多任务预测模型的训练、预测方法以及存储介质与流程
- 国知局
- 2024-08-05 12:04:36
本发明涉及计算机,尤其涉及一种多任务预测模型的训练、预测方法以及存储介质。
背景技术:
1、基于深度学习的3d目标检测,逐渐成为智能驾驶主流方法,但是,在通用障碍物检测分域,一直以来是以传统规则为主,基于lidar深度学习的少之又少。同样,基于lidar的通用目标检测,实时性同样是该方法落地的障碍。虽然,近几年人工智能芯片的高速发展,对深度学习模型部署的实时性带来巨大化提升,但仍不能解决很多实时性问题。基于传统规则为主的通用目标检测,其对于算例和计算量的需求较高,再融合带有分类标签的3d目标检测结果,都会消耗一部分计算资源,增加部署难度,这样在整个自动驾驶过程会造成推理延迟。
2、传统的解决方法一般是采用cpu计算方式,如果在nvidia平台,可以转换到cuda计算,来加速计算结果,但是如果在其他平台,这样的ai算力一般不支持该类型的加速计算,因此,即使在nv平台也不能从根本上解决问题。
技术实现思路
1、本发明提供了一种多任务预测模型的训练、预测方法以及存储介质,可以降低多任务预测过程中的计算复杂度,提高多任务预测的效率以及实时性。
2、一方面,本发明提供了一种多任务预测模型的训练方法,所述方法包括:
3、获取样本点云数据;所述样本点云数据标注了样本对象属性标签、样本中心点标签以及样本偏移量标签;所述样本对象属性包括样本对象的热力图、尺寸信息、类别信息中的至少一种;
4、将所述样本点云数据输入预设模型,基于所述预设模型的点云特征提取网络对所述样本点云数据进行点云特征提取,得到样本点云特征;所述样本点云特征用于提取所述样本对象的多个样本语义特征;
5、基于所述预设模型的对象属性预测网络对所述多个样本语义特征进行对象属性预测,得到样本预测对象属性;
6、基于所述预设模型的聚类网络对所述多个样本语义特征进行聚类处理,得到样本预测中心点以及样本预测偏移量;
7、根据所述样本预测对象属性、所述样本对象属性标签、所述样本预测中心点、所述样本中心点标签、所述样本预测偏移量以及所述样本偏移量标签,训练所述预设模型,得到多任务预测模型。
8、可选的,所述聚类网络包括聚类子网络、中心点预测网络以及偏移量预测网络,所述基于所述预设模型的聚类网络对所述多个样本语义特征进行聚类处理,得到样本预测中心点以及样本预测偏移量,包括:
9、基于所述聚类子网络对所述多个样本语义特征进行聚类处理,得到样本聚类结果;
10、基于所述中心点预测网络,对所述样本聚类结果进行中心点预测处理,得到所述样本预测中心点;
11、基于所述偏移量预测网络,对所述样本聚类结果进行偏移量预测处理,得到所述样本预测偏移量。
12、可选的,所述根据所述样本预测对象属性、所述样本对象属性标签、所述样本预测中心点、所述样本中心点标签、所述样本预测偏移量以及所述样本偏移量标签,训练所述预设模型,得到多任务预测模型,包括:
13、根据所述样本预测对象属性与所述样本对象属性标签之间的第一差异、所述样本预测中心点与所述样本中心点标签之间的第二差异以及所述样本预测偏移量与所述样本偏移量标签之间的第三差异,确定目标损失信息;
14、根据所述目标损失信息调整所述预设模型的模型参数直至满足训练结束条件,并将训练结束时的所述预设模型确定为所述多任务预测模型。
15、可选的,所述根据所述样本预测对象属性与所述样本对象属性标签之间的第一差异、所述样本预测中心点与所述样本中心点标签之间的第二差异以及所述样本预测偏移量与所述样本偏移量标签之间的第三差异,确定目标损失信息,包括:
16、根据所述样本预测对象属性与所述样本对象属性标签之间的第一差异,确定第一损失信息;
17、根据所述样本预测中心点与所述样本中心点标签之间的第二差异,确定第二损失信息;
18、根据所述样本预测偏移量与所述样本偏移量标签之间的第三差异,确定第三损失信息;
19、根据所述第一损失信息、所述第二损失信息以及所述第三损失信息,确定所述目标损失信息。
20、可选的,所述基于所述预设模型的点云特征提取网络对所述样本点云数据进行点云特征提取,得到样本点云特征,包括:
21、基于所述预设模型的点云特征提取网络提取所述样本点云数据的竖轴数据以及点云强度;
22、基于所述样本点云数据的竖轴数据以及点云强度,生成样本伪图像以及样本立柱;
23、基于所述样本伪图像以及样本立柱进行点云特征提取,得到所述样本点云特征。
24、可选的,所述预设模型还包括语义分割网络,所述基于所述预设模型的点云特征提取网络对所述样本点云数据进行点云特征提取,得到样本点云特征之后,所述方法还包括:
25、基于所述语义分割网络对所述样本点云特征进行语义分割处理,得到多个样本语义特征。
26、可选的,所述获取样本点云数据包括:
27、获取样本初始点云数据,所述样本初始点云数据标注了所述样本对象属性标签以及样本分割结果标签;
28、根据所述样本初始点云数据以及所述样本分割结果标签,生成样本中心点标签以及样本偏移量标签;
29、对所述样本初始点云数据标注所述样本对象属性标签、所述样本中心点标签以及所述样本偏移量标签,得到所述样本点云数据。
30、另一方面还提供了一种多任务预测模型的预测方法,所述方法包括:
31、获取待测点云数据;
32、将所述待测点云数据输入所述多任务预测模型进行多任务预测处理,得到目标预测对象属性、目标预测中心点以及目标预测偏移量中的至少一项;所述多任务预测模型采用上述的方法训练得到。
33、可选的,所述将所述待测点云数据输入所述多任务预测模型进行多任务预测处理,得到目标预测对象属性、目标预测中心点以及目标预测偏移量中的至少一项,包括:
34、将所述待测点云数据输入所述多任务预测模型,基于所述多任务预测模型的点云特征提取网络对所述待测点云数据进行点云特征提取,得到目标点云特征;所述目标点云特征用于提取多个目标语义特征;
35、基于所述多任务预测模型的对象属性预测网络对所述多个目标语义特征进行对象属性预测,得到目标预测对象属性;
36、基于所述多任务预测模型的聚类网络对所述多个目标语义特征进行聚类处理,得到目标预测中心点以及目标预测偏移量。
37、另一方面提供了一种多任务预测模型的训练装置,所述装置包括:
38、样本点云获取模块,用于获取样本点云数据;所述样本点云数据标注了样本对象属性标签、样本中心点标签以及样本偏移量标签;所述样本对象属性包括样本对象的热力图、尺寸信息、类别信息中的至少一种;
39、样本特征确定模块,用于将所述样本点云数据输入预设模型,基于所述预设模型的点云特征提取网络对所述样本点云数据进行点云特征提取,得到样本点云特征;所述样本点云特征用于提取所述样本对象的多个样本语义特征;
40、样本属性确定模块,用于基于所述预设模型的对象属性预测网络对所述多个样本语义特征进行对象属性预测,得到样本预测对象属性;
41、样本偏移量确定模块,用于基于所述预设模型的聚类网络对所述多个样本语义特征进行聚类处理,得到样本预测中心点以及样本预测偏移量;
42、模型确定模块,用于根据所述样本预测对象属性、所述样本对象属性标签、所述样本预测中心点、所述样本中心点标签、所述样本预测偏移量以及所述样本偏移量标签,训练所述预设模型,得到多任务预测模型。
43、可选的,所述聚类网络包括聚类子网络、中心点预测网络以及偏移量预测网络,所述样本偏移量确定模块包括:
44、聚类结果确定单元,用于基于所述聚类子网络对所述多个样本语义特征进行聚类处理,得到样本聚类结果;
45、中心点确定单元,用于基于所述中心点预测网络,对所述样本聚类结果进行中心点预测处理,得到所述样本预测中心点;
46、偏移量确定单元,用于基于所述偏移量预测网络,对所述样本聚类结果进行偏移量预测处理,得到所述样本预测偏移量。
47、可选的,所述模型确定模块包括:
48、目标损失确定单元,用于根据所述样本预测对象属性与所述样本对象属性标签之间的第一差异、所述样本预测中心点与所述样本中心点标签之间的第二差异以及所述样本预测偏移量与所述样本偏移量标签之间的第三差异,确定目标损失信息;
49、模型确定单元,用于根据所述目标损失信息调整所述预设模型的模型参数直至满足训练结束条件,并将训练结束时的所述预设模型确定为所述多任务预测模型。
50、可选的,所述目标损失确定单元包括:
51、第一确定子单元,用于根据所述样本预测对象属性与所述样本对象属性标签之间的第一差异,确定第一损失信息;
52、第二确定子单元,用于根据所述样本预测中心点与所述样本中心点标签之间的第二差异,确定第二损失信息;
53、第三确定子单元,用于根据所述样本预测偏移量与所述样本偏移量标签之间的第三差异,确定第三损失信息;
54、目标确定子单元,用于根据所述第一损失信息、所述第二损失信息以及所述第三损失信息,确定所述目标损失信息。
55、可选的,所述样本特征确定模块包括:
56、点云数据提取单元,用于基于所述预设模型的点云特征提取网络提取所述样本点云数据的竖轴数据以及点云强度;
57、样本图像生成单元,用于基于所述样本点云数据的竖轴数据以及点云强度,生成样本伪图像以及样本立柱;
58、点云确定单元,用于基于所述样本伪图像以及样本立柱进行点云特征提取,得到所述样本点云特征。
59、可选的,所述预设模型还包括语义分割网络,所述装置还包括:
60、样本语义特征确定模块,用于基于所述语义分割网络对所述样本点云特征进行语义分割处理,得到多个样本语义特征。
61、可选的,样本点云获取模块包括:
62、初始点云获取单元,用于获取样本初始点云数据,所述样本初始点云数据标注了所述样本对象属性标签以及样本分割结果标签;
63、标签生成单元,用于根据所述样本初始点云数据以及所述样本分割结果标签,生成样本中心点标签以及样本偏移量标签;
64、标签标注单元,用于对所述样本初始点云数据标注所述样本对象属性标签、所述样本中心点标签以及所述样本偏移量标签,得到所述样本点云数据。
65、另一方面提供了一种多任务预测模型的预测装置,所述装置包括:
66、待测点云获取模块,用于获取待测点云数据;
67、多任务预测模块,用于将所述待测点云数据输入所述多任务预测模型进行多任务预测处理,得到目标预测对象属性、目标预测中心点以及目标预测偏移量中的至少一项;所述多任务预测模型采用上述的方法训练得到。
68、可选的,所述多任务预测模块包括:
69、目标点云确定单元,用于将所述待测点云数据输入所述多任务预测模型,基于所述多任务预测模型的点云特征提取网络对所述待测点云数据进行点云特征提取,得到目标点云特征;所述目标点云特征用于提取多个目标语义特征;
70、目标属性预测单元,用于基于所述多任务预测模型的对象属性预测网络对所述多个目标语义特征进行对象属性预测,得到目标预测对象属性;
71、目标偏移量预测单元,用于基于所述多任务预测模型的聚类网络对所述多个目标语义特征进行聚类处理,得到目标预测中心点以及目标预测偏移量。
72、另一方面提供了一种电子设备,所述设备包括处理器和存储器,所述存储器中存储有至少一条指令或至少一段程序,所述至少一条指令或所述至少一段程序由所述处理器加载并执行以实现如上所述的多任务预测模型的训练方法或预测方法。
73、另一方面提供了一种计算机存储介质,所述计算机存储介质存储有至少一条指令或至少一段程序,所述至少一条指令或至少一段程序由处理器加载并执行以实现如上所述的多任务预测模型的训练方法或预测方法。
74、另一方面提供了一种计算机程序产品或计算机程序,该计算机程序产品或计算机程序包括计算机指令,该计算机指令存储在计算机可读存储介质中。计算机设备的处理器从计算机可读存储介质读取该计算机指令,处理器执行该计算机指令,使得该计算机设备执行以实现如上所述的多任务预测模型的训练方法或预测方法。
75、本发明提供的多任务预测模型的训练、预测方法以及存储介质,具有如下技术效果:
76、本发明获取样本点云数据;样本点云数据标注了样本对象属性标签、样本中心点标签以及样本偏移量标签;样本对象属性包括样本对象的热力图、尺寸信息、类别信息中的至少一种;将样本点云数据输入预设模型,基于预设模型的点云特征提取网络对样本点云数据进行点云特征提取,得到样本点云特征;样本点云特征用于提取样本对象的多个样本语义特征;基于预设模型的对象属性预测网络对多个样本语义特征进行对象属性预测,得到样本预测对象属性;基于预设模型的聚类网络对多个样本语义特征进行聚类处理,得到样本预测中心点以及样本预测偏移量;再根据样本预测对象属性、样本对象属性标签、样本预测中心点、样本中心点标签、样本预测偏移量以及样本偏移量标签,训练预设模型,得到多任务预测模型。本发明训练得到的模型可以降低多任务预测过程中的计算复杂度,提高多任务预测的效率以及实时性。
本文地址:https://www.jishuxx.com/zhuanli/20240802/260961.html
版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌抄袭侵权/违法违规的内容, 请发送邮件至 YYfuon@163.com 举报,一经查实,本站将立刻删除。
下一篇
返回列表