Faster-CNN完美解读和运行
**代码源码:**https://github.com/chenyuntc/simple-faster-rcnn-pytorch
运行错误解决:
部分模块代码解读:
运行结果:
代码解读:
网上很多的代码的解读,但通俗易懂的,并且解读也不够仔细(重要的代码后,有注解)目录:
trian.py
以train.py代码序列顺序的解读:
from __future__ import absolute_import
# though cupy is not used but without this line, it raise errors...
# import cupy as cp
import os
import ipdb
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from tqdm import tqdm
from util.config import opt
from data.dataset import Dataset, TestDataset, inverse_normalize
from model import FasterRCNNVGG16
from torch.utils import data as data_
from trainer import FasterRCNNTrainer
from util import array_tool as at
from util.vis_tool import visdom_bbox
from util.eval_tool import eval_detection_voc
import numpy as np
def eval(dataloader, faster_rcnn, test_num=10000):
pred_bboxes, pred_labels, pred_scores = list(), list(), list()
gt_bboxes, gt_labels, gt_difficults = list(), list(), list()
for ii, (imgs, sizes, gt_bboxes_, gt_labels_, gt_difficults_) in tqdm(enumerate(dataloader)):
sizes = [sizes[0][0].item(), sizes[1][0].item()]
pred_bboxes_, pred_labels_, pred_scores_ = faster_rcnn.predict(imgs, [sizes])
gt_bboxes += list(gt_bboxes_.numpy())
gt_labels += list(gt_labels_.numpy())
gt_difficults += list(gt_difficults_.numpy())
pred_bboxes += pred_bboxes_
pred_labels += pred_labels_
pred_scores += pred_scores_
if ii == test_num: break
result = eval_detection_voc(
pred_bboxes, pred_labels, pred_scores,
gt_bboxes, gt_labels, gt_difficults,
use_07_metric=True)
return result
def train(**kwargs):
opt._parse(kwargs)
dataset = Dataset(opt)
print('load data')
dataloader = data_.DataLoader(dataset, \
batch_size=1, \
shuffle=True, \
# pin_memory=True,
num_workers=opt.num_workers)
testset = TestDataset(opt)
# VOCBboxDataset作为数据读取库,然后依次从样例数据库中读取图片出来,
# 还调用了Transform(object)函数,完成图像的调整和随机反转工作
test_dataloader = data_.DataLoader(testset,
batch_size=1,
num_workers=opt.test_num_workers,
shuffle=False, \
pin_memory=True
)
#将数据装载到dataloader中,shuffle=True允许数据打乱排序,
# num_workers是设置数据分为几批处理,同样的将测试数据集也进行同样的处理,然后装载到test_dataloader
faster_rcnn = FasterRCNNVGG16()#接下来定义faster_rcnn=FasterRCNNVGG16()定义好模型
print('model construct completed')
trainer = FasterRCNNTrainer(faster_rcnn).cuda()
#设置trainer = FasterRCNNTrainer(faster_rcnn).cuda()
# 将FasterRCNNVGG16作为fasterrcnn的模型送入到FasterRCNNTrainer中并设置好GPU加速
if opt.load_path:
trainer.load(opt.load_path)
print('load pretrained model from %s' % opt.load_path)
trainer.vis.text(dataset.db.label_names, win='labels')
best_map = 0
lr_ = opt.lr
for epoch in range(opt.epoch):
trainer.reset_meters()
for ii, (img, bbox_, label_, scale) in tqdm(enumerate(dataloader)):
scale = at.scalar(scale)
img, bbox, label = img.cuda().float(), bbox_.cuda(), label_.cuda()
#然后从训练数据中枚举dataloader,设置好缩放范围,将img,bbox,label,scale全部设置为可gpu加速
trainer.train_step(img, bbox, label, scale)
#调用trainer.py中的函数trainer.train_step(img,bbox,label,scale)进行一次参数迭代优化过程
if (ii + 1) % opt.plot_every == 0:
if os.path.exists(opt.debug_file):
ipdb.set_trace()
# 判断数据读取次数是否能够整除plot_every(是否达到了画图次数),
# 如果达到判断debug_file是否存在,用ipdb工具设置断点,
# 调用trainer中的trainer.vis.plot_many(trainer.get_meter_data())将训练数据读取并上传完成可视化!
# plot loss
trainer.vis.plot_many(trainer.get_meter_data())
# plot groud truth bboxes
ori_img_ = inverse_normalize(at.tonumpy(img[0]))
gt_img = visdom_bbox(ori_img_,
at.tonumpy(bbox_[0]),
at.tonumpy(label_[0]))
trainer.vis.img('gt_img', gt_img)
# ori_img_的图片说明没问题+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
# plt.imshow(ori_img_.transpose((1,2,0)).astype(np.int32))
# plt.savefig('/home/dell/Desktop/AFA/Faster-Rcnn-Pytorch/simple-faster-rcnn-pytorch-master/'+'22.png')
#ori_img_的图片说明没问题+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
# 将每次迭代读取的图片用dataset文件里面的inverse_normalize()函数进行预处理,将处理后的图片调用Visdom_bbox
# 验证数据集
# plot predicti bboxes
_bboxes, _labels, _scores = trainer.faster_rcnn.predict([ori_img_], visualize=True)
pred_img = visdom_bbox(ori_img_,
at.tonumpy(_bboxes[0]),
at.tonumpy(_labels[0]).reshape(-1),
at.tonumpy(_scores[0]))
trainer.vis.img('pred_img', pred_img)
# rpn confusion matrix(meter)
trainer.vis.text(str(trainer.rpn_cm.value().tolist()), win='rpn_cm')
# roi confusion matrix
trainer.vis.img('roi_cm', at.totensor(trainer.roi_cm.conf, False).float())
eval_result = eval(test_dataloader, faster_rcnn, test_num=opt.test_num)
trainer.vis.plot('test_map', eval_result['map'])
#调用trainer.vis.img将Roi_cm将roi的可视化矩阵以图片的形式显示出来
lr_ = trainer.faster_rcnn.optimizer.param_groups[0]['lr']
log_info = 'lr:{}, map:{},loss:{}'.format(str(lr_),
str(eval_result['map']),
str(trainer.get_meter_data()))
trainer.vis.log(log_info)#将损失学习率以及map等信息及时显示更新
if eval_result['map'] > best_map:
best_map = eval_result['map']
best_path = trainer.save(best_map=best_map)#用if判断语句永远保存效果最好的map
if epoch == 9:
trainer.load(best_path)
trainer.faster_rcnn.scale_lr(opt.lr_decay)
lr_ = lr_ * opt.lr_decay#if判断语句如果学习的epoch达到了9就将学习率*0.1变成原来的十分之一
if epoch == 13:
break#判断epoch==13结束训练验证过程
if __name__ == '__main__':
# import fire
#
# fire.Fire()
train()
一、默认设置(util/config.py)
from pprint import pprint
# Default Configs for training
# NOTE that, config items could be overwriten by passing argument through command line.
# e.g. --voc-data-dir='./data/'
class Config:
# data
# voc_data_dir = '/home/cy/.chainer/dataset/pfnet/chainercv/voc/VOCdevkit/VOC2007/'
voc_data_dir = r'G:\Faster-Rcnn-Pytorch\simple-faster-rcnn-pytorch-master\VOCdevkit\VOC2007/'
min_size = 600 # image resize
max_size = 1000 # image resize
num_workers = 0
test_num_workers = 8
# sigma for l1_smooth_loss
rpn_sigma = 3.
roi_sigma = 1.
# param for optimizer
# 0.0005 in origin paper but 0.0001 in tf-faster-rcnn
weight_decay = 0.0005
lr_decay = 0.1 # 1e-3 -> 1e-4
lr = 1e-3
# visualization
env = 'faster-rcnn' # visdom env
port = 8097
plot_every = 40 # vis every N iter
# preset
data = 'voc'
pretrained_model = 'vgg16'
# training
epoch = 14
use_adam = False # Use Adam optimizer
use_chainer = False # try match everything as chainer
use_drop = False # use dropout in RoIHead
# debug
debug_file = '/tmp/debugf'
test_num = 10000
# model
load_path = None#与训练的模型
caffe_pretrain = False # use caffe pretrained model instead of torchvision
caffe_pretrain_path = 'checkpoints/vgg16_caffe.pth'
def _parse(self, kwargs):
state_dict = self._state_dict()
for k, v in kwargs.items():
if k not in state_dict:
raise ValueError('UnKnown Option: "--%s"' % k)
setattr(self, k, v)
print('======user config========')
pprint(self._state_dict())
print('==========end============')
def _state_dict(self):
return {k: getattr(self, k) for k, _ in Config.__dict__.items() \
if not k.startswith('_')}
opt = Config()
二、训练中的数据加载
data/voc_dataset.py
import os
import xml.etree.ElementTree as ET
import cv2
import numpy as np
import matplotlib.pyplot as plt
# from data.util import read_image
from PIL import Image
'''
1.加载图像和标签。
标签包括boundingbox和其名称标签,由于boundingbox和其标签有多个,所有使用循环读取。返回图像及其标签。
'''
class VOCBboxDataset:
def __init__(self, data_dir, split='trainval',
use_difficult=False, return_difficult=False,
):
# if split not in ['train', 'trainval', 'val']:
# if not (split == 'test' and year == '2007'):
# warnings.warn(
# 'please pick split from \'train\', \'trainval\', \'val\''
# 'for 2012 dataset. For 2007 dataset, you can pick \'test\''
# ' in addition to the above mentioned splits.'
# )
id_list_file = os.path.join(
data_dir, 'ImageSets/Main/{0}.txt'.format(split))#'G:/Faster-Rcnn-Pytorch/simple-faster-rcnn-pytorch-master/VOCdevkit/VOC2007/ImageSets/Main/trainval.txt'
self.ids = [id_.strip() for id_ in open(id_list_file)]#去除了换行符
# self.ids = [id_ for id_ in open(id_list_file)]#图片的编号
self.data_dir = data_dir
self.use_difficult = use_difficult
self.return_difficult = return_difficult
self.label_names = VOC_BBOX_LABEL_NAMES
def __len__(self):
return len(self.ids)
def get_example(self, i):
# print("RUN____________________")
id_ = self.ids[i]#当i为0:00005
anno = ET.parse(os.path.join(self.data_dir, 'Annotations', id_ + '.xml'))#'G:/Faster-Rcnn-Pytorch/simple-faster-rcnn-pytorch-master/VOCdevkit/VOC2007/Annotations/00005.xml'
#打开记事本000005.xml(ie浏览器)
bbox = list()
label = list()
difficult = list()
for obj in anno.findall('object'):
# when in not using difficult split, and the object is
# difficult, skipt it.
if not self.use_difficult and int(obj.find('difficult').text) == 1:#True and False==(int(obj.find('difficult').text)==0)
continue
#not self.use_difficult=True
difficult.append(int(obj.find('difficult').text))#0,0,0
bndbox_anno = obj.find('bndbox')#R=3
# 一张图片的加载
# - < bndbox >
#
# < xmin > 263 < / xmin >
#
# < ymin > 211 < / ymin >
#
# < xmax > 324 < / xmax >
#
# < ymax > 339 < / ymax >
#
# < / bndbox >
# - < bndbox >
#
# < xmin > 165 < / xmin >
#
# < ymin > 264 < / ymin >
#
# < xmax > 253 < / xmax >
#
# < ymax > 372 < / ymax >
#
# < / bndbox >
# - < bndbox >
#
# < xmin > 241 < / xmin >
#
# < ymin > 194 < / ymin >
#
# < xmax > 295 < / xmax >
#
# < ymax > 299 < / ymax >
#
# < / bndbox >
bbox.append([int(bndbox_anno.find(tag).text) - 1 for tag in ('ymin', 'xmin', 'ymax', 'xmax')])
# [[210. 262. 338. 323.]
# [263. 164. 371. 252.]
# [193.240.298. 294.]]
name = obj.find('name').text.lower().strip()#'chair'
label.append(VOC_BBOX_LABEL_NAMES.index(name))# index为8
bbox = np.stack(bbox).astype(np.float32)#[3,4]
label = np.stack(label).astype(np.int32)#[3,1]
# When `use_difficult==False`, all elements in `difficult` are False.
difficult = np.array(difficult, dtype=np.bool).astype(np.uint8) # PyTorch don't support np.bool
# Load a image
img_file = os.path.join(self.data_dir, 'JPEGImages', id_ + '.jpg')#'G:/Faster-Rcnn-Pytorch/simple-faster-rcnn-pytorch-master/VOCdevkit/VOC2007/JPEGImages/00005.jpg'
img = read_image(img_file, color=True)
# img.show()
# if self.return_difficult:
# return img, bbox, label, difficult
return img, bbox, label, difficult
__getitem__ = get_example
def read_image(path, dtype=np.float32, color=True):#图片格式很重要
try:
f = Image.open(path)#PIL读进来的图像是一个对象
except IOError:
print('fail to load image!')
try:
if color:
img = f.convert('RGB')
else:
img = f.convert('P')
img= np.asarray(img, dtype=dtype)
finally:
if hasattr(f, 'close'):
f.close()
if img.ndim == 2:
# reshape (H, W) -> (1, H, W)
return img[np.newaxis]
else:
# transpose (H, W, C) -> (C, H, W)
return img.transpose((2, 0, 1))#转3-H-W
# return img1
VOC_BBOX_LABEL_NAMES = (
'aeroplane',
'bicycle',
'bird',
'boat',
'bottle',
'bus',
'car',
'cat',
'chair',
'cow',
'diningtable',
'dog',
'horse',
'motorbike',
'person',
'pottedplant',
'sheep',
'sofa',
'train',
'tvmonitor')
if __name__ == '__main__':
data = VOCBboxDataset('G:/Faster-Rcnn-Pytorch/simple-faster-rcnn-pytorch-master/VOCdevkit/VOC2007/')[0]
img=pytorch_normalze(data[0])
data_one=data[0].transpose((1,2,0)).astype(np.int32)
plt.imshow(data_one)
plt.show()
# data_one.show()
plt.axis('off')
print(data[1])
print(data[2])
print(data[3])
脚本运行结果:
data/dataset.py
from __future__ import absolute_import
from __future__ import division
from data.voc_dataset import VOCBboxDataset
import torch as t
from skimage import