一种基于改进联邦学习的模型训练方法及相关装置与流程
- 国知局
- 2024-11-06 14:27:32
本发明属于数据安全领域,具体涉及一种基于改进联邦学习的模型训练方法及相关装置。
背景技术:
1、在人工智能(ai)领域,联邦学习(federated learning,fl)作为一种新兴的机器学习方法,已经在分布式数据环境下显示出巨大的潜力。联邦学习允许多个参与者在不共享原始数据的情况下共同训练一个模型,从而解决了传统机器学习中的数据隐私和安全问题。然而,尽管联邦学习具有这些优点,但在实际应用中,特别是在数据不平衡的场景下,传统的联邦学习算法往往面临着性能下降的挑战。
2、数据不平衡问题在机器学习中是普遍存在的,它指的是不同类别的数据样本数量差异较大。在联邦学习的框架下,当参与者的数据类别分布不均衡时,传统的联邦学习算法难以训练出泛化能力强的模型。这是因为模型在训练过程中更多地受到数量较多的数据类别的影响,而忽视了数量较少的类别,导致模型在少数类别上的性能较差。此外,数据不平衡还可能导致联邦学习算法的安全性问题。由于模型性能下降,攻击者更容易利用模型的弱点进行攻击,从而威胁到数据的安全。因此,在数据不平衡的场景下,如何提高联邦学习算法的性能和安全性是一个亟待解决的问题。
技术实现思路
1、针对现有技术中存在的问题,本发明提供了一种基于改进联邦学习的模型训练方法及相关装置,旨在解决数据不平衡场景下的模型性能和隐私保护问题,不仅提高了模型的泛化能力,还增强了数据的安全性。
2、为了解决上述技术问题,本发明通过以下技术方案予以实现:
3、根据本发明的第一方面,提供一种基于改进联邦学习的模型训练方法,应用于基于改进联邦学习的模型训练系统的服务器,所述基于改进联邦学习的模型训练系统包括服务器以及与服务器通信连接的若干参与端;所述基于改进联邦学习的模型训练方法包括:
4、获取初始模型并发送至各参与端;
5、迭代进行更新步骤至预设更新迭代阈值,将当前模型作为训练完成的模型,并发送至各参与端;
6、其中,所述更新步骤包括:
7、接收各参与端发送的梯度参数,进入全局参数更新过程;其中,梯度参数为参与端根据本地训练数据,结合本地差分隐私保护和知识蒸馏进行梯度计算得到的本地模型的各神经网络层的梯度;
8、所述全局参数更新过程包括:根据各参与端发送的梯度参数,更新模型的全局参数,并将更新的全局参数发送至各参与端;其中,更新的全局参数用于更新参与端的本地模型的模型参数。
9、在第一方面的一种可能的实现方式中,所述参与端根据本地训练数据,结合本地差分隐私保护和知识蒸馏进行梯度计算得到的本地模型的各神经网络层的梯度,包括:
10、参与端采用本地差分隐私保护向本地训练数据添加噪声,并在本地模型的softmax层通过学生模型和教师模型学习,最小化教师模型和学生模型输出之间的损失。
11、在第一方面的一种可能的实现方式中,所述根据各参与端发送的梯度参数,更新模型的全局参数,包括:
12、平均各参与端发送的梯度参数,得到平均梯度参数;
13、根据平均梯度参数,采用反向传播算法更新模型的全局参数。
14、在第一方面的一种可能的实现方式中,所述梯度计算具体为随机梯度下降方法。
15、根据本发明的第二方面,提供一种基于改进联邦学习的模型训练方法,应用于基于改进联邦学习的模型训练系统的参与端,所述基于改进联邦学习的模型训练系统包括服务器以及与服务器通信连接的若干参与端;所述基于改进联邦学习的模型训练方法包括:
16、接收服务器发送的初始模型,作为本地模型;
17、迭代进行梯度计算步骤至预设梯度计算迭代阈值,接收服务器发送的训练完成的模型;
18、其中,所述梯度计算步骤包括:
19、根据本地训练数据,结合本地差分隐私保护和知识蒸馏进行梯度计算得到的本地模型的各神经网络层的梯度,完成梯度计算后将本地模型的各神经网络层的梯度作为梯度参数发送至服务器;
20、当接收到服务器发送的更新的全局参数时,根据更新的全局参数更新本地模型的模型参数;
21、所述全局参数为服务器根据各参与端发送的梯度参数,更新模型的全局参数
22、在第二方面的一种可能的实现方式中,所述根据本地训练数据,结合本地差分隐私保护和知识蒸馏进行梯度计算得到的本地模型的各神经网络层的梯度,包括:
23、参与端采用本地差分隐私保护向本地训练数据添加噪声,并在本地模型的softmax层通过学生模型和教师模型学习,最小化教师模型和学生模型输出之间的损失。
24、在第二方面的一种可能的实现方式中,所述根据各参与端发送的梯度参数,更新模型的全局参数,包括:
25、平均各参与端发送的梯度参数,得到平均梯度参数;
26、根据平均梯度参数,采用反向传播算法更新模型的全局参数。
27、在第二方面的一种可能的实现方式中,所述梯度计算具体为随机梯度下降方法。
28、根据本发明的第三方面,提供一种基于改进联邦学习的模型训练装置,应用于基于改进联邦学习的模型训练系统的服务器,所述基于改进联邦学习的模型训练系统包括服务器以及与服务器通信连接的若干参与端;所述基于改进联邦学习的模型训练装置包括:
29、发送模块,用于获取初始模型并发送至各参与端;
30、迭代更新模块,用于迭代进行更新步骤至预设更新迭代阈值,将当前模型作为训练完成的模型,并发送至各参与端;
31、其中,所述更新步骤包括:
32、接收各参与端发送的梯度参数,进入全局参数更新过程;其中,梯度参数为参与端根据本地训练数据,结合本地差分隐私保护和知识蒸馏进行梯度计算得到的本地模型的各神经网络层的梯度;
33、所述全局参数更新过程包括:根据各参与端发送的梯度参数,更新模型的全局参数,并将更新的全局参数发送至各参与端;其中,更新的全局参数用于更新参与端的本地模型的模型参数。
34、根据本发明的第四方面,提供一种基于改进联邦学习的模型训练装置,应用于基于改进联邦学习的模型训练系统的参与端,所述基于改进联邦学习的模型训练系统包括服务器以及与服务器通信连接的若干参与端;所述基于改进联邦学习的模型训练装置包括:
35、接收模块,用于接收服务器发送的初始模型,作为本地模型;
36、迭代计算模块,用于迭代进行梯度计算步骤至预设梯度计算迭代阈值,接收服务器发送的训练完成的模型;
37、其中,所述梯度计算步骤包括:
38、根据本地训练数据,结合本地差分隐私保护和知识蒸馏进行梯度计算得到的本地模型的各神经网络层的梯度,完成梯度计算后将本地模型的各神经网络层的梯度作为梯度参数发送至服务器;
39、当接收到服务器发送的更新的全局参数时,根据更新的全局参数更新本地模型的模型参数;
40、所述全局参数为服务器根据各参与端发送的梯度参数,更新模型的全局参数。
41、与现有技术相比,本发明至少具有以下有益效果:
42、本发明提供的一种基于改进联邦学习的模型训练方法,将本地差分隐私保护和知识蒸馏技术相结合,可以在提高模型性能的同时保护数据的隐私安全,解决了数据不平衡场景下的模型性能和隐私保护问题。具体而言,本发明在联邦学习算法的基础上,创新性地结合了本地差分隐私保护和知识蒸馏,解决了数据不平衡场景下的模型性能和隐私保护问题,不仅提高了模型的泛化能力,还增强了数据的安全性,为联邦学习在实际应用中的推广提供了有力的支持,实现了在保护用户隐私的同时优化联邦学习系统的性能,具有显著的隐私保护能力、较低的通信开销、较高的系统性能以及增强的鲁棒性,为联邦学习在实际应用中的广泛部署提供了有力支持。
43、为使本发明的上述目的、特征和优点能更明显易懂,下文特举较佳实施例,并配合所附附图,作详细说明如下。
本文地址:https://www.jishuxx.com/zhuanli/20241106/322200.html
版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌抄袭侵权/违法违规的内容, 请发送邮件至 YYfuon@163.com 举报,一经查实,本站将立刻删除。
下一篇
返回列表