Faster-CNN完美解读和运行(1)

bbox_tools 回归函数 https://www.cnblogs.com/king-lps/p/8981222.html
摘要由CSDN通过智能技术生成

Faster-CNN完美解读和运行

**代码源码:**https://github.com/chenyuntc/simple-faster-rcnn-pytorch
运行错误解决:
部分模块代码解读:
运行结果:test-Loss
rpn_cls_loss
rpn_loc_loss
roi_cls_loss
roi_loc_loss
total_loss
map
gt_img
预测图
代码解读:
网上很多的代码的解读,但通俗易懂的,并且解读也不够仔细(重要的代码后,有注解)目录:
在这里插入图片描述
trian.py
以train.py代码序列顺序的解读:

from __future__ import  absolute_import
# though cupy is not used but without this line, it raise errors...
# import cupy as cp
import os

import ipdb
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from tqdm import tqdm

from util.config import opt
from data.dataset import Dataset, TestDataset, inverse_normalize
from model import FasterRCNNVGG16
from torch.utils import data as data_
from trainer import FasterRCNNTrainer
from util import array_tool as at
from util.vis_tool import visdom_bbox
from util.eval_tool import eval_detection_voc
import numpy as np

def eval(dataloader, faster_rcnn, test_num=10000):
    pred_bboxes, pred_labels, pred_scores = list(), list(), list()
    gt_bboxes, gt_labels, gt_difficults = list(), list(), list()
    for ii, (imgs, sizes, gt_bboxes_, gt_labels_, gt_difficults_) in tqdm(enumerate(dataloader)):
        sizes = [sizes[0][0].item(), sizes[1][0].item()]
        pred_bboxes_, pred_labels_, pred_scores_ = faster_rcnn.predict(imgs, [sizes])
        gt_bboxes += list(gt_bboxes_.numpy())
        gt_labels += list(gt_labels_.numpy())
        gt_difficults += list(gt_difficults_.numpy())
        pred_bboxes += pred_bboxes_
        pred_labels += pred_labels_
        pred_scores += pred_scores_
        if ii == test_num: break

    result = eval_detection_voc(
        pred_bboxes, pred_labels, pred_scores,
        gt_bboxes, gt_labels, gt_difficults,
        use_07_metric=True)
    return result


def train(**kwargs):
    opt._parse(kwargs)

    dataset = Dataset(opt)
    print('load data')
    dataloader = data_.DataLoader(dataset, \
                                  batch_size=1, \
                                  shuffle=True, \
                                  # pin_memory=True,
                                  num_workers=opt.num_workers)
    testset = TestDataset(opt)
    # VOCBboxDataset作为数据读取库,然后依次从样例数据库中读取图片出来,
    # 还调用了Transform(object)函数,完成图像的调整和随机反转工作
    test_dataloader = data_.DataLoader(testset,
                                       batch_size=1,
                                       num_workers=opt.test_num_workers,
                                       shuffle=False, \
                                       pin_memory=True
                                       )
    #将数据装载到dataloader中,shuffle=True允许数据打乱排序,
    # num_workers是设置数据分为几批处理,同样的将测试数据集也进行同样的处理,然后装载到test_dataloader
    faster_rcnn = FasterRCNNVGG16()#接下来定义faster_rcnn=FasterRCNNVGG16()定义好模型
    print('model construct completed')

    trainer = FasterRCNNTrainer(faster_rcnn).cuda()
    #设置trainer = FasterRCNNTrainer(faster_rcnn).cuda()
    # 将FasterRCNNVGG16作为fasterrcnn的模型送入到FasterRCNNTrainer中并设置好GPU加速
    if opt.load_path:
        trainer.load(opt.load_path)
        print('load pretrained model from %s' % opt.load_path)

    trainer.vis.text(dataset.db.label_names, win='labels')

    best_map = 0
    lr_ = opt.lr

    for epoch in range(opt.epoch):

        trainer.reset_meters()

        for ii, (img, bbox_, label_, scale) in tqdm(enumerate(dataloader)):
            scale = at.scalar(scale)
            img, bbox, label = img.cuda().float(), bbox_.cuda(), label_.cuda()

    #然后从训练数据中枚举dataloader,设置好缩放范围,将img,bbox,label,scale全部设置为可gpu加速
            trainer.train_step(img, bbox, label, scale)
    #调用trainer.py中的函数trainer.train_step(img,bbox,label,scale)进行一次参数迭代优化过程

            if (ii + 1) % opt.plot_every == 0:
                if os.path.exists(opt.debug_file):
                    ipdb.set_trace()
                # 判断数据读取次数是否能够整除plot_every(是否达到了画图次数),
                # 如果达到判断debug_file是否存在,用ipdb工具设置断点,
                # 调用trainer中的trainer.vis.plot_many(trainer.get_meter_data())将训练数据读取并上传完成可视化!
                # plot loss
                trainer.vis.plot_many(trainer.get_meter_data())

                # plot groud truth bboxes
                ori_img_ = inverse_normalize(at.tonumpy(img[0]))
                gt_img = visdom_bbox(ori_img_,
                                     at.tonumpy(bbox_[0]),
                                     at.tonumpy(label_[0]))

                trainer.vis.img('gt_img', gt_img)
# ori_img_的图片说明没问题+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
#                 plt.imshow(ori_img_.transpose((1,2,0)).astype(np.int32))
#                 plt.savefig('/home/dell/Desktop/AFA/Faster-Rcnn-Pytorch/simple-faster-rcnn-pytorch-master/'+'22.png')
#ori_img_的图片说明没问题+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++


                # 将每次迭代读取的图片用dataset文件里面的inverse_normalize()函数进行预处理,将处理后的图片调用Visdom_bbox
                # 验证数据集
                # plot predicti bboxes
                _bboxes, _labels, _scores = trainer.faster_rcnn.predict([ori_img_], visualize=True)

                pred_img = visdom_bbox(ori_img_,
                                       at.tonumpy(_bboxes[0]),
                                       at.tonumpy(_labels[0]).reshape(-1),
                                       at.tonumpy(_scores[0]))
                trainer.vis.img('pred_img', pred_img)

                # rpn confusion matrix(meter)
                trainer.vis.text(str(trainer.rpn_cm.value().tolist()), win='rpn_cm')
                # roi confusion matrix
                trainer.vis.img('roi_cm', at.totensor(trainer.roi_cm.conf, False).float())

        eval_result = eval(test_dataloader, faster_rcnn, test_num=opt.test_num)
        trainer.vis.plot('test_map', eval_result['map'])
        #调用trainer.vis.img将Roi_cm将roi的可视化矩阵以图片的形式显示出来
        lr_ = trainer.faster_rcnn.optimizer.param_groups[0]['lr']
        log_info = 'lr:{}, map:{},loss:{}'.format(str(lr_),
                                                  str(eval_result['map']),
                                                  str(trainer.get_meter_data()))
        trainer.vis.log(log_info)#将损失学习率以及map等信息及时显示更新

        if eval_result['map'] > best_map:
            best_map = eval_result['map']
            best_path = trainer.save(best_map=best_map)#用if判断语句永远保存效果最好的map

        if epoch == 9:
            trainer.load(best_path)
            trainer.faster_rcnn.scale_lr(opt.lr_decay)
            lr_ = lr_ * opt.lr_decay#if判断语句如果学习的epoch达到了9就将学习率*0.1变成原来的十分之一

        if epoch == 13: 
            break#判断epoch==13结束训练验证过程


if __name__ == '__main__':
    # import fire
    #
    # fire.Fire()
    train()

一、默认设置(util/config.py)

from pprint import pprint


# Default Configs for training
# NOTE that, config items could be overwriten by passing argument through command line.
# e.g. --voc-data-dir='./data/'

class Config:
    # data
    # voc_data_dir = '/home/cy/.chainer/dataset/pfnet/chainercv/voc/VOCdevkit/VOC2007/'
    voc_data_dir = r'G:\Faster-Rcnn-Pytorch\simple-faster-rcnn-pytorch-master\VOCdevkit\VOC2007/'
    min_size = 600  # image resize
    max_size = 1000 # image resize
    num_workers = 0
    test_num_workers = 8

    # sigma for l1_smooth_loss
    rpn_sigma = 3.
    roi_sigma = 1.

    # param for optimizer
    # 0.0005 in origin paper but 0.0001 in tf-faster-rcnn
    weight_decay = 0.0005
    lr_decay = 0.1  # 1e-3 -> 1e-4
    lr = 1e-3


    # visualization
    env = 'faster-rcnn'  # visdom env
    port = 8097
    plot_every = 40  # vis every N iter

    # preset
    data = 'voc'
    pretrained_model = 'vgg16'

    # training
    epoch = 14


    use_adam = False # Use Adam optimizer
    use_chainer = False # try match everything as chainer
    use_drop = False # use dropout in RoIHead
    # debug
    debug_file = '/tmp/debugf'

    test_num = 10000
    # model
    load_path = None#与训练的模型

    caffe_pretrain = False # use caffe pretrained model instead of torchvision
    caffe_pretrain_path = 'checkpoints/vgg16_caffe.pth'

    def _parse(self, kwargs):
        state_dict = self._state_dict()
        for k, v in kwargs.items():
            if k not in state_dict:
                raise ValueError('UnKnown Option: "--%s"' % k)
            setattr(self, k, v)

        print('======user config========')
        pprint(self._state_dict())
        print('==========end============')

    def _state_dict(self):
        return {k: getattr(self, k) for k, _ in Config.__dict__.items() \
                if not k.startswith('_')}


opt = Config()

二、训练中的数据加载
data/voc_dataset.py

import os
import xml.etree.ElementTree as ET

import cv2
import numpy as np
import matplotlib.pyplot as plt
# from data.util import read_image
from PIL import Image
'''
1.加载图像和标签。

标签包括boundingbox和其名称标签,由于boundingbox和其标签有多个,所有使用循环读取。返回图像及其标签。
'''
class VOCBboxDataset:

    def __init__(self, data_dir, split='trainval',
                 use_difficult=False, return_difficult=False,
                 ):

        # if split not in ['train', 'trainval', 'val']:
        #     if not (split == 'test' and year == '2007'):
        #         warnings.warn(
        #             'please pick split from \'train\', \'trainval\', \'val\''
        #             'for 2012 dataset. For 2007 dataset, you can pick \'test\''
        #             ' in addition to the above mentioned splits.'
        #         )
        id_list_file = os.path.join(
            data_dir, 'ImageSets/Main/{0}.txt'.format(split))#'G:/Faster-Rcnn-Pytorch/simple-faster-rcnn-pytorch-master/VOCdevkit/VOC2007/ImageSets/Main/trainval.txt'

        self.ids = [id_.strip() for id_ in open(id_list_file)]#去除了换行符
        # self.ids = [id_ for id_ in open(id_list_file)]#图片的编号
        self.data_dir = data_dir
        self.use_difficult = use_difficult
        self.return_difficult = return_difficult
        self.label_names = VOC_BBOX_LABEL_NAMES

    def __len__(self):
        return len(self.ids)

    def get_example(self, i):
        # print("RUN____________________")
        id_ = self.ids[i]#当i为0:00005
        anno = ET.parse(os.path.join(self.data_dir, 'Annotations', id_ + '.xml'))#'G:/Faster-Rcnn-Pytorch/simple-faster-rcnn-pytorch-master/VOCdevkit/VOC2007/Annotations/00005.xml'
        #打开记事本000005.xml(ie浏览器)
        bbox = list()
        label = list()
        difficult = list()

        for obj in anno.findall('object'):
            # when in not using difficult split, and the object is
            # difficult, skipt it.
            if not self.use_difficult and int(obj.find('difficult').text) == 1:#True and False==(int(obj.find('difficult').text)==0)
                continue
            #not self.use_difficult=True
            difficult.append(int(obj.find('difficult').text))#0,0,0
            bndbox_anno = obj.find('bndbox')#R=3
            # 一张图片的加载
            # - < bndbox >
            #
            # < xmin > 263 < / xmin >
            #
            # < ymin > 211 < / ymin >
            #
            # < xmax > 324 < / xmax >
            #
            # < ymax > 339 < / ymax >
            #
            # < / bndbox >
            # - < bndbox >
            #
            # < xmin > 165 < / xmin >
            #
            # < ymin > 264 < / ymin >
            #
            # < xmax > 253 < / xmax >
            #
            # < ymax > 372 < / ymax >
            #
            # < / bndbox >
            # - < bndbox >
            #
            # < xmin > 241 < / xmin >
            #
            # < ymin > 194 < / ymin >
            #
            # < xmax > 295 < / xmax >
            #
            # < ymax > 299 < / ymax >
            #
            # < / bndbox >
            bbox.append([int(bndbox_anno.find(tag).text) - 1 for tag in ('ymin', 'xmin', 'ymax', 'xmax')])
            # [[210. 262. 338. 323.]
            #  [263. 164. 371. 252.]
            # [193.240.298. 294.]]
            name = obj.find('name').text.lower().strip()#'chair'
            label.append(VOC_BBOX_LABEL_NAMES.index(name))# index为8

        bbox = np.stack(bbox).astype(np.float32)#[3,4]
        label = np.stack(label).astype(np.int32)#[3,1]

        # When `use_difficult==False`, all elements in `difficult` are False.
        difficult = np.array(difficult, dtype=np.bool).astype(np.uint8)  # PyTorch don't support np.bool

        # Load a image
        img_file = os.path.join(self.data_dir, 'JPEGImages', id_ + '.jpg')#'G:/Faster-Rcnn-Pytorch/simple-faster-rcnn-pytorch-master/VOCdevkit/VOC2007/JPEGImages/00005.jpg'
        img = read_image(img_file, color=True)
        # img.show()
        # if self.return_difficult:
        #     return img, bbox, label, difficult
        return img, bbox, label, difficult

    __getitem__ = get_example


def read_image(path, dtype=np.float32, color=True):#图片格式很重要

    try:
        f = Image.open(path)#PIL读进来的图像是一个对象
    except IOError:
        print('fail to load image!')
    try:
        if color:
            img = f.convert('RGB')
        else:
            img = f.convert('P')
        img= np.asarray(img, dtype=dtype)
    finally:
        if hasattr(f, 'close'):
            f.close()

    if img.ndim == 2:
        # reshape (H, W) -> (1, H, W)
        return img[np.newaxis]
    else:
        # transpose (H, W, C) -> (C, H, W)
        return img.transpose((2, 0, 1))#转3-H-W
        # return img1

VOC_BBOX_LABEL_NAMES = (
    'aeroplane',
    'bicycle',
    'bird',
    'boat',
    'bottle',
    'bus',
    'car',
    'cat',
    'chair',
    'cow',
    'diningtable',
    'dog',
    'horse',
    'motorbike',
    'person',
    'pottedplant',
    'sheep',
    'sofa',
    'train',
    'tvmonitor')

if __name__ == '__main__':
    data = VOCBboxDataset('G:/Faster-Rcnn-Pytorch/simple-faster-rcnn-pytorch-master/VOCdevkit/VOC2007/')[0]
    img=pytorch_normalze(data[0])
    data_one=data[0].transpose((1,2,0)).astype(np.int32)
    plt.imshow(data_one)
    plt.show()
    # data_one.show()
    plt.axis('off')
    print(data[1])
    print(data[2])
    print(data[3])

脚本运行结果:
脚本运行结果
在这里插入图片描述
data/dataset.py

from __future__ import  absolute_import
from __future__ import  division

from data.voc_dataset import VOCBboxDataset
import torch as t
from skimage import 
  • 5
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值