图像分类模型的训练方法、图像分类方法及相关装置与流程
- 国知局
- 2024-09-14 14:50:44
本技术涉及计算机视觉领域,尤其涉及一种图像分类模型的训练方法、图像分类方法及相关装置。
背景技术:
1、近年来,预训练模型被广泛应用于计算机视觉领域的图像分类任务。目前,可以通过为预训练模型的第一层或每一层的输入串接可优化的提示向量,并基于待实现的图像分类任务为该模型添加相应的分类头,之后通过微调提示向量和分类头参数来实现相关的图像分类任务。但是,针对不同的图像分类任务,上述方法需要使用不同的学习策略(例如,学习率、权重衰减因子)来调整提示向量和分类头参数,导致需要耗费大量的时间来搜索最优的学习策略,影响图像分类任务的效率和成本。
技术实现思路
1、本技术提供了一种图像分类模型的训练方法、图像分类方法及相关装置,能够提高图像分类任务的效率,降低图像分类任务的成本。
2、第一方面,本技术提供了一种图像分类模型的训练方法。该方法可以应用于训练装置。训练装置获取训练数据集,训练数据集包括多张训练图像。然后,训练装置为每张训练图像拼接提示向量和预训练模型的掩码向量,得到每张训练图像对应的目标输入向量,此处的预训练模型是通过掩码向量训练得到。之后,训练装置基于该预训练模型和每张训练图像对应的目标输入向量对上述提示向量进行优化,得到可用于执行图像分类任务的图像分类模型,其中,该图像分类模型包括以下参数:预训练模型的参数(例如,上述掩码向量)和已优化的提示向量。
3、本技术提供的技术方案中,训练装置通过掩码预测的方式来训练图像分类模型,即,为训练图像拼接掩码向量和提示向量以得到目标输入向量,通过对目标输入向量进行预测来训练图像分类模型。由于预训练模型也是通过掩码预测的方式训练得到的,因此利用本技术提供的方法训练图像分类模型,可以缩小上游任务和下游任务的差距,使得实现下游任务时所需的学习策略具有较好的稳定性,可以应用到不同的下游任务中。相应地,在使用预训练模型实现不同的下游任务时,可以使用相同的学习策略,从而令预训练模型高效地应用到多种下游任务中。
4、在第一方面的一种可能实现方式中,上述训练装置为每张训练图像拼接提示向量和预训练模型的掩码向量,包括:训练装置将提示向量和预训练模型的掩码向量与每张训练图像进行串接。其中,串接是指沿着向量的行的方向进行拼接。
5、应理解,在实际应用中为了方便处理训练图像,通常会将训练图像处理为行向量。而且,提示向量和掩码向量一般也会选用行向量,也就是说,训练图像、提示向量和掩码向量均为行向量,那么在拼接方式选用串接更为合适。
6、在第一方面的一种可能实现方式中,上述训练装置基于预训练模型和每张训练图像对应的目标输入向量对提示向量进行优化,包括:训练装置利用预训练模型对每张训练图像对应的目标输入向量中的掩码向量进行预测,得到每张训练图像对应的预测视觉词,然后将每张训练图像对应的预测视觉词和每张训练图像对应的类别标签映射到同一向量空间,之后基于映射后的预测视觉词和类别标签的相似度对提示向量进行优化。
7、上述实现方式中,训练装置对训练图像对应的目标输入向量中的掩码向量进行预测,实现了以掩码预测的方式来训练图像分类模型。另外,通过将训练图像对应的预测视觉词和该图像对应的类别标签映射到同一向量空间,可以拉近二者的距离,从而更准确地确定预测视觉词与类别标签的差距,进而更准确地调整提示向量,使得训练得到的图像分类模型具有更高的准确性。
8、在第一方面的一种可能实现方式中,上述训练装置将每张训练图像对应的预测视觉词和每张训练图像对应的类别标签映射到同一向量空间,包括:训练装置为每张训练图像对应的类别标签设置对应的类别向量,然后基于类别向量的维度对上述预测视觉词进行维度变换。变换后的预测视觉词与上述类别向量的维度相同,即二者在同一向量空间。
9、第二方面,本技术提供了一种图像分类方法。该方法可以应用于推理装置。推理装置获取待处理图像,然后基于图像分类模型对待处理图像进行分类,得到该图像的分类结果。其中,图像分类模型包括以下参数:预训练模型的参数和已优化的提示向量,预训练模型的参数包括掩码向量。预训练模型是基于掩码向量训练得到。已优化的提示向量是基于预训练模型和多张训练图像对应的目标输入向量优化得到,每张训练图像对应的目标输入向量是通过为每张训练图像拼接上述掩码向量和未优化的提示向量得到。
10、本技术提供的技术方案中,图像分类模型是通过掩码预测的方式训练得到,也就是说,利用图像分类模型实现的图像分类任务是通过掩码预测的方式实现。如此,缩小了上游任务和下游任务的差距,从而更充分地将上游任务中的预训练知识应用到下游任务中,提高了分类结果的准确性。
11、在第二方面的一种可能实现方式中,上述每张训练图像对应的目标输入向量是通过将预训练模型的掩码向量和未优化的提示向量与每张训练图像进行串接得到。其中,串接是指沿着向量的行的方向进行拼接。
12、应理解,在实际应用中为了方便处理训练图像,通常会将训练图像处理为行向量。而且,提示向量和掩码向量一般也会选用行向量,也就是说,训练图像、提示向量和掩码向量均为行向量,那么在拼接方式选用串接更为合适。
13、在第二方面的一种可能实现方式中,上述推理装置基于图像分类模型对待处理图像进行分类,包括:推理装置将预训练模型的掩码向量和已优化的提示向量与该图像进行串接,得到该图像对应的目标输入向量,然后基于该图像对应的目标输入向量对该图像进行分类。
14、在第二方面的一种可能实现方式中,上述图像分类模型还包括以下参数:至少一个类别向量,每个类别向量对应的一个类别标签。
15、在第二方面的一种可能实现方式中,上述推理装置基于待处理图像对应的目标输入向量对该图像进行分类,包括:推理装置基于待处理图像对应的目标输入向量确定该图像对应的预测视觉词,然后基于上述至少一个类别向量的维度对该图像对应的预测视觉词进行维度变换,之后再基于变换后的预测视觉词与上述至少一个类别向量的相似度对该图像进行分类。
16、上述实现方式中,变换后的预测视觉词和至少一个类别向量在同一向量空间,同一向量空间中的向量更容易计算距离,从而更准确地确定变换后的预测视觉词与哪一个类别向量的距离更近,更近的类别向量对应的类别标签即为图像的分类结果。因此,通过上述实现方式可以提高图像分类结果的准确性。
17、第三方面,本技术提供了一种训练装置。该装置包括获取模块、拼接模块和优化模块。获取模块用于获取训练数据集,训练数据集包括多张训练图像。拼接模块用于为训练数据集中的每张训练图像拼接提示向量和预训练模型的掩码向量,得到每张训练图像对应的目标输入向量,此处的预训练模型是通过掩码向量训练得到。优化模块用于基于该预训练模型和每张训练图像对应的目标输入向量对上述提示向量进行优化,得到可用于执行图像分类任务的图像分类模型。其中,该图像分类模型包括以下参数:预训练模型的参数(例如,上述掩码向量)和已优化的提示向量。
18、在第三方面的一种可能实现方式中,上述拼接模块用于将提示向量和预训练模型的掩码向量与每张训练图像进行串接。其中,串接是指沿着向量的行的方向进行拼接。
19、在第三方面的一种可能实现方式中,上述优化模块用于利用预训练模型对每张训练图像对应的目标输入向量中的掩码向量进行预测,得到每张训练图像对应的预测视觉词,将每张训练图像对应的预测视觉词和每张训练图像对应的类别标签映射到同一向量空间,基于映射后的预测视觉词和类别标签的相似度对提示向量进行优化。
20、在第三方面的一种可能实现方式中,上述优化模块用于为每张训练图像对应的类别标签设置对应的类别向量,基于类别向量的维度对上述预测视觉词进行维度变换,变换后的预测视觉词与上述类别向量的维度相同。
21、第四方面,本技术提供了一种推理装置。该装置包括获取模块和分类模块。获取模块用于获取待处理图像。分类模块用于基于图像分类模型对待处理图像进行分类,得到该图像的分类结果。其中,图像分类模型包括以下参数:预训练模型的参数和已优化的提示向量,预训练模型的参数包括掩码向量。预训练模型是基于掩码向量训练得到。已优化的提示向量是基于预训练模型和多张训练图像对应的目标输入向量优化得到,每张训练图像对应的目标输入向量是通过为每张训练图像拼接上述掩码向量和未优化的提示向量得到。
22、在第四方面的一种可能实现方式中,上述每张训练图像对应的目标输入向量是通过将预训练模型的掩码向量和未优化的提示向量与每张训练图像进行串接得到。其中,串接是指沿着向量的行的方向进行拼接。
23、在第四方面的一种可能实现方式中,上述图像分类模型还包括以下参数:至少一个类别向量,每个类别向量对应的一个类别标签。
24、在第四方面的一种可能实现方式中,上述分类模块用于将预训练模型的掩码向量和已优化的提示向量与该图像进行串接,得到该图像对应的目标输入向量,基于该图像对应的目标输入向量对该图像进行分类。
25、在第四方面的一种可能实现方式中,上述分类模块用于基于待处理图像对应的目标输入向量确定该图像对应的预测视觉词,基于上述至少一个类别向量的维度对该图像对应的预测视觉词进行维度变换,基于变换后的预测视觉词与上述至少一个类别向量的相似度对该图像进行分类。
26、第五方面,本技术提供了一种计算设备,该计算设备包括处理器和存储器,处理器执行存储器中的计算机程序代码以实现前述第一方面及第一方面的任一种实现方式所描述部分或全部方法,和/或,前述第二方面及第二方面的任一种实现方式所描述的部分或全部方法。
27、第六方面,本技术提供了一种计算设备集群。该计算设备集群包括至少一个计算设备,每个计算设备包括处理器和存储器。至少一个计算设备的处理器用于执行至少一个计算设备的存储器中存储的指令,以使得该计算设备集群执行前述第一方面及第一方面的任一种实现方式所描述部分或全部方法,和/或,前述第二方面及第二方面的任一种实现方式所描述的部分或全部方法。
28、第七方面,本技术提供了一种计算机程序产品。该计算机程序产品可以是包含指令的、能够运行在计算设备上或被储存在任何可用介质中的软件或程序产品。当该计算机程序产品在计算设备或计算设备集群上运行时,使得该计算设备或该计算设备集群执行前述第一方面及第一方面的任一种实现方式所描述部分或全部方法,和/或,前述第二方面及第二方面的任一种实现方式所描述的部分或全部方法。
29、第八方面,本技术提供了一种计算机可读存储介质。该计算机存储介质包括计算机程序指令,当所述计算机程序指令被计算设备或计算设备集群执行时,使得该计算设备或该计算设备集群执行前述第一方面及第一方面的任一种实现方式所描述部分或全部方法,和/或,前述第二方面及第二方面的任一种实现方式所描述的部分或全部方法。
本文地址:https://www.jishuxx.com/zhuanli/20240914/296132.html
版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌抄袭侵权/违法违规的内容, 请发送邮件至 YYfuon@163.com 举报,一经查实,本站将立刻删除。
下一篇
返回列表