一种残膜回收机防缠绕挑膜装置的制 一种秧草收获机用电力驱动行走机构

一种基于元学习的自监督域适应方法与流程

2021-10-24 07:19:00 来源:中国专利 TAG:特征 提取 域图 目标 网络


1.本发明属于计算机视觉和图像处理领域,具体地,涉及一种基于元学习和图像重建的无监督域适应方法,可以使源域图像分类任务和目标域图像重建任务对特征提取网络在参数更新方向上趋于一致,即利用目标域图像中物体的空间关系、光照等信息使源域图像类别信息对特征提取网络的参数更新方向趋向利用目标域图像中物体的空间关系、光照等信息对特征提取网络的参数更新方向,从而使特征提取网络得到的特征为域不变特征,实现源域和目标域的域适应。


背景技术:

2.近些年来,基于深度神经网络的有监督学习已在许多方面有了成熟的应用,被广泛地应用于图像分类、目标检测、语义分割、自然语言处理等领域,极大地促进了人工智能技术与现实生活的结合。然而,许多有监督学习方法往往假设训练集样本和测试集样本的概率分布服从同一分布,此外,有监督学习方式为了使网络具有较好的泛化性能,避免过拟合问题,在训练阶段往往需要大量有标注的训练样本。然而随着大数据时代的到来,数据规模日益增大,不同数据集合间统计特性差异、数据标注的人工成本较高等问题逐渐显现。为了解决上述问题,无监督域适应方法被广泛研究。无监督域适应是一种解决源域和目标域之间分布差异的方法,其通过学习有标签的源域数据中的某些可以泛化的知识,并将其应用于无标签的目标域数据上的任务来提高网络在目标域上的表现。
3.目前,很多无监督域适应方法多聚焦于采用基于距离度量或基于对抗学习等方式来对齐源域数据和目标域数据的特征概率分布,从而减小域间差异,使网络能够学习到域不变特征空间。但是这些方法仅从源域数据和目标域数据的整体上进行分布对齐,没有关注目标域数据的内在特征对域适应过程中知识迁移的影响,因此,基于图像重建等的自监督域适应方法通过在无标签的目标域数据上设置图像重建、图像旋转角度预测、图像修复等自监督任务来协同有标签的源域数据对网络进行联合训练,即通过对目标域数据内在特征的挖掘来辅助网络迁移从源域数据中学习到的知识到目标域,此外,由于基于图像重建的方法在域适应提取可迁移特征的同时保留了数据的完整性,保证目标域数据中对提升特定任务性能的信息不被破坏,并且在对源域数据的知识迁移过程中尽可能保留了目标域数据的原始分布,因此能够更好地将源域数据中的学习到的知识应用在目标域任务中。然而,在域适应过程中,现有的很多无监督域适应方法往往将特征提取网络的参数分别通过自监督任务和图像分类等特定任务进行简单地更新,没有考虑两类任务对特征提取网络参数的更新方向是否一致,可能会使得自监督任务对图像分类等特定任务的特征学习造成负面影响,通过元学习可以将自监督任务和图像分类等特定任务分别作为训练器和测试器,通过训练器中的损失函数对测试器网络进行参数更新,从而迫使两类任务对网络的参数更新方向趋向一致。


技术实现要素:

4.针对现有技术中存在的不足,本发明提出一种基于元学习(meta learning)的自监督域适应方法。
5.一种基于元学习的自监督域适应方法,包括以下步骤:
6.步骤1,设置训练器和测试器:
7.将目标域样本的重建过程作为元学习中的训练器,将源域样本的分类过程作为元学习中的测试器。
8.步骤2,利用目标域样本进行图像重建任务并计算重建损失:
9.将无标签的目标域样本输入特征提取网络得到目标域样本特征,然后将目标域样本特征输入图像重建网络进行图像重建并计算重建损失。
10.步骤3,对特征提取网络进行参数更新:
11.利用训练器中的重建损失对训练器中的特征提取网络进行参数更新,由于权值共享,测试器中的特征提取网络的参数和训练器中的特征提取网络参数一起更新,即使得测试器中网络的参数更新方向趋向训练器中网络的参数更新方向。
12.步骤4,利用源域样本进行分类任务并计算分类损失:
13.将有标签的源域数据输入参数更新后的特征提取网络得到源域数据特征,然后将源域数据特征输入分类网络进行图像分类任务并计算分类损失。
14.步骤5,计算总损失函数并对全部网络进行参数更新:
15.计算总损失函数,并对训练器和测试器中的特征提取网络、重建网络和分类网络进行参数更新。
16.本发明有益效果如下:
17.(1)将目标域图像重建作为域适应过程中的自监督任务,监督信息即为目标域图像本身,不需要额外的目标域图像标注信息,节省了大量人工标注成本;此外,目标域图像的重建过程能够使网络学习到目标域图像中更丰富的高层语义信息,使得网络能够利用目标域数据的内在特征来辅助网络将源域数据中学习到的知识向目标域迁移,从而提升域适应方法的性能。
18.(2)通过将元学习策略引入自监督域适应中,使得目标域自监督任务和源域分类等特定任务对网络参数的更新方向趋于一致,使得网络能够更好地提取域不变特征,减少了域适应任务和特定任务对网络参数的更新方向不一致造成的负迁移问题,提升了域适应性能。
附图说明
19.图1为本发明提出的基于元学习的自监督域适应方法流程图。
20.图2为本发明提出的基于元学习的自监督域适应方法网络示意图。
21.图3为resnet网络的基本单元示意图。
22.图4为全连接层结构示意图。
具体实施方式
23.为使本发明的上述目的、特征和优点能够更加明显易懂,下面结合附图和具体实
施方式对本发明作进一步详细的说明。
24.如图1所示,一种基于元学习的自监督域适应方法,步骤如下:
25.步骤1,设置训练器和测试器:如图2所示,将目标域样本x
t
的重建过程作为元学习中的元训练器(meta

train),将源域样本x
s
的分类过程作为元学习中的元测试器(meta

test)。其中有标签的源域表示为s={x
s
,y
s
},x
s
∈x
s
和y
s
∈y
s
分别表示源域样本和相应的标签,无标签的目标域表示为t={x
t
},其中x
t
∈x
t
表示目标域样本。
26.步骤2,利用目标域样本进行图像重建任务并计算重建损失:
27.将无标签的目标域样本x
t
输入特征提取网络g得到目标域样本特征f
t
,然后将目标域样本特征f
t
输入图像重建网络d进行图像重建得到目标域重建样本并计算重建损失l
r
。其中特征提取网络g采用resnet

50结构,resnet的基本单元如图3所示,其通过跳接将之前层的输出与本层计算的输出相加,并将求和的结果输入到激活函数中作为本层的输出。通过resnet

50的特征提取过程,得到目标域样本特征f
t
=g(x
t
)。图像重建网络d采用解码器结构,通过一系列上采样将目标域样本特征f
t
还原为原图大小,即重建损失为:
[0028][0029]
其中n
t
为目标域样本个数,j为目标域中第j个样本。
[0030]
步骤3,对特征提取网络进行参数更新:
[0031]
利用训练器中的重建损失l
r
对特征提取网络g进行参数更新,即:
[0032][0033]
其中为当前特征提取网络的参数,为经过更新后的特征提取网络的参数,α为学习率,为解码器d的参数,表示对参数求梯度,梯度下降采用随机梯度下降算法。由于权值共享,测试器中的特征提取网络g的参数θ和训练器中的特征提取网络g的参数θ一起更新,即迫使测试器中图像分类等特定任务对网络的参数更新方向趋向训练器中自监督图像重建任务对网络的参数更新方向。
[0034]
步骤4,利用源域样本进行分类任务并计算分类损失:将有标签的源域数据x
s
输入参数更新后的特征提取网络g得到源域数据特征f
s
,然后将源域数据特征f
s
输入分类网络c进行图像分类任务并计算分类损失。其中分类网络c采用多全连接层加softmax层结构,全连接层结构示意图如图4所示,其每一个结点都与上一层的所有结点相连,用于将提取到的特征进行综合,并经过softmax层输出图像预测标签即分类损失为n
s
为源域样本个数,k为源域中第k个样本。
[0035]
步骤5,计算总损失函数并对全部网络进行参数更新:
[0036]
计算总损失函数l,并对训练器和测试器中的特征提取网络g、重建网络d和分类网络c进行参数更新,即:
[0037][0038]
其中β为学习率,{θ
g

d

c
}
t
为当前时刻网络的参数,总损失函数如下:
[0039][0040]
其中,θ
c
表示分类网络的参数,λ为超参数,用于控制图像重建任务和图像分类任务对网络参数更新的影响大小,l
r

g

d
)表示在网络参数为θ
g
和θ
d
时计算得到的图像重建损失,损失,表示在网络参数为和θ
c
时计算得到的图像分类损失。
再多了解一些

本文用于企业家、创业者技术爱好者查询,结果仅供参考。

发表评论 共有条评论
用户名: 密码:
验证码: 匿名发表

相关文献

  • 日榜
  • 周榜
  • 月榜