simple-faster-rcnn-pytorch-master代码解读——数据预处理
继代码复现后,查看了许多博客代码详解、逐句理解来仔细研读具体的代码,结合自己的理解写下这篇博客。
data部分下的dataset.py、util.py、voc_dataset.py和_init_.py文件,是对数据进行读取并进行预处理。
1.dataset.py
from __future__ import absolute_import
from __future__ import division
import torch as t
from data.voc_dataset import VOCBboxDataset
from skimage import transform as sktsf
from torchvision import transforms as tvtsf
from data import util
import numpy as np
from utils.config import opt
#去正则化 caffe预训练模型输入为BGR 0-255图像,img维度[[B,G,R],H,W],而torchvision模型需要RGB 0-1的图像
def inverse_normalize(img):
if opt.caffe_pretrain: #首先判断是否采用caffe_pretrain进行预训练,如果是,对图片进行逆正则化处理,将图片处理为caffe模型需要的格式
img = img + (np.array([122.7717, 115.9465, 102.9801]).reshape(3, 1, 1))
return img[::-1, :, :] #python中[::-1]表示逆序输出
# approximate un-normalize for visualize
return (img * 0.225 + 0.45).clip(min=0, max=1) * 255 #pytorch_normalze中标准化采用0均值标准化,转化函数为(x-mean)/(standard deviation),现在乘以标准差再加上均值还原回去到0-255
#pytorch形式的正则化 输入的img为0-1
def pytorch_normalze(img):
"""
https://github.com/pytorch/vision/issues/223
return appr -1~1 RGB
"""
normalize = tvtsf.Normalize(mean=[0.485, 0.456, 0.406], #设置归一化参数 使用公式channel=(channel-mean)/std,转换到-1~1
std=[0.229, 0.224, 0.225])
img = normalize(t.from_numpy(img)) #进行归一化处理
return img.numpy()
#caffe形式的正则化 输入的img为0-1
def caffe_normalize(img):
"""
return appr -125-125 BGR
"""
img = img[[2, 1, 0], :, :] # RGB-BGR caffe的图片格式为BGR 0-255,所以img要从RGB转化为BGR格式,再img=img*255
img = img * 255
mean = np.array([122.7717, 115.9465, 102.9801]).reshape(3, 1, 1) #设置图片均值
img = (img - mean).astype(np.float32, copy=True) #减均值完成caffe形式的归一化处理
return img
#图片处理函数 所有输入图片标准化后选择缩放比,再归一化,进行调整,根据预训练模型是否是caffe_pretrain选择pytorch形式的正则化还是caffe_pretrain形式的正则化
def preprocess(img, min_size=600, max_size=1000): #根据论文中规定长宽分别不超过1000和600,按此比例进行缩放
"""Preprocess an image for feature extraction.
The length of the shorter edge is scaled to :obj:`self.min_size`.
After the scaling, if the length of the longer edge is longer than
:param min_size:
:obj:`self.max_size`, the image is scaled to fit the longer edge
to :obj:`self.max_size`.
After resizing the image, the image is subtracted by a mean image value
:obj:`self.mean`.
Args:
img (~numpy.ndarray): An image. This is in CHW and RGB format.
The range of its value is :math:`[0, 255]`.
Returns:
~numpy.ndarray: A preprocessed image.
"""
C, H, W = img.shape #读取图片通道,高度,宽度
scale1 = min_size / min(H, W)
scale2 = max_size / max(H, W)
scale = min(scale1, scale2) #设置缩放比 选择小的比例可以使大的小的图片都缩放到合适的位置
img = img / 255.
img = sktsf.resize(img, (C, H * scale, W * scale), mode='reflect',anti_aliasing=False) #将图片调整到(H*scale,W*scale)大小,且位于(min_size,max_size)之间
# both the longer and shorter should be less than
# max_size and min_size
if opt.caffe_pretrain: #若是caffe_pretrain选择caffe_pretrain形式的正则化
normalize = caffe_normalize
else:
normalize = pytorch_normalze #否则选择pytorch形式的正则化
return normalize(img)
#调整图像和边界框的大小,并进行随机水平翻转
class Transform(object):
def __init__(self, min_size=600, max_size=1000): #设置图片最大最小尺寸
self.min_size = min_size
self.max_size = max_size
def __call__(self, in_data):
img, bbox, label = in_data #读取img图片,bbox(bounding box)的图片和框,label的图片和标签
_, H, W = img.shape #读取图片的长宽
img = preprocess(img, self.min_size, self.max_size) #进行缩放然后归一化
_, o_H, o_W = img.shape #读取缩放后的长宽
scale = o_H / H #放缩前/放缩后=放缩比因子
bbox = util.resize_bbox(bbox, (H, W), (o_H, o_W)) #调整框的大小,按照与原框等比例缩放
# horizontally flip水平翻转 对图片和bbox进行同样的随机水平翻转
img, params = util.random_flip(
img, x_random=True, return_param=True)
bbox = util.flip_bbox(
bbox, (o_H, o_W), x_flip=params['x_flip'])
return img, bbox, label, scale
#生成训练的样本集
class Dataset:
def __init__(self, opt):
self.opt = opt
self.db = VOCBboxDataset(opt.voc_data_dir)
self.tsf = Transform(opt.min_size, opt.max_size)
def __getitem__(self, idx):
ori_img, bbox, label, difficult = self.db.get_example(idx) #调用get_example函数,可以理解为从数据集存储路径中将例子一个个获取出来,也就是每张图片的(img,bbox,label,difficult)
img, bbox, label, scale = self.tsf((ori_img, bbox, label)) #调用Transform函数将img和label进行最小最大值缩放归一化,重新调整bbox的大小,进行随机反转,最后将数据集返回
# TODO: check whose stride is negative to fix this instead copy all
# some of the strides of a given numpy array are negative.
return img.copy(), bbox.copy(), label.copy(), scale
def __len__(self):
return len(self.db)
#生成测试样本集
class TestDataset:
def __init__(self, opt, split='test', use_difficult=True):#在从voc_data_dir获取数据时使用split='test'也就是将test往后的部分数据送入self.db
self.opt = opt
self.db = VOCBboxDataset(opt.voc_data_dir, split=split, use_difficult=use_difficult) #设置了use_difficult
def __getitem__(self, idx):
ori_img, bbox, label, difficult = self.db.get_example(idx)
img = preprocess(ori_img) #对测试集图片直接进行缩放归一化
return img, ori_img.shape[1:], bbox, label, difficult #将数据集返回
def __len__(self):
return len(self.db)
2.util.py
import numpy as np
from PIL import Image
import random
#从文件中读取图像
def read_image(path, dtype=np.float32, color=True): #此函数用于从给定文件中读取图像。图像为CHW格式,其值的范围是[0,255]。如果`color=True`,则通道的顺序是RGB。
f = Image.open(path) #path (str)代表图像文件的路径
try:
if color: #color(bool)用来确定通道数。若为'true'代表通道数为3,通道的顺序是RGB;若为'False'则返回灰度图像
img = f.convert('RGB')
else:
img = f.convert('P')
img = np.asarray(img, dtype=dtype)
finally:
if hasattr(f, 'close'):
f.close()
if img.ndim == 2:
# reshape (H, W) -> (1, H, W)
return img[np.newaxis]
else:
# transpose (H, W, C) -> (C, H, W)
return img.transpose((2, 0, 1))
#根据图像大小调整边界框大小
def resize_bbox(bbox, in_size, out_size):
bbox = bbox.copy()
y_scale = float(out_size[0]) / in_size[0] #计算获得与原图一样的缩放比
x_scale = float(out_size[1]) / in_size[1]
bbox[:, 0] = y_scale * bbox[:, 0] #用与原图一样的缩放比对bbox进行缩放
bbox[:, 2] = y_scale * bbox[:, 2]
bbox[:, 1] = x_scale * bbox[:, 1]
bbox[:, 3] = x_scale * bbox[:, 3]
return bbox
#相应地翻转边界框
def flip_bbox(bbox, size, y_flip=False, x_flip=False):
H, W = size #调整大小前图像的高度和宽度
bbox = bbox.copy()
if y_flip: # y_flip表示根据图像的垂直翻转翻转边界框,无垂直翻转
y_max = H - bbox[:, 0]
y_min = H - bbox[:, 2]
bbox[:, 0] = y_min
bbox[:, 2] = y_max
if x_flip: #x_flip代表根据图像的水平翻转翻转边界框,进行了水平翻转,计算水平翻转的左下角和右上角的坐标
x_max = W - bbox[:, 1]
x_min = W - bbox[:, 3]
bbox[:, 1] = x_min
bbox[:, 3] = x_max
return bbox
#转换边界框以适合图像的裁剪区域
def crop_bbox(
bbox, y_slice=None, x_slice=None,
allow_outside_center=True, return_param=False): #bbox代表要转换的边界框;allow_outside_center值若为'false',将删除中心位于裁剪区域之外的边界框;return_param值若为'true'返回保持边界框的索引
t, b = _slice_to_bounds(y_slice)
l, r = _slice_to_bounds(x_slice)
crop_bb = np.array((t, l, b, r))
if allow_outside_center:
mask = np.ones(bbox.shape[0], dtype=bool)
else:
center = (bbox[:, :2] + bbox[:, 2:]) / 2.0
mask = np.logical_and(crop_bb[:2] <= center, center < crop_bb[2:]) \
.all(axis=1)
bbox = bbox.copy()
bbox[:, :2] = np.maximum(bbox[:, :2], crop_bb[:2])
bbox[:, 2:] = np.minimum(bbox[:, 2:], crop_bb[2:])
bbox[:, :2] -= crop_bb[:2]
bbox[:, 2:] -= crop_bb[:2]
mask = np.logical_and(mask, (bbox[:, :2] < bbox[:, 2:]).all(axis=1))
bbox = bbox[mask]
if return_param:
return bbox, {'index': np.flatnonzero(mask)}
else:
return bbox
def _slice_to_bounds(slice_):
if slice_ is None:
return 0, np.inf
if slice_.start is None:
l = 0
else:
l = slice_.start
if slice_.stop is None:
u = np.inf
else:
u = slice_.stop
return l, u
#转换边界框
def translate_bbox(bbox, y_offset=0, x_offset=0): #此方法主要与图像变换(如填充和裁剪)一起使用,后者将图像的左顶点从坐标(0,0)转换为(y,x)=(y{offset},x{offset})。
out_bbox = bbox.copy()
out_bbox[:, :2] += (y_offset, x_offset)
out_bbox[:, 2:] += (y_offset, x_offset)
return out_bbox
#在垂直或水平方向上随机翻转图像
def random_flip(img, y_random=False, x_random=False,
return_param=False, copy=False):
y_flip, x_flip = False, False
if y_random: #没有进行垂直翻转,False
y_flip = random.choice([True, False])
if x_random: #进行水平翻转,True
x_flip = random.choice([True, False]) #随机选择图片是否进行水平翻转
if y_flip:
img = img[:, ::-1, :]
if x_flip:
img = img[:, :, ::-1] #进行水平翻转,[::-1]为逆序输出
if copy:
img = img.copy()
if return_param:
return img, {'y_flip': y_flip, 'x_flip': x_flip}
else:
return img
3.voc_dataset.py
import os
import xml.etree.ElementTree as ET
import numpy as np
from .util import read_image
#PASCAL `VOC`的边界框数据集
class VOCBboxDataset:
def __init__(self, data_dir, split='trainval',
use_difficult=False, return_difficult=False,
):
# if split not in ['train', 'trainval', 'val']:
# if not (split == 'test' and year == '2007'):
# warnings.warn(
# 'please pick split from \'train\', \'trainval\', \'val\''
# 'for 2012 dataset. For 2007 dataset, you can pick \'test\''
# ' in addition to the above mentioned splits.'
# )
id_list_file = os.path.join(
data_dir, 'ImageSets/Main/{0}.txt'.format(split)) # id_list_file为trainval.txt,或者test.txt
self.ids = [id_.strip() for id_ in open(id_list_file)] # id是每个样本文件名
self.data_dir = data_dir
self.use_difficult = use_difficult
self.return_difficult = return_difficult
self.label_names = VOC_BBOX_LABEL_NAMES # 一共有20类
def __len__(self):
return len(self.ids)
def get_example(self, i): # 返回彩色图像和边界框。
id_ = self.ids[i]
anno = ET.parse(
os.path.join(self.data_dir, 'Annotations', id_ + '.xml')) #读取.xml文件(标签)
bbox = list()
label = list()
difficult = list()
for obj in anno.findall('object'): # 对xml标签文件进行解析,xml文件中包含object name和difficult(0或者1,0代表容易检测)
# when in not using difficult split, and the object is
# difficult, skipt it.
if not self.use_difficult and int(obj.find('difficult').text) == 1: # 被标为difficult的目标在测试中一般会被忽略
continue
difficult.append(int(obj.find('difficult').text))
bndbox_anno = obj.find('bndbox') # bndbox(xmin,ymin,xmax,ymax),表示框左下角和右上角坐标
# subtract 1 to make pixel indexes 0-based
bbox.append([
int(bndbox_anno.find(tag).text) - 1
for tag in ('ymin', 'xmin', 'ymax', 'xmax')])
name = obj.find('name').text.lower().strip() # 在框中object name
label.append(VOC_BBOX_LABEL_NAMES.index(name))
bbox = np.stack(bbox).astype(np.float32) # 所有object的bbox坐标存在列表里
label = np.stack(label).astype(np.int32) # 所有object的label存在列表里
# When `use_difficult==False`, all elements in `difficult` are False.
difficult = np.array(difficult, dtype=np.bool).astype(np.uint8) # PyTorch don't support np.bool
# PyTorch 不支持 np.bool,所以这里转换为uint8
# Load a image
img_file = os.path.join(self.data_dir, 'JPEGImages', id_ + '.jpg') # 按照图片的编号在/JPEGImages/取图片
img = read_image(img_file, color=True) # 如果'color'=True,则输出的通道的顺序是RGB
return img, bbox, label, difficult
__getitem__ = get_example #一般如果想使用索引访问元素时,就可以在类中定义这个方法(__getitem__(self, key) )
VOC_BBOX_LABEL_NAMES = (
'aeroplane',
'bicycle',
'bird',
'boat',
'bottle',
'bus',
'car',
'cat',
'chair',
'cow',
'diningtable',
'dog',
'horse',
'motorbike',
'person',
'pottedplant',
'sheep',
'sofa',
'train',
'tvmonitor')