PointPillarV2VNet
PointPillarV2VNet
# -*- coding: utf-8 -*-
# Author: Hao Xiang <haxiang@g.ucla.edu>, Runsheng Xu <rxx3386@ucla.edu>
# License: TDG-Attribution-NonCommercial-NoDistrib
import torch.nn as nn
from opencood.models.sub_modules.pillar_vfe import PillarVFE
from opencood.models.sub_modules.point_pillar_scatter import PointPillarScatter
from opencood.models.sub_modules.base_bev_backbone import BaseBEVBackbone
from opencood.models.sub_modules.downsample_conv import DownsampleConv
from opencood.models.sub_modules.naive_compress import NaiveCompressor
from opencood.models.fuse_modules.v2v_fuse import V2VNetFusion
class PointPillarV2VNet(nn.Module):
def __init__(self, args):
super(PointPillarV2VNet, self).__init__()
self.max_cav = args['max_cav']
# PIllar VFE
self.pillar_vfe = PillarVFE(args['pillar_vfe'],
num_point_features=4,
voxel_size=args['voxel_size'],
point_cloud_range=args['lidar_range'])
self.scatter = PointPillarScatter(args['point_pillar_scatter'])
self.backbone = BaseBEVBackbone(args['base_bev_backbone'], 64)
# used to downsample the feature map for efficient computation
self.shrink_flag = False
if 'shrink_header' in args:
self.shrink_flag = True
self.shrink_conv = DownsampleConv(args['shrink_header'])
self.compression = False
if args['compression'] > 0:
self.compression = True
self.naive_compressor = NaiveCompressor(256, args['compression'])
self.fusion_net = V2VNetFusion(args['v2vfusion'])
self.cls_head = nn.Conv2d(128 * 2, args['anchor_number'],
kernel_size=1)
self.reg_head = nn.Conv2d(128 * 2, 7 * args['anchor_number'],
kernel_size=1)
if args['backbone_fix']:
self.backbone_fix()
def backbone_fix(self):
"""
Fix the parameters of backbone during finetune on timedelay。
"""
for p in self.pillar_vfe.parameters():
p.requires_grad = False
for p in self.scatter.parameters():
p.requires_grad = False
for p in self.backbone.parameters():
p.requires_grad = False
if self.compression:
for p in self.naive_compressor.parameters():
p.requires_grad = False
if self.shrink_flag:
for p in self.shrink_conv.parameters():
p.requires_grad = False
for p in self.cls_head.parameters():
p.requires_grad = False
for p in self.reg_head.parameters():
p.requires_grad = False
def forward(self, data_dict):
voxel_features = data_dict['processed_lidar']['voxel_features']
voxel_coords = data_dict['processed_lidar']['voxel_coords']
voxel_num_points = data_dict['processed_lidar']['voxel_num_points']
record_len = data_dict['record_len']
pairwise_t_matrix = data_dict['pairwise_t_matrix']
batch_dict = {
'voxel_features': voxel_features,
'voxel_coords': voxel_coords,
'voxel_num_points': voxel_num_points,
'record_len': record_len}
# n, 4 -> n, c
batch_dict = self.pillar_vfe(batch_dict)
# n, c -> N, C, H, W
batch_dict = self.scatter(batch_dict)
batch_dict = self.backbone(batch_dict)
spatial_features_2d = batch_dict['spatial_features_2d']
# downsample feature to reduce memory
if self.shrink_flag:
spatial_features_2d = self.shrink_conv(spatial_features_2d)
# compressor
if self.compression:
spatial_features_2d = self.naive_compressor(spatial_features_2d)
fused_feature = self.fusion_net(spatial_features_2d,
record_len,
pairwise_t_matrix)
psm = self.cls_head(fused_feature)
rm = self.reg_head(fused_feature)
output_dict = {
'psm': psm,
'rm': rm}
return output_dict
以下是对提供的代码逐行解析:
# -*- coding: utf-8 -*-
# Author: Hao Xiang <haxiang@g.ucla.edu>
# License: TDG-Attribution-NonCommercial-NoDistrib
这些是文件的元信息注释,指定了文件的编码方式为 UTF-8,提供了作者信息和许可证信息。
"""
Implementation of V2VNet Fusion
"""
这是一个文档字符串,说明了该文件的目的,即实现了 V2VNet Fusion。
import torch
import torch.nn as nn
from opencood.models.sub_modules.torch_transformation_utils import \
get_discretized_transformation_matrix, get_transformation_matrix, \
warp_affine, get_rotated_roi
from opencood.models.sub_modules.convgru import ConvGRU
导入所需的模块和类,包括 PyTorch 库和自定义模块中的一些工具函数和 ConvGRU 类。
class V2VNetFusion(nn.Module):
def __init__(self, args):
super(V2VNetFusion, self).__init__()
定义了一个名为 V2VNetFusion 的类,它继承自 nn.Module,表示它是一个 PyTorch 模型。
in_channels = args['in_channels']
H, W = args['conv_gru']['H'], args[

博客对PointPillarV2VNet、V2VNetFusion和ConvGRU进行代码逐行解析。介绍了V2VNet Fusion的参数初始化、前向传播等,阐述变换矩阵离散化原因。还解析了V2VNetFusion类和ConvGRU模型,说明ConvGRU结合CNN和RNN处理时序数据的原理。
最低0.47元/天 解锁文章
1217

被折叠的 条评论
为什么被折叠?



