技术新讯 > 计算推算,计数设备的制造及其应用技术 > 一种用于图像分类的部分蒸馏联邦学习方法、电子设备及存储介质  >  正文

一种用于图像分类的部分蒸馏联邦学习方法、电子设备及存储介质

  • 国知局
  • 2024-07-31 23:21:12

本发明属于图像分类处理,具体涉及一种用于图像分类的部分蒸馏联邦学习方法、电子设备及存储介质。

背景技术:

1、联邦学习基于各客户端的协作来联合训练模型,同时保护每个客户端的数据隐私。在联邦学习中,客户端不需要传输本地数据,而是将使用本地数据训练的本地模型参数上传到服务器,服务器聚合客户端模型生成全局模型并下发到客户端,客户端使用全局模型来优化本地模型。但是在现实世界中,不同客户端拥有数据的分布往往是不一样的,即存在客户端数据非独立同分布。此外,客户端通常具有不同的计算资源和需求,这导致客户端模型拥有不同的结构,这使得要求模型同构的联邦学习方法失效。联邦蒸馏方法交换客户端的对蒸馏数据集的预测而并非模型参数,因此可以在客户端模型异构的场景下实现客户端之间的协作。

2、现有的图像分类技术往往通过使用大量的数据来训练模型得到一个性能较好的图像分类模型,但在联邦学习场景中,由于不能集中每个客户端的原始数据,因此,每个客户端的图像分类模型只能使用客户端所拥有的少量数据进行单独模型训练,无法得到一个较好的客户端图像分类模型。

技术实现思路

1、本发明要解决的问题是提高客户端图像分类模型的准确性,提出一种用于图像分类的部分蒸馏联邦学习方法、电子设备及存储介质。

2、为实现上述目的,本发明通过以下技术方案实现:

3、一种用于图像分类的部分蒸馏联邦学习方法,包括如下步骤:

4、s1.对服务器和图像分类客户端进行初始化;

5、s2.步骤s1初始化后的服务器随机选取一定数量的图像分类客户端参与训练,得到参与训练的图像分类客户端;

6、s3.对于步骤s2得到的参与训练的图像分类客户端,使用本地图像分类模型对蒸馏数据集中的蒸馏样本生成特征,得到图像分类客户端本地知识集合zn,并将图像分类客户端本地知识集合上传到服务器;

7、s4.服务器根据接收到的图像分类客户端本地知识集合,在蒸馏数据集上训练相应的服务器模型;

8、s5.基于步骤s4训练完成的服务器模型,服务器为下一轮计算全局基础模型;

9、s6.基于步骤s4训练完成的服务器模型,对蒸馏数据集中的蒸馏样本生成部分知识集合,并发送到图像分类客户端;

10、s7.图像分类客户端接收到步骤s6生成的部分知识集合,计算部分蒸馏系数,进行图像分类客户端模型参数更新;

11、s8.重复步骤s1-s7直至执行完所有的通信轮次,得到图像分类结果。

12、进一步的,步骤s3得到的图像分类客户端本地知识集合zn的表达式为:

13、

14、其中,为第n个图像分类客户端对第i个蒸馏样本生成的本地知识,i为蒸馏数据集中蒸馏样本的个数,为蒸馏数据集的长度。

15、进一步的,步骤s4服务器根据接收到的图像分类客户端本地知识集合,在蒸馏数据集上训练相应的服务器模型,服务器模型的损失函数的表达式为:

16、

17、其中,为第n个图像分类客户端对应的服务器模型的模型参数,为蒸馏数据集,为蒸馏数据集中第i个蒸馏样本,为服务器模型对蒸馏数据集中第i个蒸馏样本的输出,为第n个图像分类客户端对应的服务器模型去除最后一层的模型参数,为全局基础模型,μ为超参数;

18、服务器模型参数的更新的表达式为

19、

20、其中,ηs为服务器模型的学习率,代表损失函数对服务器模型参数的梯度。

21、进一步的,步骤s5中服务器为下一轮计算全局基础模型的表达式为:

22、

23、其中,n为图像分类客户端的数量。

24、进一步的,步骤s6中服务器模型对蒸馏数据集中的蒸馏样本生成部分知识集合的表达式为:

25、其中,

26、其中,为第n个图像分类客户端的部分知识集合,为第n个图像分类客户端对第i个蒸馏样本生成的部分知识。

27、进一步的,步骤s7的具体实现方法为:

28、图像分类客户端接收部分知识集合,第n个图像分类客户端的损失函数ln(wn)的表达式为:

29、ln(wn)=lce(wn)+λld(wn;αn)

30、其中,lce(wn)为第n个图像分类客户端的交叉熵损失函数,λ为超参数,ld(wn;αn)为第n个图像分类客户端的蒸馏损失函数,wn为第n个图像分类客户端的客户端模型参数,αn为第n个图像分类客户端的部分蒸馏系数;

31、第n个图像分类客户端的蒸馏损失函数的表达式为:

32、

33、其中,αn,i为第i个蒸馏样本的部分蒸馏系数,l1为绝对值损失函数,τ为超参数,为第n个图像分类客户端对第i个蒸馏样本生成的特征,为第n个图像分类客户端模型的特征提取部分的参数;

34、第n个图像分类客户端使用和客户端模型参数wn计算部分蒸馏系数αn的表达式为:

35、

36、其中,ηα为学习率超参数;

37、结合部分蒸馏系数αn和第n个图像分类客户端的部分知识集合客户端模型参数wn更新的表达式为:

38、

39、其中,ηw为学习率超参数。

40、一种电子设备,包括存储器和处理器,存储器存储有计算机程序,所述的处理器执行所述计算机程序时实现所述的一种用于图像分类的部分蒸馏联邦学习方法的步骤。

41、一种计算机可读存储介质,其上存储有计算机程序,所述计算机程序被处理器执行时实现所述的一种用于图像分类的部分蒸馏联邦学习方法。

42、本发明的有益效果:

43、本发明所述的一种用于图像分类的部分蒸馏联邦学习方法,提出使用部分蒸馏系数来衡量不同蒸馏样本的重要性,以指导客户进行根据本地数据有选择地进行部分蒸馏,帮助客户端选取对本地模型更有益的全局知识,降低全局知识中冗余知识的影响,从而提升客户端本地模型在图片分类任务上的准确率。此外,为了更好地集成来自不同客户的知识,本发明提出了一种新的部分知识集成方法,该方法使用客户端知识来训练服务器模型,允许客户端模型拥有完全不同的架构,并为每个客户生成不同部分全局知识来指导不同客户模型的训练,从而提升客户端模型在图片分类任务上的准确率。

技术特征:

1.一种用于图像分类的部分蒸馏联邦学习方法,其特征在于,包括如下步骤:

2.根据权利要求1所述的一种用于图像分类的部分蒸馏联邦学习方法,其特征在于:步骤s3得到的图像分类客户端本地知识集合zn的表达式为:

3.根据权利要求2所述的一种用于图像分类的部分蒸馏联邦学习方法,其特征在于:步骤s4服务器根据接收到的图像分类客户端本地知识集合,在蒸馏数据集上训练相应的服务器模型,服务器模型的损失函数的表达式为:

4.根据权利要求3所述的一种用于图像分类的部分蒸馏联邦学习方法,其特征在于:步骤s5中服务器为下一轮计算全局基础模型的表达式为:

5.根据权利要求4所述的一种用于图像分类的部分蒸馏联邦学习方法,其特征在于:步骤s6中服务器模型对蒸馏数据集中的蒸馏样本生成部分知识集合的表达式为:

6.根据权利要求5所述的一种用于图像分类的部分蒸馏联邦学习方法,其特征在于:步骤s7的具体实现方法为:

7.一种电子设备,其特征在于,包括存储器和处理器,存储器存储有计算机程序,所述的处理器执行所述计算机程序时实现权利要求1-6任一项所述的一种用于图像分类的部分蒸馏联邦学习方法的步骤。

8.一种计算机可读存储介质,其上存储有计算机程序,其特征在于,所述计算机程序被处理器执行时实现权利要求1-6任一项所述的一种用于图像分类的部分蒸馏联邦学习方法。

技术总结一种用于图像分类的部分蒸馏联邦学习方法、电子设备及存储介质,属于图像分类处理技术领域。为提高客户端图像分类模型的准确性,本发明对服务器和图像分类客户端进行初始化得到参与训练的图像分类客户端,使用本地图像分类模型对蒸馏数据集中的蒸馏样本生成特征得到图像分类客户端本地知识集合上传到服务器;服务器根据接收到的图像分类客户端本地知识集合,在蒸馏数据集上训练相应的服务器模型;服务器为下一轮计算全局基础模型;对蒸馏数据集中的蒸馏样本生成部分知识集合,发送到图像分类客户端;图像分类客户端接收到生成的部分知识集合,计算部分蒸馏系数,进行图像分类客户端模型参数更新;重复直至执行完所有的通信轮次得到图像分类结果。技术研发人员:廖清,杨旭,周骏,冯纪元,冯云青,张博雅,郭松岳,刘瑞受保护的技术使用者:哈尔滨工业大学技术研发日:技术公布日:2024/7/29

本文地址:https://www.jishuxx.com/zhuanli/20240730/197153.html

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌抄袭侵权/违法违规的内容, 请发送邮件至 YYfuon@163.com 举报,一经查实,本站将立刻删除。