代码使用解析
- 代码路径:https://github.com/SuperMHP/GUPNet
- 输入数据路径修改:config.yaml中root_dir
- 训练数据组织形式:KITTI文件夹下ImageSets、testing、training。
- 自动下载预训练权重:“http://dl.yf.io/dla/models/imagenet/dla34-ba72cf86.pth” to /home/yang/.cache/torch/checkpoints/dla34-ba72cf86.pth
datasets目录下kitti.py文件
-
数据增强与否的控制 :
self.data_augmentation = True if split in ['train', 'trainval','sample'] else False
-
修改图像分辨率:
self.resolution = np.array([1280, 720])
gupnet.py
输出类别的控制
self.heatmap = nn.Sequential(nn.Conv2d(channels[self.first_level], self.head_conv, kernel_size=3, padding=1, bias=True),
nn.ReLU(inplace=True),
nn.Conv2d(self.head_conv, 9, kernel_size=1, stride=1, padding=0, bias=True))