基于阶段训练和注意力融合的多出口架构自蒸馏方法与流程
- 国知局
- 2024-10-09 16:06:50
本发明涉及知识蒸馏领域,特别涉及基于阶段训练和注意力融合的多出口架构自蒸馏方法。
背景技术:
1、人工智能是当今世界上最具革命性和颠覆性的技术领域。然而,受限于较弱的计算能力和数据获取能力,早期的人工智能研究对改变人类社会的生产生活方式做出了相对有限的贡献。随着更强大的gpu计算能力的到来和大规模数据采集可行性的增加,神经网络开始发挥其真正的力量。自2012年以来,深度学习方法逐渐主导了各种计算机视觉和自然语言处理相关任务的精度基准,并已被广泛应用于图像分类、目标监控、机器翻译等领域,使得全球人工智能产业的发展达到了一个周期性的高峰。然而,这种算法性能的激增通常是以增加模型复杂度为代价的。一方面,高复杂性导致模型的总体响应时间较长,这在智能驾驶、智能医疗等对时效性要求较高的应用场景中难以被接受。另一方面,这些训练有素的模型带有大量冗余参数,这将远远超出边缘设备的有限计算资源容量。为了解决模型参数冗余的问题,研究人员围绕网络的轻量化进行了大量工作。现有的解决方案大致可以分为三类:参数剪枝、参数量化和知识蒸馏。参数剪枝的核心思想是通过开发不同的评估策略来删除无意义的参数。然而,许多方法仍然依赖于特定的硬件或实现,因此在实施上存在一定的局限性,难以被广泛实施。参数量化可以通过减少位数来提高计算速度。然而,量化权重通常使得神经网络更难以收敛。相比之下,知识提取可以通过多监督目标直接提高小网络的准确性,这使得它的应用非常灵活,因此具有广泛的优化前景。
2、知识蒸馏是神经网络压缩领域的一种重要方法。其核心思想是让一个弱小但紧凑的学生模型通过学习一个先进但庞大的教师模型的软目标来提高其性能。知识传递不受模型结构的限制,因此具有很高的灵活性。现有的知识蒸馏有以下缺点:不同模型之间的知识传递效率低,学生模型难以有效地从教师模型中获得帮助,尤其是当教师模型和学生模型之间的差距较大时。
3、因此,需要提供基于阶段训练和注意力融合的多出口架构自蒸馏方法,用于提高知识传递的效率,改进知识自蒸馏框架的性能。
技术实现思路
1、本发明提供基于阶段训练和注意力融合的多出口架构自蒸馏方法,包括:根据深度将教师模型划分为多个出口分支,其中,所述多个出口分支中,深度最深的分支为教师模型,深度最浅的出口分支为学生模型,其余的出口分支为中间模型,所述教师模型和所述学生模型均用于图像分类;建立总损失函数;基于所述多个出口分支及注意力融合算法,训练所述学生模型,基于所述总损失函数,计算总损失,基于所述总损失,优化所述学生模型,直至所述学生模型满足预设条件。
2、进一步地,所述根据深度将教师模型划分为多个分支,包括:根据卷积层数量将所述教师模型均分为多个出口分支。
3、进一步地,基于所述多个出口分支及注意力融合算法,训练所述学生模型,包括:将训练周期划分为多个阶段;在所述多个阶段,基于所述多个出口分支及注意力融合算法,对所述学生模型进行级联训练。
4、进一步地,在所述多个阶段,基于所述多个出口分支,对所述学生模型进行级联训练,包括:在所述多个阶段,基于所述多个出口分支及注意力融合算法,对所述学生模型进行级联训练,并对每个所述中间模型进行级联训练。
5、进一步地,所述总损失函数为:,其中,l为总损失,为第i个出口分支的交叉熵损失,为取值为0或1的函数,为第i个出口分支与第j个出口分支的蒸馏损失,为第i个出口分支的特征图损失,m为出口分支的总数,及均为权重,和均为用于求和的索引变量,为阶段。
6、进一步地,基于以下公式计算第i个出口分支的交叉熵损失:,其中,n为样本总数,g为图像类别总数,为第 i个样本对应第 j 个图像类别的真实标签,为第i个出口分支预测第 i 个样本属于第 j 个图像类别的概率, k为用于求和的索引变量,g为用于求和的索引变量。
7、进一步地,基于以下公式计算第i个出口分支与第j个出口分支的蒸馏损失:,其中,k为输入至第i个出口分支与第j个出口分支的样本,t为温度参数,为第j个出口分支在给定输入 k 和温度参数 t 下,对第 g个图像类别的预测概率,为第i个出口分支在给定输入 k 和温度参数 t 下,对第g个图像类别的预测概率。
8、进一步地,基于以下公式计算第i个出口分支的特征图损失:,其中,为超参数,为教师模型输出的特征图,为第i个出口分支输出的特征图,表示计算教师模型输出的特征图与第i个出口分支输出的特征图之间的k阶欧几里得距离。
9、进一步地,基于所述多个出口分支及注意力融合算法,对所述学生模型进行级联训练,包括:对于每个所述出口分支,通过注意力模块计算所述出口分支的注意力信息;将每个所述出口分支的注意力信息整合到学生分支中,对所述学生模型进行级联训练。
10、进一步地,基于以下公式整合每个所述出口分支的注意力信息:,其中,为融合后的注意力信息,为第i个出口分支的注意力信息,m为出口分支的总数,为空间注意力模块输出的加权特征图,为通道注意力模块输出的加权特征图,f为输入特征图,为元素级乘法。
11、相比于现有技术,本发明提供的基于阶段训练和注意力融合的多出口架构自蒸馏方法,至少具备以下有益效果:
12、在多出口架构框架下,提出了一种渐进式学生训练方法,可以有效减少教师与学生之间的能力差距,达到提高知识传递效率的目的。在多出口架构框架下,提出了一种注意力整合方法,可以使学生模型关注教师模型在不同抽象层次的注意力信息。
技术特征:1.基于阶段训练和注意力融合的多出口架构自蒸馏方法,其特征在于,包括:
2.根据权利要求1所述的基于阶段训练和注意力融合的多出口架构自蒸馏方法,其特征在于,所述根据深度将教师模型划分为多个分支,包括:根据卷积层数量将所述教师模型均分为多个出口分支。
3.根据权利要求2所述的基于阶段训练和注意力融合的多出口架构自蒸馏方法,其特征在于,基于所述多个出口分支及注意力融合算法,训练所述学生模型,包括:
4.根据权利要求3所述的基于阶段训练和注意力融合的多出口架构自蒸馏方法,其特征在于,在所述多个阶段,基于所述多个出口分支,对所述学生模型进行级联训练,包括:
5.根据权利要求4所述的基于阶段训练和注意力融合的多出口架构自蒸馏方法,其特征在于,所述总损失函数为:
6.根据权利要求5所述的基于阶段训练和注意力融合的多出口架构自蒸馏方法,其特征在于,基于以下公式计算第i个出口分支的交叉熵损失:
7.根据权利要求6所述的基于阶段训练和注意力融合的多出口架构自蒸馏方法,其特征在于,基于以下公式计算第i个出口分支与第j个出口分支的蒸馏损失:
8.根据权利要求6所述的基于阶段训练和注意力融合的多出口架构自蒸馏方法,其特征在于,基于以下公式计算第i个出口分支的特征图损失:
9.根据权利要求3-8中任意一项所述的基于阶段训练和注意力融合的多出口架构自蒸馏方法,其特征在于,基于所述多个出口分支及注意力融合算法,对所述学生模型进行级联训练,包括:
10.根据权利要求9所述的基于阶段训练和注意力融合的多出口架构自蒸馏方法,其特征在于,基于以下公式整合每个所述出口分支的注意力信息:
技术总结本发明提供基于阶段训练和注意力融合的多出口架构自蒸馏方法,涉及知识蒸馏领域,包括:根据深度将教师模型划分为多个出口分支,其中,教师模型和学生模型用于图像分类,多个出口分支中,深度最深的分支为教师模型,深度最浅的出口分支为学生模型,其余的出口分支为中间模型;建立总损失函数;基于多个出口分支及注意力融合算法,训练学生模型,基于总损失函数,计算总损失,基于总损失,优化学生模型,直至学生模型满足预设条件,具有提高知识传递的效率,改进知识自蒸馏框架的性能的优点。技术研发人员:张峰,黄赛枭,许剑受保护的技术使用者:国能大渡河大数据服务有限公司技术研发日:技术公布日:2024/9/26本文地址:https://www.jishuxx.com/zhuanli/20240929/311677.html
版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌抄袭侵权/违法违规的内容, 请发送邮件至 YYfuon@163.com 举报,一经查实,本站将立刻删除。
下一篇
返回列表