一种基于知识蒸馏的图像特征提取方法及装置
- 国知局
- 2024-10-09 14:39:35
本发明涉及图像处理,尤其是指一种基于知识蒸馏的图像特征提取方法及装置。
背景技术:
1、近年来,机器学习(machine learning,ml)和深度学习(deep learning,dl)在计算机视觉领域取得了显著进步,从第一个深度卷积神经网络alexnet开始,用越来越多的参数和计算量创建的更深层次的网络逐渐被提出,例如vggnet、inception、resnet等。这些模型被广泛用于图像特征提取任务,为了提高特征提取的有效性,卷积神经网络模型越来越复杂,计算量和参数也越来越多,对硬件资源的要求随之提高,导致模型很难部署在资源受限的边缘设备上。
2、现有技术中使用知识蒸馏法对深度神经网络模型进行压缩,通过将大模型(教师模型)的知识迁移到小模型(学生模型)上,在不改变小模型结构的基础上,提升小模型的精度,以小模型的成本获得与大模型相媲美的精度,从而实现模型的轻量化,使得模型能够在不降低图像特征提取精度的情况下减少计算量和参数量。现有的知识蒸馏法将教师模型输出层的特征作为监督信息,用于训练学生模型,使得学生模型输出层输出的特征尽可能与教师模型输出层的特征一致,最终将训练好的学生模型作为特征提取模型,通过计算量和参数量更少的学生模型实现与教师模型同样的特征提取精度,在不降低图像特征提取精度的情况下实现模型的轻量化。
3、目前基于输出层特征的知识蒸馏法通过构建类内蒸馏损失以最小化学生模型和教师模型对同一图像样本在输出层特征上的差异性,但是当教师模型和学生模型的模型架构不同时,例如教师模型为resnet架构的模型,而学生模型为mobilenet架构的模型,教师模型和学生模型对图像的理解有很大的差异,因此教师模型可以给学生模型提供更多样的知识和特征表示,仅通过类内知识蒸馏损失对异构师生模型中的学生模型进行训练,无法有效将教师模型的知识迁移给学生模型,导致学生模型无法充分学习到教师模型中的全部知识,从而使得学生模型的性能下降,影响学生模型的图像特征提取精度。
4、综上所述,现有的基于知识蒸馏的图像特征提取方法在对异构师生模型中的学生模型进行训练时,学生模型无法充分学习教师模型中的知识,导致学生模型的图像特征提取精度较低。
技术实现思路
1、为此,本发明所要解决的技术问题在于克服现有技术中的基于知识蒸馏的图像特征提取方法应用于异构师生模型时,存在学生模型无法充分学习教师模型中的知识,导致学生模型的图像特征提取精度较低的问题。
2、为解决上述技术问题,本发明提供了一种基于知识蒸馏的图像特征提取方法,包括:
3、将训练集中的多个图像样本分别输入至教师模型进行特征提取,基于所述教师模型的输出得到每个图像样本对应的第一特征向量;将所述多个图像样本分别输入至学生模型进行特征提取,基于所述学生模型的输出得到每个图像样本对应的第二特征向量;
4、将多个图像样本对应的第一特征向量中排序相同的特征值组合,得到m个第一组合向量;将多个图像样本对应的第二特征向量中排序相同的特征值组合,得到m个第二组合向量;其中,m表示第一特征向量和第二特征向量中的特征值数量;
5、基于每个图像样本对应的第一特征向量和该图像样本对应的第二特征向量构建该图像样本的目标类知识蒸馏损失函数;基于多个图像样本的目标类知识蒸馏损失函数之和得到第一损失函数;
6、基于每个图像样本对应的第一特征向量和该图像样本对应的第二特征向量之间的皮尔逊相关系数,构建该图像样本的皮尔逊相关系数损失函数;基于多个图像样本的皮尔逊相关系数损失函数之和得到第二损失函数;
7、基于第i个第一组合向量和第i个第二组合向量之间的皮尔逊相关系数,构建第i个组合向量的皮尔逊相关系数损失函数;基于m个组合向量的皮尔逊相关系数损失函数之和得到第三损失函数;其中,i∈[1,m];
8、对所述第一损失函数、所述第二损失函数和所述第三损失函数加权求和,得到图像特征提取损失函数,利用训练集中的图像样本对所述学生模型进行迭代训练,直到所述图像特征提取损失函数的值最小,得到训练好的学生模型。
9、优选地,图像样本的目标类知识蒸馏损失函数表示为:
10、
11、其中,tckdj表示第j个图像样本的目标类知识蒸馏损失函数,kl表示和之间的相似度,表示第j个图像样本的第一特征向量经softmax归一化后得到的目标类别预测概率,表示第j个图像样本的第二特征向量经softmax归一化后得到的目标类别预测概率;
12、第一损失函数表示为:
13、
14、其中,tckd表示第一损失函数。
15、优选地,图像样本的皮尔逊相关系数损失函数表示为:
16、
17、其中,pearsonj表示第j个图像样本的皮尔逊相关系数损失函数,表示第j个图像样本的第一特征向量,表示第j个图像样本的第二特征向量;
18、第二损失函数表示为:
19、
20、其中,pearson(pt‖ps)表示第二损失函数。
21、优选地,第i个组合向量的皮尔逊相关系数损失函数表示为:
22、
23、其中,pearsoni表示第i个组合向量的皮尔逊相关系数损失函数,表示第i个第一组合向量,表示第i个第二组合向量;
24、第三损失函数表示为:
25、
26、其中,pearson(ptz‖psz)表示第三损失函数。
27、优选地,图像特征提取损失函数表示为:
28、kd=α·tckd+β·(pearson(pt‖ps)+pearson(ptz‖psz)),
29、其中,kd表示图像特征提取损失函数,α表示第一损失函数的权重,tckd表示第一损失函数,β表示第二损失函数和第三损失函数的权重,pearson(pt‖ps)表示第二损失函数,pearson(ptz‖psz)表示第三损失函数。
30、优选地,还包括:
31、分别对每个图像样本对应的第一特征向量进行归一化,输出每个图像样本的类别预测概率值;对多个图像样本的类别预测概率值取均值,得到目标类别预测概率值;
32、基于所述目标类别预测概率值更新第二损失函数和第三损失函数的权重,得到目标图像特征提取损失函数。
33、优选地,目标图像特征提取损失函数表示为:
34、
35、其中,kd′表示目标图像特征提取损失函数,表示目标类别预测概率值,k为预设系数。
36、优选地,还包括对所述目标类别预测概率值进行缩放、取反,得到温度因子,从而基于所述温度因子调整所述教师模型的输出概率分布。
37、优选地,所述学生模型包括stem模块和沿正传播方向依次串联的多个卷积核尺寸不同的特征提取模块,每个特征提取模块均包括无填充下采样子模块和特征提取卷积子模块,每个无填充下采样子模块均包括无填充卷积单元、无填充最大池化单元和特征融合单元。
38、本发明还提供了一种基于知识蒸馏的图像特征提取装置,包括:
39、特征提取模块,用于将训练集中的多个图像样本分别输入至教师模型进行特征提取,基于所述教师模型的输出得到每个图像样本对应的第一特征向量;将所述多个图像样本分别输入至学生模型进行特征提取,基于所述学生模型的输出得到每个图像样本对应的第二特征向量;
40、特征组合模块,用于将多个图像样本对应的第一特征向量中排序相同的特征值组合,得到m个第一组合向量;将多个图像样本对应的第二特征向量中排序相同的特征值组合,得到m个第二组合向量;其中,m表示第一特征向量和第二特征向量中的特征值数量;
41、第一损失函数构建模块,用于基于每个图像样本对应的第一特征向量和该图像样本对应的第二特征向量构建该图像样本的目标类知识蒸馏损失函数;基于多个图像样本的目标类知识蒸馏损失函数之和得到第一损失函数;
42、第二损失函数构建模块,用于基于每个图像样本对应的第一特征向量和该图像样本对应的第二特征向量之间的皮尔逊相关系数,构建所述图像样本的皮尔逊相关系数损失函数;基于多个图像样本的皮尔逊相关系数损失函数之和得到第二损失函数;
43、第三损失函数构建模块,用于基于第i个第一组合向量和第i个第二组合向量之间的皮尔逊相关系数,构建第i个组合向量的皮尔逊相关系数损失函数;基于m个组合向量的皮尔逊相关系数损失函数之和得到第三损失函数;其中,i∈[1,m];
44、模型训练模块,用于对所述第一损失函数、所述第二损失函数和所述第三损失函数加权求和,得到图像特征提取损失函数,利用训练集中的图像样本对所述学生模型进行迭代训练,直到所述图像特征提取损失函数的值最小,得到训练好的学生模型。
45、本技术提供的基于知识蒸馏的图像特征提取方法在构建损失函数时,先基于同一图像样本的第一特征向量和第二特征向量构建目标类知识蒸馏损失,从而基于所有图像样本的目标类知识蒸馏损失得到第一损失函数,以约束学生模型对于图像样本的目标类别预测值趋近于教师模型;同时,基于同一图像样本对应的第一特征向量和第二特征向量之间的皮尔逊相关系数构建皮尔逊相关系数损失函数,从而基于所有图像样本的皮尔逊相关系数损失函数得到第二损失函数,以约束学生模型和教师模型对于相同样本提取的特征向量之间的差异;除此之外,还基于不同样本对应的第一特征向量中排序相同的特征值和第二特征向量中排序相同的特征值之间的皮尔逊相关系数,构建第三损失函数,使得学生模型学习教师模型对于不同样本中同一类别的输出特征值;最后基于第一损失函数、第二损失函数和第三损失函数共同构建图像损失函数对学生模型进行迭代训练,使得学生模型不仅能够学习教师模型对图像样本的目标类别预测值,还能学习教师模型对同一图像样本的特征表示以及对不同样本中同一类别的特征表示,即使在学生模型和教师模型的模型架构相差较大的情况下,学生模型依然能够充分学习到教师模型的知识,从而提高学生模型的图像特征提取精度。
本文地址:https://www.jishuxx.com/zhuanli/20241009/306002.html
版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌抄袭侵权/违法违规的内容, 请发送邮件至 YYfuon@163.com 举报,一经查实,本站将立刻删除。