技术新讯 > 计算推算,计数设备的制造及其应用技术 > 非对称双任务协同训练的偏标记学习模型及图像处理装置  >  正文

非对称双任务协同训练的偏标记学习模型及图像处理装置

  • 国知局
  • 2024-10-09 15:19:49

本发明属于人工智能领域中的弱监督学习技术,具体涉及一种非对称双任务协同训练的偏标记学习模型及图像处理装置。

背景技术:

1、近年来,人工智能技术得到了快速发展,广泛应用于现实世界中的各种场景,如国防建设、医疗健康、电子商务等。深度学习是人工智能的重要支撑技术,依赖大量精准标注的样本完成参数训练。然而,在现实世界中,由于安全、隐私、样本歧义等因素,精准标准样本的获取成本较高。因此,依赖不完全、不精确和不准确标注数据的弱监督学习方法近年来受到了广泛关注。

2、偏标记学习是一类典型的弱监督学习问题。不同于为每个样本标注唯一且准确的标签,偏标记学习假设每个样本被标注了一组候选标签,其中有且仅有一个标签为真实标签。为样本标注候选标签集的难度远远低于为其进行精准唯一的标注,特别当样本的特征层面语义模糊程度比较高时。例如,西伯利亚雪橇犬与狼外观特征相似,对于一张西伯利亚雪橇犬或狼的图片,标注者准确识别其唯一真实标签比较困难,会带来较高人力和时间成本,而为其标注同时包含西伯利亚雪橇犬和狼在内的候选标签集则容易得多。

3、最近对偏标记学习的研究主要集中在基于辨识的方法上,这些方法将真实标签视为潜在变量,并通过标签消歧来识别真实标签。基于此出现了各种算法,如最大边距法、图模型法、期望最大化算法、对比学习法和一致性正则化法等。在这些方法中,自训练深度模型是一种有前景的方法,它通过学习标签置信度向量并进行迭代训练,取得了最先进的性能。

4、然而,自训练的偏标记学习(pll)模型存在错误累积的问题,因为复杂的实例难以分类,容易被错误地消歧,这可能会进一步误导模型,导致假阳性标签以及性能下降。协同策略通过同时训练两个网络并使它们相互作用,是缓解错误累积的可行解决方案。虽然协同策略在噪声标记学习(nll)中得到了广泛研究,但在偏标记学习中并没有被充分研究。最近,yao等人(yao y,gong c,deng j,et al.network cooperation with progressivedisambiguation for partial label learning.in:proceedings of machine learningand knowledge discovery in databases:european conference,2020,471-488)提出了一种基于协同训练的新方法,称为ncpd。ncpd通过数据复制将偏标签数据集转换为高噪声率的数据集,并采用一种典型的nll方法——co-teaching协同训练(han b,yao q,yux,etal.co-teaching:robusttrainingof deep neural networks withextremely noisylabels.in:proceedings of advances inneural informationprocessing systems,2018,31)。然而,ncpd不仅导致极高的时间和空间复杂度,其模型性能也有限。

5、此外,现有的大多数协同训练模型,包括ncpd在内,都是对称的,即它们的两个网络分支具有相同的结构,并使用相同的输入数据和损失函数进行训练。他们假设,通过不同的参数初始化,可以使两个结构相同的网络在同一任务上获得不同的能力,从而能够相互纠正错误。然而,在对称模式下训练会使得两个网络更容易产生相同的问题,例如,两个网络都难以正确识别复杂的样本。因此,它们无法有效地纠正错误。

技术实现思路

1、有鉴于此,本发明的首要目的在于提供一种非对称双任务协同训练的偏标记学习模型,通过不同任务协同训练两个结构相同的网络,使其可以从偏标记数据中有效训练模型参数,得到一个面向偏标记数据集测试环境的分类器,并在均匀的偏标签生成策略和实例依赖的偏标签生成策略之下表现出优越的性能。

2、为实现上述目的,本发明所采用的具体技术方案如下:

3、一种非对称双任务协同训练的偏标记学习模型,其关键在于,包括消歧网络模块、辅助网络模块和错误校正模块,其中:

4、所述消歧网络模块用于根据样本数据计算各个类别的分类概率和标签置信度;

5、所述辅助网络模块以样本数据和其通过所述消歧网络模块产生的分类概率和标签置信度为输入,输出辅助网络模块对样本数据在各类别的最终分类概率及标签最终置信度;

6、所述错误校正模块用于根据所述消歧网络所得的分类结果和所述辅助网络模块所得的最终分类结果进行信息初步提取和置信度完善,所述信息初步提取用于将辅助网络模块所得的最终分类概率视为真实分布,并引入基于kl散度的损失,确保所述消歧网络模块的预测概率与所述辅助网络模块的预测概率一致;所述置信度完善用于根据辅助网络模块所得的标签最终置信度对所述消歧网络模块的标签置信度进行微调。

7、可选地,所述消歧网络模块通过多层感知器计算分类得分,然后通过softmax函数计算分类概率pi∈rm,其中pi中的第k个元素pik表示样本数据xi被分类到标签k概率;

8、

9、其中,m为标签类别总数,yi为样本数据xi类结果,mlpk(xi)表示将数据xi类到标签k类得分,τ温度参数。

10、可选地,所述样本数据为图像数据,分别采用两种不同的图像增强技术为每个样本图像xi成两个增强图像,记为x′i=aug1(xi)和x″i=aug2(xi),且组成增强样本数据集(xi)={xi,x′i,x″i};

11、所述消歧网络模块根据增强样本数据集计算得到增强样本对应的分类概率p′i,p″i和标签置信度c′i,c″i。

12、可选地,所述消歧网络模块:

13、按照计算样本数据xi的分类一致性损失,其中,表示的模,表示xi及其增强样本所构成的样本集合,表示中的任一样本,表示样本分类到类别k的概率,yi表示样本数据xi的候选标签集;

14、按照计算样本数据xi属于类别k的综合置信度,得到样本数据xi的类别综合置信度向量wi=[wi1,wi2,…,wim],其中表示样本xi对应的样本集合中的任意样本在类别k上的置信度,其根据样本在上个轮次中的概率预测结果计算,若类别k∈yi,则否则

15、按照计算样本数据xi的风险一致性损失,其中表示的模,表示xi及其增强样本所构成的样本集合,wik表示样本xi属于类别k的置信度,表示中的任一样本,表示样本分类到类别k的概率,yi表示样本数据xi的候选标签集。通过最小化数据增强后的风险一致性损失,使得原始样本数据和增强样本数据的分类概率都逐渐接近综合置信度向量wi;

16、可选地,所述消歧网络模块按照构建的损失函数ldisam(xi)=lcc(xi)+γ(t)lrc(xi)进行训练,其中权重参数λ,t是超参数,t为训练轮次。

17、可选地,根据消歧网络的预测结果,为每个实例生成伪类标签。伪类标签是从置信度向量ci中选择概率最大的标签,表示模型认为最可能的真实标签。如果一对样本数据共享相同的伪标签,则为它们分配一个相似性标签1;否则,分配一个相似性标签为0,生成的相似性数据集表示为其中,分别用于表示样本xi和xj经过辅助网络抽取的特征表示,表示一对样本对,χ表示样本空间,sij∈{0,1}表示由样本xi和xj所组成样本对的相似性标签,利用生成的相似性数据集来训练所述辅助网络模块;

18、对于相似性标签为1的样本数据对,期望它们的预测分类概率表现出高度相似性,训练过程中采用损失函数:

19、其中,二元交叉熵损失函数表示样本xi通过辅助网络预测的在类别k上的分类概率,表示样本xi的数据增强样本在类别k上的分类概率。表示在反向传播期间停止损失函数中关于的梯度更新。

20、可选地,所述信息初步提取步骤中引入的基于kl散度的损失按照:

21、计算;其中,kl(·)表示kl散度函数。

22、可选地,所述置信度完善步骤中,在第t个轮次,微调后的置信度计算如下:

23、其中,为辅助网络计算的样本xi的综合置信度向量,μ(t)=min(ρ×max(t-t0,0),μmax)是训练轮次t的递增函数,t0和

24、μmax是超参数,在第t0个训练次数之前,μ(t)=0,并且0≤μmax≤1是μ(t)的上限,μ的增长速度取决于ρ,在协同训练期间,原始置信度wi被改善后的置信度替代。

25、可选地,按照设置模型总体训练损失进行训练;首先,消歧网络模块经过前期预热,确保通过置信度向量准确识别部分训练样本数据的真实标签;随后,预训练模型参数用于初始化辅助网络的参数,此外,为了增强模型训练的效率,使用同一小批量内的样本数据生成噪声相似性标签;在推断阶段,选择使用消歧网络模块或辅助网络模块中的一个模块进行预测。

26、基于上述模型,本发明还提供一种图像处理装置,其关键在于:采用前文所述的非对称双任务协同训练的偏标记学习模型进行图像分类。

27、本发明的有益效果是:

28、(1)本发明探索了偏标记学习的非对称协同训练,并提出了一种全新的基于深度学习的双任务偏标记学习模型,通过不同任务协同训练两个结构相同的网络。

29、(2)本发明提出了一种有效的监督学习辅助网络,利用消歧网络识别的伪标签进行训练,并通过信息蒸馏和标签置信度精调逐步缓解错误累积问题。

30、(3)本发明在基准数据集上的广泛实验结果表明,所提模型在均匀的偏标签生成策略和样本数据依赖的偏标签生成策略之下都表现出优越的性能。

本文地址:https://www.jishuxx.com/zhuanli/20241009/308362.html

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌抄袭侵权/违法违规的内容, 请发送邮件至 YYfuon@163.com 举报,一经查实,本站将立刻删除。