fasterRCNN数据处理部分
1、图片——进行缩放,归一化,旋转
2、bbox——解析,缩放,
3、label——解析类名,由类名获取相对应编号
4、scale——依据scale1,scale2 获取相适应的scale用于图片、bbox缩放
代码如下:
from __future__ import absolute_import
from __future__ import division
import torch as t
from data.voc_dataset import VOCBboxDataset
from skimage import transform as sktsf
from torchvision import transforms as tvtsf
from data import util
import numpy as np
from utils.config import opt
def inverse_normalize(img): #判断数据类型,caffe花四海pytorch.::-1是逆序 i:j:stride 格式为按步长stride提取i-j
if opt.caffe_pretrain:
img = img + (np.array([122.7717, 115.9465, 102.9801]).reshape(3, 1, 1)) #reshape(3,1,1)目的是img为(3,h,w)维度,每个通道加上对应值
return img[::-1, :, :]
# approximate un-normalize for visualize
return (img * 0.225 + 0.45).clip(min=0, max=1) * 255
def pytorch_normalze(img): #pytorch数据归一化处理
"""
https://github.com/pytorch/vision/issues/223
return appr -1~1 RGB
"""
#归一化函数,第一个参数是平均值,第二个参数是标准差
normalize = tvtsf.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
img = normalize(t.from_numpy(img))
return img.numpy()
def caffe_normalize(img): #caffe归一化处理
"""
return appr -125-125 BGR
"""
img = img[[2, 1, 0], :, :] # RGB-BGR
img = img * 255
mean = np.array([122.7717, 115.9465, 102.9801]).reshape(3, 1, 1)
img = (img - mean).astype(np.float32, copy=True)
return img
def preprocess(img, min_size=600, max_size=1000): #为图片处理核心,图片缩放scale1,scale2得到scale,进行缩放resize
C, H, W = img.shape
scale1 = min_size / min(H, W)
scale2 = max_size / max(H, W)
scale = min(scale1, scale2)
img = img / 255.
img = sktsf.resize(img, (C, H * scale, W * scale), mode='reflect',anti_aliasing=False)
# both the longer and shorter should be less than
# max_size and min_size
if opt.caffe_pretrain:
normalize = caffe_normalize
else:
normalize = pytorch_normalze
return normalize(img)
class Transform(object):
def __init__(self, min_size=600, max_size=1000):
self.min_size = min_size
self.max_size = max_size
def __call__(self, in_data):
img, bbox, label = in_data
_, H, W = img.shape
img = preprocess(img, self.min_size, self.max_size)
_, o_H, o_W = img.shape
scale = o_H / H
bbox = util.resize_bbox(bbox, (H, W), (o_H, o_W)) //按scale缩放bbox
# horizontally flip
img, params = util.random_flip(
img, x_random=True, return_param=True) //水平旋转,增加图片鲁棒性
bbox = util.flip_bbox(
bbox, (o_H, o_W), x_flip=params['x_flip'])
return img, bbox, label, scale
class Dataset:
def __init__(self, opt):
self.opt = opt
self.db = VOCBboxDataset(opt.voc_data_dir)
self.tsf = Transform(opt.min_size, opt.max_size)
def __getitem__(self, idx):
ori_img, bbox, label, difficult = self.db.get_example(idx) //idx为文件名,通过文件名解析需要的image,bbox,label
img, bbox, label, scale = self.tsf((ori_img, bbox, label)) ##对图片,bbox进行缩放旋转,多返回个缩放比例scale
# TODO: check whose stride is negative to fix this instead copy all
# some of the strides of a given numpy array are negative.
return img.copy(), bbox.copy(), label.copy(), scale
def __len__(self):
return len(self.db)
class TestDataset: ##与训练集类似
def __init__(self, opt, split='test', use_difficult=True):
self.opt = opt
self.db = VOCBboxDataset(opt.voc_data_dir, split=split, use_difficult=use_difficult)
def __getitem__(self, idx):
ori_img, bbox, label, difficult = self.db.get_example(idx)
img = preprocess(ori_img)
return img, ori_img.shape[1:], bbox, label, difficult
def __len__(self):
return len(self.db)
import os
import xml.etree.ElementTree as ET
import numpy as np
from .util import read_image
class VOCBboxDataset:
def __init__(self, data_dir, split='trainval',
use_difficult=False, return_difficult=False,
):
id_list_file = os.path.join(
data_dir, 'ImageSets/Main/{0}.txt'.format(split)) ##获取文件路径
self.ids = [id_.strip() for id_ in open(id_list_file)] ##文件名的生成器
self.data_dir = data_dir ##VOC2007文件夹路径
self.use_difficult = use_difficult
self.return_difficult = return_difficult
self.label_names = VOC_BBOX_LABEL_NAMES ##全局变量,类的集合
def __len__(self):
return len(self.ids) ##样本数量
def get_example(self, i):
id_ = self.ids[i] ##文件名
anno = ET.parse(
os.path.join(self.data_dir, 'Annotations', id_ + '.xml')) ##解析文件内容
bbox = list() ##初始化三个列表
label = list()
difficult = list()
for obj in anno.findall('object'): ##向列表中添加对应的内容
# when in not using difficult split, and the object is
# difficult, skipt it.
if not self.use_difficult and int(obj.find('difficult').text) == 1:
continue
difficult.append(int(obj.find('difficult').text))
bndbox_anno = obj.find('bndbox')
# subtract 1 to make pixel indexes 0-based
bbox.append([
int(bndbox_anno.find(tag).text) - 1
for tag in ('ymin', 'xmin', 'ymax', 'xmax')]) ##添加坐标,在text中查找ymin等内容,返回在bndbox_anno中的位置对应的内容也就是对应的坐标
name = obj.find('name').text.lower().strip()
label.append(VOC_BBOX_LABEL_NAMES.index(name)) ##添加label,为类的编号
bbox = np.stack(bbox).astype(np.float32) ##将bbox二维列表存入np.array中
label = np.stack(label).astype(np.int32)
# When `use_difficult==False`, all elements in `difficult` are False.
difficult = np.array(difficult, dtype=np.bool).astype(np.uint8) # PyTorch don't support np.bool
# Load a image
img_file = os.path.join(self.data_dir, 'JPEGImages', id_ + '.jpg') ##图片路径
img = read_image(img_file, color=True) ##获取图片
# if self.return_difficult:
# return img, bbox, label, difficult
return img, bbox, label, difficult ##数据处理完成返回img,bbox,label,difficult
__getitem__ = get_example
VOC_BBOX_LABEL_NAMES = (
'aeroplane',
'bicycle',
'bird',
'boat',
'bottle',
'bus',
'car',
'cat',
'chair',
'cow',
'diningtable',
'dog',
'horse',
'motorbike',
'person',
'pottedplant',
'sheep',
'sofa',
'train',
'tvmonitor')