Simple-fasterrcnn源码学习笔记 (1)


阅读该博客 https://blog.csdn.net/weixin_41424027/article/details/87896768#convet_caffe_pretrain_318
对一些点进行补充记录,便于自己记忆

voc_dataset.py

自定义数据读取
init:需要读入路径,把文件名都读入一个self.ids中
len:返回一个文件总数
getitem:给一个idx,能够返回img,bbox,label 最后需要用np.stack堆成np形式

import os
import xml.etree.ElementTree as ET
import numpy as np
from .util import read_image
VOC_CLASSES=('aeroplane',
    'bicycle',
    'bird',
    'boat',
    'bottle',
    'bus',
    'car',
    'cat',
    'chair',
    'cow',
    'diningtable',
    'dog',
    'horse',
    'motorbike',
    'person',
    'pottedplant',
    'sheep',
    'sofa',
    'train',
    'tvmonitor')

class VOCBboxDataset:
    def __init__(self,data_dir,split='trainval',use_difficult=False,return_difficult=False):
        id_list_file=os.path.join(data_dir,'ImageSets/Main/{0}.txt'.format(split))
        self.ids=[ id_.strip() for id_ in open(id_list_file)]
        self.data_dir=data_dir
        self.use_difficult=use_difficult
        self.return_difficult=return_difficult
        self.label_names=VOC_CLASSES

    def __len__(self):
        return len(self.ids)
    def __getitem__(self, i):
        id_=self.ids[i]
        anno= ET.parse(os.path.join(self.data_dir,'Annotations',id_+'.xml'))
        bbox=list()
        label=list()
        difficult=list()
        for obj in anno.findall('object'):
            if not self.use_difficult and int(obj.find('difficult').text)==1:
                continue
            difficult.append(int(obj.find('difficult').text))
            bndbox_anno=obj.find('bndbox')
            bbox.append([int(bndbox_anno.find(tag).text-1) for tag in ('ymin','xmin','ymax','ymax')])
            name=obj.find('name').text.lower().strip()
            label.append(VOC_CLASSES.index(name))
        bbox=np.stack(bbox).astype(np.float32)#trans the box from list to np.float32
        label=np.stack(label).astype(np.int32)
        difficult=np.array(difficult,dtype=np.bool).astype(np.uint8)
        img_file=os.path.join(self.data_dir,'JPEGImages',id_+'.jpg')
        img=read_image(img_file,color=True)
        return img,bbox,label,difficult

dataset.py

把之前自定义的data类型包进dataset.py中
init :读入config的参数,初始化 自定义数据,初始化trans
getiitem:读入img,bbox,label后,用transform处理,然后返回预处理后的数据的copy
len:返回数目总条数

tsf数据预处理过程

1.训练预处理
(1)图像先归一化到0-1
(2)比例转化 比如到600,1000
(3)标准化,减去均值除以标准差
(4)若图片需要水平翻转等数据增强操作,在这时候添加
其中比例转化和数据增强操作也需要对gt_bbox进行操作,保持图片和gt_box能够对应上

from __future__ import absolute_import
#绝对引入主要是针对python2.4及之前的版本的,这些版本在引入某一个.py文件时,
# 会首先从当前目录下查找是否有该文件。如果有,则优先引用当前包内的文件。而如果我们想引用python自带的.py文件时,则需要使用,
from __future__ import division
import torch as t
from data.voc_dataset import VOCBboxDataset
from skimage import transform as sktsf
from torchvision import transforms as tvtsf
from data import util
import numpy as np
from utils._config import opt

def inverse_normalize(img):# for the vis
    if opt.caffe_pretrain:
        img=img+(np.array([122.7717,115.9465,102.9801]).reshape(3,1,1))#add mean
        #caffe has no std reshape is for the broadcast
        return img[::-1,:,:]#caffe is [BGR,H,W] need to be [RGB,H,W]
    return (img*0.225+0.45).clip(min=0,max=1)*255#pytorch pretrain img from 0-1 add mean and multiply std need to mul 255

def pytorch_normalize(img):
    normalize=tvtsf.Normalize(mean=[0.485,0.456,0.406],std=[0.229,0.224,0.225])#pytorch method
    img=normalize(t.from_numpy(img))# 因为normalize方法只接受tensor对象,将img转化为tensor传入
    return img.numpy()#将标准化后的img再从tensor转化诶numpy

def caffe_normalize(img):
    img=img[::-1,:,:]#RGB2BGR 因为如果使用caffe_pretrain ,那么整个模型参数都是基于caffe训练的,需要先用bgr图篇进行训练,万了之后再inverse_nomalize还原
    img=img*255
    mean=np.array([122.7717,115.9465,102.9801]).reshape(3,1,1)
    img=(img-mean).astype(np.float32,copy=True)#返回一个float32类型的 img矩阵的副本
    return img

def preprocess(img,min_size=600,max_size=1000):#输入原始的img矩阵,返回取值0-1的,经过resize的,标准化后的img numpy矩阵 min_size就是输出图片的短边最长为600
    #长边最长为1000
    C,H,W=img.shape
    scale1=min_size/min(H,W)
    scale2=max_size/max(H,W)
    scale=min(scale1,scale2)#比较两个scale,看哪个才是主要影响所方因子,就是为防止 长边可能超过1000或者短边可能超过600
    img=img/255#先转化为0-1
    img=sktsf.resize(img,(C,H*scale,W*scale),mode='reflect')#resize来自skimage的transform
    if opt.caffe_pretrain:
        normalize=caffe_normalize
    else:
        normalize=pytorch_normalize
    return normalize(img)
#一张图可能有R个box和label box shape为(R,4) label shape 为(R,)
# 接受魔术方法get_example传来的一张图片的原始 img box label
    #                        返回resize和normalize后的img 对应处理后的box  以及label(没有处理)
class Transform():
    def __init__(self,min_size,max_size):
        self.min_size=min_size
        self.max_size=max_size

    def __call__(self, in_data):#使得类的实例也能像函数一样
        img,bbox,label=in_data
        _,H,W=img.shape
        img=preprocess(img,self.min_size,self.max_size)
        _,o_H,o_W=img.shape
        scale=o_H/H
        bbox=util.resize_bbox(bbox,(H,W),(o_H,o_W))

        #水平翻转
        img,params=util.random_flip(img,x_random=True,return_param=True)
        bbox=util.flip_bbox(bbox,(o_H,o_W),x_flip=params['x_flip'])#根据img水平翻转情况,对bbox也进行翻转
        return img,bbox,label,scale

class Dataset:
    # 取训练数据最大的类
    # 如果你读过pytorch源码
    # 你会发现其实并不用继承dataset类
    # 因为那个类是空
    # 的
    # 只实现了两个pass空方法
    # getitem和len两个魔术方法
    # 所以我们只要实现这两个方法就不用继承就可以传入DataLoader
    def __init__(self,opt):#opt是传进来的参数,来自utils.config 包含了voc_data 的路径
        self.opt=opt
        self.db=VOCBboxDataset(opt.voc_data_dir)
        self.tsf=Transform(opt.min_size,opt.max_size)

    def __getitem__(self, idx):
        ori_img,bbox,label,difficult=self.db.__getitem__(idx)
        img,bbox,label,scale=self.tsf((ori_img,bbox,label))
        return img.copy(),bbox.copy(),label.copy(),scale
    def __len__(self):
        return len(self.db)

class TestDataset:
    def __init__(self,opt,split='test',use_difficult=True):
        self.opt=opt
        self.db=VOCBboxDataset(opt.voc_data_dir,split=split,use_difficult=use_difficult)
    def __getitem__(self, idx):
        ori_img,bbox,label,difficult=self.db.__getitem__(idx)
        img=preprocess(ori_img)
        return img,ori_img.shape[1:],bbox,label,difficult#返回原图的HW 去掉了C
    def __len__(self):
        return len(self.db)
util.py
import numpy as np
from PIL import Image
import random


def read_image(path,dtype=np.float32,color=True):
    f=Image.open(path)
    try:
        if color:
            img=f.convert('RGB')
        else:
            img=f.convert('P')#gray
        img=np.asarray(img,dtype=dtype)#trans to np.float32 array
    finally:
        if hasattr(f,'close'):
            f.close()
    if img.ndim==2:#gray
        return img[np.newaxis]#add a new axis
    else:
        return img.transpose((2,0,1))# HWC 2 CHW

def resize_bbox(bbox,in_size,out_size):
    bbox=bbox.copy()
    y_scale=out_size[0]/in_size[0]
    x_scale=out_size[1]/in_size[1]
    bbox[:,0]=bbox[:,0]*y_scale
    bbox[:,1]=bbox[:,1]*x_scale
    bbox[:,2]=bbox[:,2]*y_scale
    bbox[:,3]=bbox[:,3]*y_scale
    return bbox

def random_flip(img,y_random=False,x_random=True,return_param=True,copy=False):
    #
    # img: 图片矩阵
    # y_random: 是否使用垂直随机翻
    # return_param:是否返回翻转状态
    # 一个dict很好懂
    # copy: 是否返回img的副本
    y_flip,x_flip=False,False
    if y_random:
        y_flip=random.choice([True,False])#随即选取是否翻转
    if x_random:
        x_flip=random.choice([True,False])
    if y_flip:
        img=img[:,::-1,:]#图片翻转,CHW H翻转
    if x_flip:
        img=img[:,:,::-1]
    if copy:
        img=img.copy()
    if return_param:
        # 因为我们这里只翻转了图片
        # 保留dict参数是为了翻转box时使用
        # 如果img水平翻转了
        # 那么x_flip = True
        # 我们记录这个参数
        # 以后也应当水平翻转这张图片的所有box
        # R个box
        return img,{'y_flip':y_flip,'x_flip':x_flip}
    else:
        return img

def flip_bbox(bbox,size,y_flip=False,x_flip=False):
    H,W=size
    bbox=bbox.copy()
    if y_flip:
        y_max=H-bbox[:,0]#H-ymin
        y_min=H-bbox[:,2]
        bbox[:,0]=y_min
        bbox[:,2]=y_max
    if x_flip:
        x_max=W-bbox[:,1]
        x_min=W-bbox[:,3]
        bbox[:,1]=x_max
        bbox[:,3]=x_min
    return bbox
_config.py
from pprint import pprint#打印出来更美观
class Config:
    #data
    voc_data_dir='/home/wrc/yuyijie/KITTI/VOCdevkit/VOC2007'
    min_size=600
    max_size=1000
    num_works=8
    test_num_works=8
    rpn_sigma=3.
    roi_sigma=1.
    # for optimizer
    wd=0.0005
    lr_decay=0.1
    lr=1e-3
    #vis
    env='faster-rcnn'
    port=8097#visdom 端口
    plot_every=40
    #preset
    data='voc'
    pretrained_model='vgg16'
    epoch=14
    use_adam=False
    use_chainer=False
    use_drop=False
    #debug
    debug_file='/tmp/debugf'
    test_num=10000
    #model
    load_path=None
    caffe_pretrain=False
    caffe_pretrain_path='checkpoints/vgg16_caffe.pth'
    def _parse(self,kwargs):#解析并设置用户设定的参数
        state_dict=self._state_dict()#读取Config类所有参数dict{para_name:para_value}
        for k,v in kwargs.items():#遍历用户传来的dict
            if k not in state_dict:
                raise ValueError('Unknow option:"--%s"'%k)
            setattr(self,k,v)#设置参数
        print('=============user config=========')
        pprint(self._state_dict())#打印参数
        print('=============end=========')


    def _state_dict(self):
        return {k:getattr(self,k) for k ,_ in Config.__dict__.items() if not k.startswith('_')}
    #字典解析,字典解析,Config.__dict__.items() 取出类中所有的函数、全局变量以及一些内置的属性
  # 前面我们设定的都是全局变量(键值对:比如min_size = 600),没有函数,而系统内置属性都是_打头的,
  # 所以我们要not k.startswith('_')  返回结果dict{para_name0:para_value0,para_name1:para_value1,....}
opt=Config()#创建config对象
这部分代码不熟悉的numpy,python,pytorch操作小结
bbox=np.stack(bbox).astype(np.float32)

trans the box from list to np.float32
想要转变np数据类型可以直接用astype

img[::-1,:,:]

[BGR,H,W] need to be [RGB,H,W]

numpy,tensor互转

t.from_numpy(img)
img.numpy()

对numpy数据增加一个新的轴

img[np.newaxis]#add a new axis
比如(100,2)新增之后会变成(1,100,2)

numpy数据维度互换位置

img.transpose((2,0,1))# HWC 2 CHW

从几个选项中选一个

random.choice([True,False])#随即选取是否翻转

  • 1
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值