基于连续学习的模型推理方法、装置及存储介质与流程
- 国知局
- 2024-10-21 14:57:51
本技术涉及人工智能,具体涉及一种基于连续学习的模型推理方法、装置及存储介质。
背景技术:
1、近年来,随着互联网技术的不断发展,互联网中的数据也呈现出爆炸式的增长。在机器学习算法中,训练数据与测试数据很难达到分布相似的状态,一般的机器学习算法也很难在动态的环境中连续自适应的学习。因此,学者们提出了连续学习算法来解决上述问题。其中,连续学习(continual learning,cl),又称持续学习或终生学习。连续学习模拟了人类大脑的学习思考方式,能够对非独立同分布的数据流进行学习。他的本质是既能够对到来的新数据进行利用,并基于之前任务积累的经验,在新的数据上很好地完成任务;又能够避免遗忘问题,对曾经训练过的任务以及保持很高的精度。
2、然而,在相关技术中,连续学习方法对每个任务都需要使用一个输出层(例如分类任务则需要分类层)来解决任务,对于不同的任务需要使用不同的权重来构造输出层以适应新的任务。随着连续学习的模型任务的数量的增加,模型需要保存的输出层的参数也会逐渐增加存储的负担,如此便会使得模型变得臃肿,从而影响模型的推理效率。
技术实现思路
1、本技术实施例提供一种基于连续学习的模型推理方法、装置及存储介质,该方法可以提升基于连续学习的模型推理的推理效率。
2、本技术第一方面提供一种基于连续学习的模型推理方法,方法包括:
3、获取目标推理任务以及所述目标推理任务对应的推理数据;
4、将所述推理数据输入至训练后的神经网络模型中,得到模型输出数据,所述神经网络模型为对多个推理任务对应的任务模型进行连续学习得到的神经网络模型,所述多个推理任务包括所述目标推理任务;
5、从记忆数据集合中获取与所述目标推理任务关联的关联记忆数据,所述记忆数据集合中的数据为对特征数据进行聚类得到的数据,所述特征数据为在对所述任务模型进行连续学习过程中由训练后的任务模型对相应的训练样本数据进行特征提取得到的数据;
6、根据所述关联记忆数据与所述模型输出数据之间的关联关系,在所述关联记忆数据中确定多个目标记忆数据;
7、基于所述多个目标记忆数据对应的标签数据确定推理结果。
8、相应的,本技术第二方面提供一种基于连续学习的模型推理装置,所述装置包括:
9、第一获取单元,用于获取目标推理任务以及所述目标推理任务对应的推理数据;
10、输入单元,用于将所述推理数据输入至训练后的神经网络模型中,得到模型输出数据,所述神经网络模型为对多个推理任务对应的任务模型进行连续学习得到的神经网络模型,所述多个推理任务包括所述目标推理任务;
11、第二获取单元,用于从记忆数据集合中获取与所述目标推理任务关联的关联记忆数据,所述记忆数据集合中的数据为对特征数据进行聚类得到的数据,所述特征数据为在对所述任务模型进行连续学习过程中由训练后的任务模型对相应的训练样本数据进行特征提取得到的数据;
12、第一确定单元,用于根据所述关联记忆数据与所述模型输出数据之间的关联关系,在所述关联记忆数据中确定多个目标记忆数据;
13、第二确定单元,用于基于所述多个目标记忆数据对应的标签数据确定推理结果。
14、可选地,在一些实施例中,本技术提供的基于连续学习的模型推理装置,还包括:
15、第一获取子单元,用于获取与多个推理任务对应的训练样本数据集合,得到多个训练样本数据集合;
16、训练子单元,用于依次采用每一推理任务对应的训练样本数据集合对所述多个推理任务对应的任务模型进行训练,得到训练后的神经网络模型。
17、可选地,在一些实施例中,所述多个推理任务按照预设顺序排列,所述训练子单元,包括:
18、第一获取模块,用于获取对所述任务模型进行训练的目标训练样本数据集合,所述目标训练样本数据集合为当前推理任务对应的训练样本数据集合,所述当前推理任务为所述多个推理任务中的一个第一推理任务;
19、训练模块,用于根据所述目标训练样本数据集合对所述任务模型进行训练;
20、执行模块,用于当所述当前推理任务不是所述多个推理任务中的最后一个推理任务时,获取下一个第二推理任务并采用所述第二推理任务更新所述当前推理任务,以及返回执行获取对所述任务模型进行训练的目标训练样本数据集合,并根据所述目标训练样本数据集合对所述任务模型进行训练的步骤;
21、终止模块,用于当所述当前推理任务是所述多个推理任务中最后一个推理任务时,终止对所述任务模型的训练,得到训练后的神经网络模型。
22、可选地,在一些实施例中,本技术提供的基于连续学习的模型推理装置还包括:
23、构建子单元,用于构建记忆数据集合;
24、训练模块,包括:
25、训练子模块,用于根据所述记忆数据集合与所述目标训练样本数据集合对所述任务模型进行训练;
26、更新子模块,用于采用训练后的任务模型对所述目标训练样本数据集合中的训练样本数据进行特征提取,并基于提取到的特征更新所述记忆数据集合。
27、可选地,在一些实施例中,所述训练子模块,还用于:
28、获取对任务模型进行训练的轮次,并从所述目标训练样本数据集合中获取批样本数据;
29、当所述轮次不是预设次数的整数倍时,根据所述批样本数据对所述任务模型进行训练;
30、当所述轮次是所述预设次数的整数倍时,根据所述批样本数据和所述记忆数据集合中的记忆数据对所述任务模型进行训练;
31、循环执行更新对所述任务模型进行训练的轮次,并从所述目标训练样本数据集合中获取批样本数据,当所述轮次不是预设次数的整数倍时,根据所述批样本数据对所述任务模型进行训练,以及当所述轮次是所述预设次数的整数倍时,根据所述批样本数据和所述记忆数据集合中的记忆数据对所述任务模型进行训练的步骤。
32、可选地,在一些实施例中,所述训练子模块,还用于:
33、将所述批样本数据输入至所述任务模型,并基于所述任务模型的输出与所述批样本数据中的标签数据计算对比学习损失;
34、基于所述对比学习损失对所述任务模型的参数进行更新。
35、可选地,在一些实施例中,本技术提供的基于连续学习的模型推理装置,还包括:
36、第一计算子单元,用于当所述当前推理任务不是所述多个推理任务中的第一个推理任务时,计算实例关系蒸馏损失,所述实例关系蒸馏损失用于约束基于不同推理任务对应的训练样本数据集合训练得到的任务模型之间的差异;
37、所述训练子模块,还用于:
38、根据所述对比学习损失和所述实例关系蒸馏损失计算目标损失;
39、基于所述目标损失对所述任务模型的参数进行更新。
40、可选地,在一些实施例中,所述第一计算子单元,包括:
41、第二获取模块,用于当所述当前推理任务不是所述多个推理任务中的第一个推理任务时,获取采用每一推理任务对应的样本特征集合,所述样本特征为对任务模型进行训练后采用训练后的任务模型对相应的训练样本数据进行特征提取得到的特征;
42、计算模块,用于基于相似标签对应的样本特征在不同样本特征集合中的相似关系,计算实例关系蒸馏损失。
43、可选地,在一些实施例中,所述更新子模块,还用于:
44、采用训练后的任务模型对所述目标训练样本数据进行特征提取,得到多个样本特征;
45、对基于所述多个样本特征对应的标签数据将所述多个样本特征进行特征聚类,并根据聚类结果从所述多个样本特征中确定多个目标样本特征;
46、将所述多个目标样本特征添加至所述记忆数据集合中,以对所述记忆数据集合进行更新。
47、可选地,在一些实施例中,第一确定单元,包括:
48、第二计算子单元,用于计算所述关联记忆数据中的每一记忆数据与所述模型输出数据之间的相似度;
49、确定子单元,用于根据所述相似度在所述关联记忆数据中确定多个目标记忆数据。
50、可选地,在一些实施例中,第二确定单元,包括:
51、第二获取子单元,用于获取所述多个目标记忆数据对应的标签数据,得到多个标签数据;
52、处理子单元,用于对所述多个标签数据进行聚类处理,并确定标签数量最多的类别对应的标签为推理结果。
53、本技术第三方面还提供一种计算机可读存储介质,所述计算机可读存储介质存储有多条指令,所述指令适于处理器进行加载,以执行本技术第一方面所提供的基于连续学习的模型推理方法中的步骤。
54、本技术第四方面提供一种计算机设备,包括存储器、处理器以及存储在所述存储器中并可以在所述处理器上运行的计算机程序,所述处理器执行所述计算机程序时实现本技术第一方面所提供的基于连续学习的模型推理方法中的步骤。
55、本技术第五方面提供一种计算机程序产品,包括计算机程序/指令,所述计算机程序/指令被处理器执行时实现第一方面所提供的基于连续学习的模型推理方法中的步骤。
56、本技术实施例提供的基于连续学习的模型推理方法,通过获取目标推理任务以及目标推理任务对应的推理数据;将推理数据输入至训练后的神经网络模型中,得到模型输出数据,神经网络模型为对多个推理任务对应的任务模型进行连续学习得到的神经网络模型,多个推理任务包括目标推理任务;从记忆数据集合中获取与目标推理任务关联的关联记忆数据,记忆数据集合中的数据为对特征数据进行聚类得到的数据,特征数据为在对任务模型进行连续学习过程中由训练后的任务模型对相应的训练样本数据进行特征提取得到的数据;根据关联记忆数据与模型输出数据之间的关联关系,在关联记忆数据中确定多个目标记忆数据;基于多个目标记忆数据对应的标签数据确定推理结果。
57、以此,本技术提供的基于连续学习的模型推理方法,在对多任务模型进行连续学习时,不仅输出训练后的神经网络模型,还输出了用于协助进行模型推理的记忆数据集合。当需要进行模型任务推理时,可以采用连续学习得到的神经网络模型对推理数据进行特征提取,得到模型输出数据,然后基于模型输出数据和记忆数据集合中与该模型推理任务相应的记忆数据之间的相似关系确定最终的推理结果。如此则无需再构建模型输出层,避免了模型任务增加使得存储负担增加,进而导致影响模型推理效率的问题。从而可以大大提升模型推理效率。
本文地址:https://www.jishuxx.com/zhuanli/20241021/319958.html
版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌抄袭侵权/违法违规的内容, 请发送邮件至 YYfuon@163.com 举报,一经查实,本站将立刻删除。