FasterRCNN源码解析(四)——GeneralizedRCNNTransform部分
文章目录
前言
主要是对框架中对数据进行预处理的类进行解读,重点在于数据集以及标签的转化上进行剖析
一、前期训练部分
在我们train_res50_fpn.py脚本中,我们先通过读取解析PASCAL VOC2012数据集一文中的VOC2012DataSet
类来获取我们的数据集
# load train data set
# VOCdevkit -> VOC2012 -> ImageSets -> Main -> train.txt
train_data_set = VOC2012DataSet(VOC_root, data_transform["train"], "train.txt")
然后我们对获得的数据用torch.utils.data.DataLoader
进行载入,batch_size设为2,其中train_data_set.collate_fn
方法是将我们的数据集 img和target 各自打包放在一起
train_data_loader = torch.utils.data.DataLoader(train_data_set,
batch_size=batch_size,
shuffle=True,
num_workers=nw,
collate_fn=train_data_set.collate_fn)
二、GeneralizedRCNNTransform
这一模块的作用就是将图片标准化,并将图片缩放到统一尺大小,经过这一模块之后才得到真正的batch数据
1.标准化处理函数
( 图 像 数 据 − 均 值 ) / 方 差 (图像数据-均值)/方差 (图像数据−均值)/方差
def normalize(self, image):
"""标准化处理"""
dtype, device = image.dtype, image.device
mean = torch.as_tensor(self.image_mean, dtype=dtype, device=device)
std = torch.as_tensor(self.image_std, dtype=dtype, device=device)
# [:, None, None]: shape [3] -> [3, 1, 1]
return (image - mean[:, None, None]) / std[:, None, None]
2.resize函数
- 获取缩放因子
- 对图片进行缩放
- 对boxes进行缩放
def resize(self, image, target):
# type: (Tensor, Optional[Dict[str, Tensor]]) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]
"""
将图片缩放到指定的大小范围内,并对应缩放bboxes信息
Args:
image: 输入的图片
target: 输入图片的相关信息(包括bboxes信息)
Returns:
image: 缩放后的图片
target: 缩放bboxes后的图片相关信息
"""
# image shape is [channel, height, width]
h, w = image.shape[