基于客户端自主性的分布式训练方法、系统、介质及设备
- 国知局
- 2024-12-06 12:30:14
本公开涉及机器学习,具体地,涉及一种基于客户端自主性的分布式训练方法、系统、介质及设备。
背景技术:
1、联邦学习是一种基于隐私保护的分布式模型训练架构,允许多个边缘设备联合训练一个全局模型而无需泄漏本地数据。联邦学习的最通用的算法为fedavg,整个过程训练多个轮次直至模型收敛。客户端每轮训练重复如下过程:首先,在每轮训练开始时,参与训练的客户端会获取最新的全局模型并更新至本地模型;然后,客户端分别基于本地小批次样本进行多次计算并更新本地模型;最后,客户端在该轮本地训练结束后将对本地模型的更新上传至服务端,服务端将所有更新进行聚合后更新全局模型。
2、训练效率低是现实联邦学习应用面临的最主要问题之一,由于参与联邦学习的边缘设备(例如手机和iot设备等)通常具有异构的硬件资源和有限的网络带宽,并且资源和带宽状态又通常是动态变化的,这影响了设备的计算和通信时间,进而影响全局训练过程。在同步联邦学习架构中,计算或通信缓慢的设备会拖慢整个训练过程,严重影响训练效率;在异步联邦学习系统中通常设置一个容忍区间,服务端通过接受区间范围内非本轮次的过时更新来解决设备异构带来的性能问题,然而过时信息会影响全局模型的收敛速度和性能。
3、为了解决计算和通信异构性带来的问题,现有一些优化方法,例如:现有技术一:为了解决异构性给联邦学习带来的挑战,fedprox作为fedavg的一种泛化和重参数化方法被提出。fedprox在每轮训练中,在各个客户端本地优化目标上增加了一个近端项,例如,在第k个参与训练的设备上,将它的本地优化目标由fk(w)变为其中w表示初始化模型参数,wt表示第t个轮次的模型参数,μ表示算法超参数。增加近端项后可以在两个方面减少异构性带来的影响:(1)容忍系统异构性带来的影响,例如,允许设备本地存在不同的训练进度;(2)保证本地的更新不过于偏离初始的全局模型,减轻客户端本地数据统计学异构性带来的影响。同时fedprox通过理论证明证实了该方法能够保证模型收敛性能。
4、虽然fedprox通过设定不同训练轮次模拟设备系统异构性,并在实验中验证了算法对这种异构性具有容忍度,但是,该算法并没有提供选择合适训练轮次的有效策略。此外,训练轮次只是针对客户端粗粒度的工作负载控制,然而参与训练的客户端状态不仅会在轮间产生变化,每轮训练内部也会快速发生变化,该算法忽略了潜在每轮训练内部更细粒度的控制优化空间。同时,该算法在一定程度上会影响模型的收敛速度。
5、现有技术二:fedada通过结合设备系统性能和本地数据统计学特征动态为参与训练的客户端自适应分配工作负载,具体地说,fedada采用训练时间代表设备系统特征,本地损失值代表本地数据的统计学特征,服务端每轮都基于客户端上报的训练时间和本地损失值来动态地为各个客户端下一轮分配不同的小批量迭代训练次数。fedada为算法设定了一个多目标优化公式:其中t表示每一轮各客户端计算时间与通信时间之和,tidle表示表示完成训练较快的客户端等待所有客户端都完成的空闲时间,表示初始t中的最大值,而c则是对两轮之间本地损失差值进行sigmoid函数计算后得到的值,因此csum表示各个客户端c的和,和cmax表示最大值。fedada并未直接对优化目标进行求解,而是采用从设定的迭代次数初始值开始,不断减少每个客户端的迭代次数并计算p的值直至p最小,经过计算后fedada通常会为每轮训练较慢的客户端分配较少的迭代次数以使得所有客户端都能在相近的时间内完成训练。
6、fedada虽然能够通过平衡系统和数据异构性特征自适应地动态调节各个客户端每轮的迭代次数,但是仍然存在一定局限。fedada由服务端做优化决策,服务端通常只能获取到客户端每轮训练完成后的状态数据,无法感知每轮内客户端训练状态的动态变化,然而每轮内客户端的状态变化是非常重要的,例如,一个客户端可能会由于负载高、节能模式等原因计算能力下降,如果能感知到这样的变化就可以采取策略即时避免该客户端拖慢整体训练进程。另外,每个客户端本地模型参数收敛速度也并不相同,客户端能够通过提前传输收敛参数来进行计算-通信时间重叠优化,进一步提高训练效率。
技术实现思路
1、针对现有技术中的缺陷,本公开的目的是提供一种基于客户端自主性的分布式训练方法、系统、介质及设备。
2、根据本公开的一个方面,提供一种基于客户端自主性的分布式训练方法,包括:
3、确定客户端本地训练的特征,所述客户端本地训练的特征包括系统特征和统计特征;
4、根据所述客户端本地训练的特征,采用提前停止策略决策所述客户端的本地训练进程;
5、根据所述客户端本地训练的统计特征,采用基于错误反馈的层级提前传输策略优化所述客户端的本地训练的通信进程。
6、可选地,所述确定客户端本地训练的特征,包括:
7、采用周期性采样方法采样所述客户端在采样训练轮次内每次迭代训练后对应的累积梯度;
8、根据所述客户端在采样训练轮次内的累积梯度,分别按层以及整体计算模型的统计进程,确定所述客户端本地训练的统计特征。
9、可选地,所述确定客户端本地训练的特征,还包括:
10、采用所述客户端在每轮训练中完成每次迭代的时间表示所述客户端本地训练的系统特征。
11、可选地,所述根据所述客户端本地训练的特征,采用提前停止策略决策所述客户端的本地训练进程,包括:
12、将所述客户端在每轮训练中完成每次迭代产生的边际收益进行量化处理,确定所述客户端在每轮训练中完成每次迭代的边际收益;
13、将所述客户端在每轮训练中完成每次迭代对应的边际成本进行量化处理,确定所述客户端在每轮训练中完成每次迭代的边际成本;
14、根据所述客户端在每轮训练中完成每次迭代的边际收益和所述客户端在每轮训练中完成每次迭代的边际成本,确定所述客户端在每轮训练中完成每次迭代的净收益;
15、根据所述客户端在每轮训练中完成每次迭代后的净收益,所述客户端自主决策是否提前终止该轮训练。
16、可选地,所述根据所述客户端在每轮训练中完成每次迭代后的净收益,所述客户端自主决策是否提前终止该轮训练,包括:
17、若所述客户端在该轮训练中完成该次迭代后的净收益为负值,所述客户端提前终止本轮训练;
18、若所述客户端在该轮训练中完成该次迭代后的净收益为非负值,所述客户端继续本轮训练。
19、可选地,所述根据所述客户端本地训练的统计特征,采用基于错误反馈的层级提前传输策略优化所述客户端的本地训练的通信进程,包括:
20、在所述客户端的每轮训练中,将所述客户端的已收敛的层级参数提前传输至服务端;
21、在所述客户端完成每轮训练时,根据所述提前传输时的所述客户端的累积梯度快照和所述客户端完成每轮训练后对应的累积梯度,确定余弦相似度;
22、若所述余弦相似度小于预设的第一阈值,在所述客户端完成该轮训练后,将在该轮训练中提前传输的层级参数重新上传至所述服务端。
23、可选地,所述在客户端的每轮训练中,将所述客户端的已收敛的层级参数提前传输至服务端,包括:
24、若所述客户端的层级参数在每轮训练中完成每次迭代时,在相邻周期性采样训练轮次的相同迭代次数的统计特征指标值不小于预设的第二阈值,确定所述客户端的层级参数在该轮训练中完成该次迭代时达到收敛状态;
25、将所述客户端在每轮训练中完成每次迭代时达到收敛状态的层级参数提前上传至所述服务端。
26、根据本公开的第二方面,提供一种基于客户端自主性的分布式训练系统,包括:
27、特征确定模块,用于确定客户端本地训练的特征,所述客户端本地训练的特征包括系统特征和统计特征;
28、第一优化模块,用于根据所述客户端本地训练的特征,采用提前停止策略决策所述客户端的本地训练进程;
29、第二优化模块,用于根据所述客户端本地训练的统计特征,采用基于错误反馈的层级提前传输策略优化所述客户端的本地训练的通信进程。
30、根据本公开的第三方面,提供一种非临时性计算机可读存储介质,其上存储有计算机程序,该程序被处理器执行时实现本公开第一方面提供的所述方法的步骤。
31、根据本公开的第四方面,提供一种电子设备,包括:
32、存储器,其上存储有计算机程序;
33、处理器,用于执行所述存储器中的所述计算机程序,以实现本公开第一方面提供的所述方法的步骤。
34、与现有技术相比,本公开实施例具有如下至少一种有益效果:
35、通过上述技术方案,确定客户端本地训练的特征,其中,客户端本地训练的特征包括系统特征和统计特征,实现对训练轮次内客户端状态变化的快速感知;基于客户端训练轮次内本地训练特征的指导,客户端采用提前停止策略,自主决策是否停止本轮训练进程,能够实现在最小化对模型精度的影响的前提下,节省计算资源,减少训练较慢的客户端对全局训练进程的影响,提升训练效率;基于客户端训练轮次内本地训练特征的指导,客户端采用基于错误反馈的层级提前传输策略,提前传输已收敛的层级参数来增加计算-通信重叠时间,并引入错误反馈机制减少提前传输对全局精度的影响,提高资源利用率,并提升客户端本地训练效率。
本文地址:https://www.jishuxx.com/zhuanli/20241204/341767.html
版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌抄袭侵权/违法违规的内容, 请发送邮件至 YYfuon@163.com 举报,一经查实,本站将立刻删除。