技术新讯 > 计算推算,计数设备的制造及其应用技术 > 图神经网络的训练方法及装置与流程  >  正文

图神经网络的训练方法及装置与流程

  • 国知局
  • 2024-09-05 14:45:22

本说明书一个或多个实施方式涉及图数据,尤其涉及一种图神经网络的训练方法及装置。

背景技术:

1、图数据(graph data)是由节点和边组成的数据结构,主要用于表示复杂或不规则的数据关系,其广泛应用于如社交网络、知识图谱、分子结构等数据场景,图神经网络(gnn,graph neural network)是一种专门针对图数据设计的神经网络架构。

2、在实际应用中,通常采用自监督学习对图神经网络进行训练,例如可以采用掩码自编码(mae,masked autoencoders)技术进行训练,但是mae往往只能获取到局部邻域信息,难以获取整个图的全局信息,导致网络训练效果受限。

技术实现思路

1、有鉴于此,为了提高图神经网络的任务性能和效果,本说明书一个或多个实施方式提供了一种图神经网络的训练方法及装置、图数据处理方法及装置、电子设备、存储介质以及计算机程序。

2、第一方面,本说明书提供了一种图神经网络的训练方法,包括:

3、获取样本图数据,对所述样本图数据的节点特征进行掩码处理得到掩码图数据,并对所述样本图数据的数据结构进行增强处理得到目标图数据;

4、将所述掩码图数据和所述目标图数据输入待训练的图神经网络,利用所述图神经网络的编码器对所述掩码图数据进行编码得到第一图嵌入向量,并对所述目标图数据进行编码得到第二图嵌入向量;

5、基于所述第一图嵌入向量确定节点特征损失,并基于所述第一图嵌入向量和所述第二图嵌入向量之间的差异确定对比损失,所述节点特征损失表示重建后的图数据与样本图数据之间的差异;

6、基于所述节点特征损失和对比损失,对所述图神经网络的所述编码器进行参数调整,直至满足收敛条件,得到训练后的图神经网络。

7、在一些实施方式中,对所述样本图数据的节点特征进行掩码处理得到掩码图数据,包括:

8、对所述样本图数据中的节点进行随机采样,得到预设数量的掩码节点;

9、对所述掩码节点的节点特征进行掩码处理,得到所述掩码图数据。

10、在一些实施方式中,对所述样本图数据的数据结构进行增强处理得到目标图数据,包括:

11、对所述样本图数据中预设数量的节点和/或边进行随机舍弃,得到所述目标图数据。

12、在一些实施方式中,基于所述第一图嵌入向量确定节点特征损失,包括:

13、将所述第一图嵌入向量输入所述图神经网络的解码器,得到所述解码器输出的图表征,所述图表征用于重建所述掩码图数据;

14、基于所述第一图嵌入向量、所述图表征以及所述样本图数据的节点特征,计算得到所述节点特征损失。

15、在一些实施方式中,基于所述第一图嵌入向量和所述第二图嵌入向量之间的差异确定对比损失,包括:

16、对所述第一图嵌入向量进行空间映射得到第一对比向量,对所述第二图嵌入向量进行空间映射得到第二对比向量;

17、基于所述第一对比向量和所述第二对比向量中的每个节点对之间的异同,计算得到所述对比损失。

18、在一些实施方式中,本说明书的训练方法还包括:

19、确定所述第一图嵌入向量和所述第二图嵌入向量中各个节点特征的方差,并基于预设约束条件计算得到所述第一图嵌入向量和所述第二图嵌入向量的区分度损失,所述预设约束条件用于约束所述第一图嵌入向量和所述第二图嵌入向量的方差最小值;

20、基于所述节点特征损失和对比损失,对所述图神经网络的所述编码器进行参数调整,直至满足收敛条件,得到训练后的图神经网络,包括:

21、基于所述节点特征损失、所述对比损失以及所述区分度损失,对所述图神经网络的所述编码器进行参数调整,直至满足收敛条件,得到训练后的图神经网络。

22、在一些实施方式中,本说明书的训练方法还包括:

23、基于所述图表征重建所述掩码图数据包括的第二邻接矩阵,并根据所述第二邻接矩阵与原始样本图数据包括的第一邻接矩阵之间的差异,确定第一损失项;

24、基于所述第一邻接矩阵与所述第二邻接矩阵所表示的边差异,确定第二损失项;

25、基于所述图表征中每任意两个节点之间的相对距离,确定第三损失项;

26、根据所述第一损失项、第二损失项以及第三损失项确定邻接矩阵损失;

27、基于所述节点特征损失和对比损失,对所述图神经网络的所述编码器进行参数调整,直至满足收敛条件,得到训练后的图神经网络,包括:

28、基于所述节点特征损失、所述对比损失以及所述邻接矩阵损失,对所述图神经网络的所述编码器进行参数调整,直至满足收敛条件,得到训练后的图神经网络。

29、第二方面,本说明书提供了一种图数据处理方法,包括:

30、获取待处理的图数据;

31、将所述图数据输入预先训练的图神经网络,利用所述图神经网络的编码器对所述图数据进行编码得到对应的图嵌入向量,所述图神经网络通过上述任意实施方式所述的方法训练得到;

32、将所述图嵌入向量输入相应的任务网络中,得到所述任务网络预测输出的任务结果。

33、第三方面,本说明书提供了一种图神经网络的训练装置,包括:

34、数据处理模块,被配置为获取样本图数据,对所述样本图数据的节点特征进行掩码处理得到掩码图数据,并对所述样本图数据的数据结构进行增强处理得到目标图数据;

35、编码器模块,被配置为将所述掩码图数据和所述目标图数据输入待训练的图神经网络,利用所述图神经网络的编码器对所述掩码图数据进行编码得到第一图嵌入向量,并对所述目标图数据进行编码得到第二图嵌入向量;

36、网络训练模块,被配置为基于所述第一图嵌入向量确定节点特征损失,并基于所述第一图嵌入向量和所述第二图嵌入向量之间的差异确定对比损失,所述节点特征损失表示重建后的图数据与样本图数据之间的差异;基于所述节点特征损失和对比损失,对所述图神经网络的所述编码器进行参数调整,直至满足收敛条件,得到训练后的图神经网络。

37、在一些实施方式中,所述数据处理模块被配置为:

38、对所述样本图数据中的节点进行随机采样,得到预设数量的掩码节点;

39、对所述掩码节点的节点特征进行掩码处理,得到所述掩码图数据。

40、在一些实施方式中,所述数据处理模块被配置为:

41、对所述样本图数据中预设数量的节点和/或边进行随机舍弃,得到所述目标图数据。

42、在一些实施方式中,所述网络训练模块被配置为:

43、将所述第一图嵌入向量输入所述图神经网络的解码器,得到所述解码器输出的图表征,所述图表征用于重建所述掩码图数据;

44、基于所述第一图嵌入向量、所述图表征以及所述样本图数据的节点特征,计算得到所述节点特征损失。

45、在一些实施方式中,所述网络训练模块被配置为:

46、对所述第一图嵌入向量进行空间映射得到第一对比向量,对所述第二图嵌入向量进行空间映射得到第二对比向量;

47、基于所述第一对比向量和所述第二对比向量中的每个节点对之间的异同,计算得到所述对比损失。

48、在一些实施方式中,所述网络训练模块被配置为:

49、确定所述第一图嵌入向量和所述第二图嵌入向量中各个节点特征的方差,并基于预设约束条件计算得到所述第一图嵌入向量和所述第二图嵌入向量的区分度损失,所述预设约束条件用于约束所述第一图嵌入向量和所述第二图嵌入向量的方差最小值;

50、基于所述节点特征损失和对比损失,对所述图神经网络的所述编码器进行参数调整,直至满足收敛条件,得到训练后的图神经网络,包括:

51、基于所述节点特征损失、所述对比损失以及所述区分度损失,对所述图神经网络的所述编码器进行参数调整,直至满足收敛条件,得到训练后的图神经网络。

52、在一些实施方式中,所述网络训练模块被配置为:

53、基于所述图表征重建所述掩码图数据包括的第二邻接矩阵,并根据所述第二邻接矩阵与原始样本图数据包括的第一邻接矩阵之间的差异,确定第一损失项;

54、基于所述第一邻接矩阵与所述第二邻接矩阵所表示的边差异,确定第二损失项;

55、基于所述图表征中每任意两个节点之间的相对距离,确定第三损失项;

56、根据所述第一损失项、第二损失项以及第三损失项确定邻接矩阵损失;

57、基于所述节点特征损失和对比损失,对所述图神经网络的所述编码器进行参数调整,直至满足收敛条件,得到训练后的图神经网络,包括:

58、基于所述节点特征损失、所述对比损失以及所述邻接矩阵损失,对所述图神经网络的所述编码器进行参数调整,直至满足收敛条件,得到训练后的图神经网络。

59、第四方面,本说明书提供了一种图数据处理装置,包括:

60、数据获取模块,被配置为获取待处理的图数据;

61、图神经网络模块,被配置为将所述图数据输入预先训练的图神经网络,利用所述图神经网络的编码器对所述图数据进行编码得到对应的图嵌入向量,所述图神经网络通过上述任意实施方式所述的方法训练得到;

62、任务网络模块,被配置为将所述图嵌入向量输入相应的任务网络中,得到所述任务网络预测输出的任务结果。

63、第六方面,本说明书提供了一种电子设备,包括:

64、处理器;

65、用于存储处理器可执行指令的存储器,其中,所述处理器通过运行所述可执行指令以实现上述任意实施方式所述方法的步骤。

66、第七方面,本说明书提供了一种计算机可读存储介质,其上存储有计算机指令,该指令被处理器执行时实现上述任意实施方式所述方法的步骤。

67、第八方面,本说明书提供了一种计算机程序产品,包括计算机程序/指令,该计算机程序/指令被处理器执行时实现上述任意实施方式所述方法的步骤。

68、本说明书实施方式的图神经网络的训练方法,通过共享的编码器架构实现掩码自编码mae和对比学习cl的融合训练,使得训练后的gnn可以同时捕获图结构数据的局部信息和全局信息,提高gnn对整图全局信息的感知能力,进而提高gnn在整图处理任务中的表现性能。而且在网络训练过程中,掩码自编码mae和对比学习cl共享同一个编码器,无需额外的嵌入向量的融合策略,简化训练算法过程。

本文地址:https://www.jishuxx.com/zhuanli/20240905/288014.html

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌抄袭侵权/违法违规的内容, 请发送邮件至 YYfuon@163.com 举报,一经查实,本站将立刻删除。