技术新讯 > 计算推算,计数设备的制造及其应用技术 > 用于图片分类的遗忘模型的训练方法及图片的分类方法  >  正文

用于图片分类的遗忘模型的训练方法及图片的分类方法

  • 国知局
  • 2024-10-09 15:20:45

本发明涉及图像分类,具体涉及一种用于图片分类的遗忘模型的训练方法及图片的分类方法。

背景技术:

1、机器遗忘学习指从模型中抹除训练集中某个类别或特征对于模型的影响,从而实现机器学习模型对于指定数据的遗忘删除。随着人工智能技术的广泛应用,大量用户的个人隐私图片被应用于深度神经网络模型的训练当中。然而,众多研究已经证明,深度神经网络模型在实际应用中存在着数据隐私泄露的风险,恶意攻击者可以通过一定技术手段来获取模型训练数据,这就需要人工智能服务商在结束对于隐私图片的使用后进行有效删除。同时在法律层面上,众多数据安全法和隐私保护法也已经明确规定了数据所有者有权要求服务商将其隐私数据从存储该数据的实体中删除。因此,对于深度神经网络模型中用户隐私图片遗忘方法的开发是亟需且重要的。

2、目前对于深度神经网络模型中的数据遗忘删除的主流思路是将待遗忘数据剔除出训练数据集后,再次使用原训练方法对模型进行重新训练,或是使用其他近似的思路。例如,lucas bourtoule等人提出的sisa(sharded、isolated、sliced and aggregatedtraining)算法是该领域目前使用最广泛的方法,这种方法的大致流程为:首先,将完整的训练数据集平均划分为等长的n份小数据集,这n份数据互相独立,不存在交集;之后,通过在n份数据上分别进行模型的训练,记录存储n个不同的模型及其参数;在遗忘数据时,只需定位到待遗忘数据所在块对应的模型,在对应子数据集中剔除出待遗忘数据后重新训练该模型,并于其他模型进行后验聚合得到最终输出。

3、但是,重训练方法或其他近似思路的资源开销较大,需要服务商存储完整的训练数据集和训练信息,并且需要进行多次训练,造成了较大的资源开销。

技术实现思路

1、为了解决现有技术中存在的上述问题,本发明提供了一种用于图片分类的遗忘模型的训练方法及图片的分类方法。本发明要解决的技术问题通过以下技术方案实现:

2、本发明提供一种用于图片分类的遗忘模型的训练方法,所述隐私图片遗忘方法包括:

3、将第一高斯噪声分布输入至训练好的生成对抗网络中得到图片数据集dt,并根据图片数据集dt得到没有交集的数据集du和数据集dr,其中,待遗忘的图片均包含在所述数据集du;

4、基于初始模型,根据所述数据集du得到q维离散输出分布,并根据所述q维离散输出分布得到未包括待遗忘的图片的q-1维平均输出分布向量,根据所述数据集du和所述数据集dr得到q-1维余弦相似度分布向量,并根据所述q-1维平均输出分布向量和所述q-1维余弦相似度分布向量得到q维遗忘训练目标分布向量,其中,q为类别的总数量;

5、将所述数据集du和所述数据集dr分别输入遗忘模型中,基于所述遗忘模型的总损失函数得到总损失值,基于所述总损失值对所述遗忘模型进行训练,得到训练好的遗忘模型,其中,所述总损失函数中包括根据所述遗忘训练目标分布向量和所述数据集du输入至所述遗忘模型后的输出分布构建的第一子损失函数,所述训练好的遗忘模型用于对待分类的图片进行分类,所述遗忘模型与所述初始模型的结构相同,且初始化的遗忘模型与所述初始模型的参数相同。

6、在本发明的一个实施例中,所述生成对抗网络包括生成器和判别器;

7、其中,将第一高斯噪声分布输入至训练好的生成对抗网络中得到图片数据集dt,并根据图片数据集dt得到没有交集的数据集du和数据集dr,包括:

8、获取随机采样的第一高斯噪声分布;

9、将所述第一高斯噪声分布输入至训练好的生成对抗网络中,所述训练好的生成对抗网络的生成器输出图片数据集dt;

10、将所述图片数据集dt划分为没有交集的数据集du和数据集dr。

11、在本发明的一个实施例中,所述生成对抗网络的训练方法包括:

12、s1、获取n张训练图片;

13、s2、将每张所述训练图片转换为张量并进行归一化,得到包含n个真实样本的第一集合;

14、s3、在第二高斯噪声分布中进行随机抽样,得到包含n个噪声样本的第二集合;

15、s4、将第n个所述噪声样本输入至生成器,得到第n个虚假样本,其中,n个所述虚假样本组成第三集合,第n个所述真实样本与第n个所述虚假样本的规格相同,1≤n≤n;

16、s5、随机从所述第一集合中选取m个真实样本和从所述第三集合中选取与所述m个真实样本对应的m个虚假样本;

17、s6、将m个真实样本和m个虚假样本依次输入至所述判别器,并利用所述判别器的损失函数得到第一损失值,以使用随机梯度下降法对所述判别器的参数进行更新,所述判别器的损失函数表示为:

18、

19、其中,lossd为第一损失值,d(t(m))为m个真实样本中第m个真实样本的得分,d(g(u(m)))为m个虚假样本中第m个虚假样本g(u(m)的得分;

20、s7、固定步骤s6所述判别器的参数,将m个虚假样本依次输入至所述生成器,并利用所述生成器的损失函数得到第二损失值,以使用随机梯度下降法对所述生成器的参数进行更新,所述生成器的损失函数为:

21、

22、其中,lossg为第二损失值;

23、s8、重复步骤s1至s7,直至达到预设的第一训练条件,得到所述训练好的生成对抗网络。

24、在本发明的一个实施例中,基于初始模型,根据所述数据集du得到q维离散输出分布,并根据所述q维离散输出分布得到未包括待遗忘的图片的q-1维平均输出分布向量,根据所述数据集du和所述数据集dr得到q-1维余弦相似度分布向量,并根据所述q-1维平均输出分布向量和所述q-1维余弦相似度分布向量得到q维遗忘训练目标分布向量,包括:

25、在所述数据集du中进行随机抽样,得到第一抽样数据集,其中,所述第一抽样数据集包括多个第一抽样数据样本,且所有所述第一抽样数据样本中包含待遗忘的第一抽样数据样本;

26、基于利用初始模型对所有所述第一抽样数据集分别进行识别分类得到的q维离散输出分布,得到q个平均预测概率,并根据所述q个平均预测概率得到平均输出分布向量,其中,所述平均输出分布向量由未包含待遗忘的第一抽样数据样本的其余q-1个平均预测概率得到;

27、在所述数据集du中进行随机抽样,得到第二抽样数据集,其中,所述第二抽样数据集包括多个第二抽样数据样本,且所有所述第二抽样数据样本中包含待遗忘的第二抽样数据样本;

28、在所述数据集dr中进行随机抽样,得到第三抽样数据集,其中,所述第三抽样数据集包括q-1个第三抽样数据样本;

29、根据基于待遗忘的第二抽样数据样本的平均向量和q-1个所述第三抽样数据样本的平均向量得到的q-1个所述余弦相似度,得到余弦相似度分布向量;

30、根据所述平均输出分布向量和所述余弦相似度分布向量得到q维遗忘训练目标分布向量。

31、在本发明的一个实施例中,基于利用初始模型对所有所述第一抽样数据集分别进行识别分类得到的q维离散输出分布,得到q个平均预测概率,并根据所述q个平均预测概率得到平均输出分布向量,包括:

32、利用初始模型对所述第一抽样数据集的每个第一抽样数据样本进行识别分类,以统计得到所述第一抽样数据样本的q维离散输出分布;

33、根据所有所述第一抽样数据样本的q维离散输出分布得到q维平均离散输出分布,其中,所述q维平均离散输出分布包括q个类别的平均预测概率;

34、确定所述待遗忘的第一抽样数据样本的q维离散输出分布中的最大概率的类别作为待遗忘类别,去除所述q维平均离散输出分布中待遗忘类别的平均预测概率,根据所述q维平均离散输出分布中其余的q-1个平均预测概率得到平均输出分布向量。

35、在本发明的一个实施例中,根据基于待遗忘的第二抽样数据样本的平均向量和q-1个所述第三抽样数据样本的平均向量得到的q-1个所述余弦相似度,得到余弦相似度分布向量,包括:

36、分别将待遗忘的第二抽样数据样本和所有所述第三抽样数据样本均展开为向量形式,并计算向量形式的待遗忘的第二抽样数据样本和所有向量形式的第三抽样数据样本的平均向量;

37、根据向量形式的待遗忘的第二抽样数据样本得到待遗忘的第二抽样数据样本的平均向量,根据向量形式的第q个第三抽样数据样本得到第q个第三抽样数据样本的平均向量,并根据待遗忘的第二抽样数据样本的平均向量和第q个第三抽样数据样本的平均向量得到第q个余弦相似度,其中,1≤q≤q-1;

38、根据q-1个所述余弦相似度的绝对值得到余弦相似度分布向量。

39、在本发明的一个实施例中,根据所述平均输出分布向量和所述余弦相似度分布向量得到q维遗忘训练目标分布向量,包括:

40、根据所述平均输出分布向量和所述余弦相似度分布向量得到样本相似度分布向量,所述样本相似度分布向量表示为:

41、psimilarity=[z1,z2,...,zq,...,zq-1]

42、zq=xq·[yq-min(y1,y2,...,yq,...,yq-1)]

43、其中,psimilarity为样本相似度分布向量,xq为平均输出分布向量中的第q个元素,yq为余弦相似度分布向量中的第q个元素,1≤q≤q-1;

44、向所述样本相似度分布向量中增加元素0,以将所述样本相似度分布向量扩展至q维,其中,所述元素0的位置为所述待遗忘的图片对应的类别序号;

45、对扩展至q维的样本相似度分布向量中的每个元素进行归一化,得到包括q个归一化后的元素的遗忘训练目标分布向量,其中,所述q个归一化后的元素的和为1。

46、在本发明的一个实施例中,将所述数据集du和所述数据集dr分别输入遗忘模型中,基于所述遗忘模型的总损失函数得到总损失值,基于所述总损失值对所述遗忘模型进行训练,得到训练好的遗忘模型,包括:

47、步骤3.1、初始化所述遗忘模型;

48、步骤3.2、将所述数据集du输入所述遗忘模型,基于第一子损失函数得到第三损失值;

49、步骤3.3、将所述数据集dr输入所述遗忘模型,基于第二子损失函数得到第四损失值;

50、步骤3.4、基于所述第一子损失函数和所述第二子损失函数构建总损失函数,以根据所述第三损失值和所述第四损失值得到总损失值:

51、步骤3.5、基于所述总损失值,使用随机梯度下降法对所述遗忘模型的参数进行更新;

52、步骤3.6、重复步骤3.2至步骤3.5,直至达到预设的第二训练条件,得到所述训练好的遗忘模型。

53、在本发明的一个实施例中,所述第一子损失函数表示为:

54、lu=kl(p(mu(du))||pforecast)

55、其中,lu为第三损失值,p(mu(du))为数据集du输入至输入遗忘模型后的输出分布,pforecast为遗忘训练目标分布向量,kl(·)为相对熵;

56、所述第二子损失函数表示为:

57、lr=lossinit(mu(dr))

58、其中,lr为第四损失值,lossinit(mu(dr))为数据集dr输入至输入遗忘模型后的损失值;

59、所述总损失函数表示为:

60、loss=α·kl(p(mu(du))||ptarget)+(1-α)·lossinit(mu(dr))

61、其中,loss为总损失值,α为权重参数,取值范围为[0,1]。

62、本发明还提供一种基于遗忘模型的图片的分类方法,所述分类方法包括:

63、获取待分类的图片;

64、将所述待分类的图片输入至上述任一项实施例所述的训练好的遗忘模型中,得到分类结果。

65、与现有技术相比,本发明的有益效果在于:

66、本发明首先将第一高斯噪声分布输入至训练好的生成对抗网络中得到图片数据集dt,并根据图片数据集dt得到没有交集的数据集du和数据集dr,然后基于初始模型,根据数据集du得到q维离散输出分布,并根据q维离散输出分布得到未包括待遗忘的图片的q-1维平均输出分布向量,根据数据集du和所述数据集dr得到q-1维余弦相似度分布向量,并根据q-1维平均输出分布向量和q-1维余弦相似度分布向量得到q维遗忘训练目标分布向量。并且本发明的总损失函数中包括根据遗忘训练目标分布向量和数据集du输入至所述遗忘模型后的输出分布构建的第一子损失函数。因此,本发明所提出的基于生成对抗网络的数据集生成与划分方法不需要人工智能服务商事先存储原模型的完整训练数据集和训练信息,降低了服务商的存储成本。同时本发明通过基于预测概率分布的损失函数设计,使得对于用户隐私图片遗忘的各项任务在单次训练中完成,提高了时间效率。

67、以下将结合附图及实施例对本发明做进一步详细说明。

本文地址:https://www.jishuxx.com/zhuanli/20241009/308427.html

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌抄袭侵权/违法违规的内容, 请发送邮件至 YYfuon@163.com 举报,一经查实,本站将立刻删除。