Faster R-CNN源码阅读之十:Faster R-CNN/lib/fast_rcnn/train.py

版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/DaVinciL/article/details/81982798
  1. Faster R-CNN源码阅读之零:写在前面
  2. Faster R-CNN源码阅读之一:Faster R-CNN/lib/networks/network.py
  3. Faster R-CNN源码阅读之二:Faster R-CNN/lib/networks/factory.py
  4. Faster R-CNN源码阅读之三:Faster R-CNN/lib/networks/VGGnet_test.py
  5. Faster R-CNN源码阅读之四:Faster R-CNN/lib/rpn_msr/generate_anchors.py
  6. Faster R-CNN源码阅读之五:Faster R-CNN/lib/rpn_msr/proposal_layer_tf.py
  7. Faster R-CNN源码阅读之六:Faster R-CNN/lib/fast_rcnn/bbox_transform.py
  8. Faster R-CNN源码阅读之七:Faster R-CNN/lib/rpn_msr/anchor_target_layer_tf.py
  9. Faster R-CNN源码阅读之八:Faster R-CNN/lib/rpn_msr/proposal_target_layer_tf.py
  10. Faster R-CNN源码阅读之九:Faster R-CNN/tools/train_net.py
  11. Faster R-CNN源码阅读之十:Faster R-CNN/lib/fast_rcnn/train.py
  12. Faster R-CNN源码阅读之十一:Faster R-CNN预测demo代码补完
  13. Faster R-CNN源码阅读之十二:写在最后

一、介绍
   本demo由Faster R-CNN官方提供,我只是在官方的代码上增加了注释,一方面方便我自己学习,另一方面贴出来和大家一起交流。
   该文件中的函数的主要目的是训练整个Faster R-CNN网络。
二、代码以及注释

# coding=utf-8
# --------------------------------------------------------
# Fast R-CNN
# Copyright (c) 2015 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Written by Ross Girshick
# --------------------------------------------------------

"""Train a Fast R-CNN network."""

from fast_rcnn.config import cfg
import gt_data_layer.roidb as gdl_roidb
import roi_data_layer.roidb as rdl_roidb
from roi_data_layer.layer import RoIDataLayer
from utils.timer import Timer
import numpy as np
import os
import tensorflow as tf
import sys
from tensorflow.python.client import timeline
import time


class SolverWrapper(object):
    """
    A simple wrapper around Caffe's solver.
    This wrapper gives us control over the snapshot process, which we
    use to unnormalize the learned bounding-box regression weights.

    对Caffe的Solver进行了简单的封装。
    这个封装可以让我们控制snapshot过程,在snapshot过程中,我们对学习得到的bbox回归权重进行了去规范化(unnormalize)。
    """

    def __init__(self, sess, saver, network, imdb, roidb, output_dir, pretrained_model=None):
        """Initialize the SolverWrapper."""
        # 使用的Faster RCNN网络结构
        self.net = network
        # 图片数据集
        self.imdb = imdb
        # rois数据集
        self.roidb = roidb
        # 网络结构和权重保存输出目录
        self.output_dir = output_dir
        # 预训练文件模型路径
        self.pretrained_model = pretrained_model

        print 'Computing bounding-box regression targets...'
        # cfg.TRAIN.BBOX_REG默认为True
        if cfg.TRAIN.BBOX_REG:
            # 不同类的均值与方差,返回格式means.ravel(), stds.ravel()
            self.bbox_means, self.bbox_stds = rdl_roidb.add_bbox_regression_targets(roidb)
        print 'done'

        # For checkpoint
        self.saver = saver

    def snapshot(self, sess, iter):
        """
        Take a snapshot of the network after unnormalizing the learned
        bounding-box regression weights. This enables easy use at test-time.
        在对学习的边界框回归权重进行非标准化(unnormalize)后获取网络snapshot。
        这样可以在测试使用时比较方便
        """
        net = self.net

        if cfg.TRAIN.BBOX_REG and net.layers.has_key('bbox_pred'):
            # save original values
            # 将原来的值保存下来
            with tf.variable_scope('bbox_pred', reuse=True):
                weights = tf.get_variable("weights")
                biases = tf.get_variable("biases")

            orig_0 = weights.eval()
            orig_1 = biases.eval()

            # scale and shift with bbox reg unnormalization; then save snapshot
            # 更新weights和bias
            weights_shape = weights.get_shape().as_list()
            sess.run(net.bbox_weights_assign,
                     feed_dict={net.bbox_weights: orig_0 * np.tile(self.bbox_stds, (weights_shape[0], 1))})
            sess.run(net.bbox_bias_assign,
                     feed_dict={net.bbox_biases: orig_1 * self.bbox_stds + self.bbox_means})

        # 如果网络保存的目录不存在则重新创建一个
        if not os.path.exists(self.output_dir):
            os.makedirs(self.output_dir)

        # 中缀
        infix = ('_' + cfg.TRAIN.SNAPSHOT_INFIX
                 if cfg.TRAIN.SNAPSHOT_INFIX != '' else '')
        # 文件名的创建
        filename = (cfg.TRAIN.SNAPSHOT_PREFIX + infix +
                    '_iter_{:d}'.format(iter + 1) + '.ckpt')
        filename = os.path.join(self.output_dir, filename)

        # 保存网络
        self.saver.save(sess, filename)
        print 'Wrote snapshot to: {:s}'.format(filename)

        # 恢复原始的状态
        if cfg.TRAIN.BBOX_REG and net.layers.has_key('bbox_pred'):
            with tf.variable_scope('bbox_pred', reuse=True):
                # restore net to original state
                sess.run(net.bbox_weights_assign, feed_dict={net.bbox_weights: orig_0})
                sess.run(net.bbox_bias_assign, feed_dict={net.bbox_biases: orig_1})

    # smooth l1方法
    def _modified_smooth_l1(self, sigma, bbox_pred, bbox_targets, bbox_inside_weights, bbox_outside_weights):
        """
            ResultLoss = outside_weights * SmoothL1(inside_weights * (bbox_pred - bbox_targets))
            SmoothL1(x) = 0.5 * (sigma * x)^2,    if |x| < 1 / sigma^2
                          |x| - 0.5 / sigma^2,    otherwise
        """
        # 计算sigma^2
        sigma2 = sigma * sigma

        # 计算所需要处理的x的矩阵,这里利用了之前返回的inside weights。
        inside_mul = tf.multiply(bbox_inside_weights, tf.subtract(bbox_pred, bbox_targets))
        # 获取inside mul矩阵中小于 1 / sigma ^ 2的信息,在每个位置设置为True 或者False。然后转换为1.0或者0.0。
        smooth_l1_sign = tf.cast(tf.less(tf.abs(inside_mul), 1.0 / sigma2), tf.float32)
        # 计算上面公式中的第一个式子,这里并没有关注到后面的判断条件。
        smooth_l1_option1 = tf.multiply(tf.multiply(inside_mul, inside_mul), 0.5 * sigma2)
        # 计算第二个式子。
        smooth_l1_option2 = tf.subtract(tf.abs(inside_mul), 0.5 / sigma2)
        # 这里根据上面产生的smooth l1 sign条件产生最后的结果,就是在这里才综合考虑后面的判断条件
        smooth_l1_result = tf.add(tf.multiply(smooth_l1_option1, smooth_l1_sign),
                                  tf.multiply(smooth_l1_option2, tf.abs(tf.subtract(smooth_l1_sign, 1.0))))

        # 和outside weights相乘并返回最后的结果。
        outside_mul = tf.multiply(bbox_outside_weights, smooth_l1_result)

        return outside_mul

    def train_model(self, sess, max_iters):
        """Network training loop."""

        data_layer = get_data_layer(self.roidb, self.imdb.num_classes)

        # RPN
        # classification loss
        # rpn-data数据都是在anchor target layer中产生
        # 将'rpn_cls_score_reshape'层的输出(1, n,n,18)reshape为(-1, 2), 其中2为前景与背景的多分类得分()
        rpn_cls_score = tf.reshape(self.net.get_output('rpn_cls_score_reshape'), [-1, 2])
        # 将labels展开成1维
        rpn_label = tf.reshape(self.net.get_output('rpn-data')[0], [-1])
        # 把rpn_label不等于-1对应引索的rpn_cls_score取出,重新组合成rpn_cls_score
        rpn_cls_score = tf.reshape(tf.gather(rpn_cls_score, tf.where(tf.not_equal(rpn_label, -1))), [-1, 2])
        # 把rpn_label不等于 - 1对应引索的rpn_label取出,重新组合成rpn_label
        rpn_label = tf.reshape(tf.gather(rpn_label, tf.where(tf.not_equal(rpn_label, -1))), [-1])
        # labels的交叉熵损失。
        # tf.nn.sparse_softmax_cross_entropy_with_logits返回的是一个向量,最后需要通过规约操作生成损失数值。
        rpn_cross_entropy = tf.reduce_mean(
            tf.nn.sparse_softmax_cross_entropy_with_logits(logits=rpn_cls_score, labels=rpn_label))

        # bounding box regression L1 loss
        # 获取RPN网络产生的bbox回归目标
        rpn_bbox_pred = self.net.get_output('rpn_bbox_pred')
        # 获取rpn-data层产生的bbox回归目标和inside weights和outside weights,并将通道顺序更改为[N, H, W, C]
        rpn_bbox_targets = tf.transpose(self.net.get_output('rpn-data')[1], [0, 2, 3, 1])
        rpn_bbox_inside_weights = tf.transpose(self.net.get_output('rpn-data')[2], [0, 2, 3, 1])
        rpn_bbox_outside_weights = tf.transpose(self.net.get_output('rpn-data')[3], [0, 2, 3, 1])

        # 计算smooth l1的结果
        rpn_smooth_l1 = self._modified_smooth_l1(3.0, rpn_bbox_pred, rpn_bbox_targets, rpn_bbox_inside_weights,
                                                 rpn_bbox_outside_weights)
        # 对smooth l1的结果进行归约操作,因为smooth l1返回的结果是一个矩阵。
        rpn_loss_box = tf.reduce_mean(tf.reduce_sum(rpn_smooth_l1, reduction_indices=[1, 2, 3]))

        # R-CNN
        # classification loss
        # roi-data由proposal target layer产生
        # 获取每个roi的预测的分类概率分布
        cls_score = self.net.get_output('cls_score')
        # 获取每个roi的实际label,并展开成一维数组
        label = tf.reshape(self.net.get_output('roi-data')[1], [-1])
        # 计算rois分类的交叉熵损失
        cross_entropy = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=cls_score, labels=label))

        # bounding box regression L1 loss
        # 获取Fast RCNN网络产生的预测的bbox回归目标
        bbox_pred = self.net.get_output('bbox_pred')
        # 获取roi-data层bbox的回归目标以及inside weights和outside weights。
        bbox_targets = self.net.get_output('roi-data')[2]
        bbox_inside_weights = self.net.get_output('roi-data')[3]
        bbox_outside_weights = self.net.get_output('roi-data')[4]

        # 计算smooth l1的结果
        smooth_l1 = self._modified_smooth_l1(1.0, bbox_pred, bbox_targets, bbox_inside_weights, bbox_outside_weights)
        # 归约smooth l1的计算结果。
        loss_box = tf.reduce_mean(tf.reduce_sum(smooth_l1, reduction_indices=[1]))

        # final loss
        # 网络的总损失函数是上述四个损失值的相加
        loss = cross_entropy + loss_box + rpn_cross_entropy + rpn_loss_box

        # optimizer and learning rate
        # 全局的步数
        global_step = tf.Variable(0, trainable=False)
        # 学习率设置
        lr = tf.train.exponential_decay(cfg.TRAIN.LEARNING_RATE, global_step,
                                        cfg.TRAIN.STEPSIZE, 0.1, staircase=True)
        # momentum设置,默认值为0.9
        momentum = cfg.TRAIN.MOMENTUM
        # 优化器设置
        train_op = tf.train.MomentumOptimizer(lr, momentum).minimize(loss, global_step=global_step)

        # iintialize variables
        # 初始化所有变量
        sess.run(tf.global_variables_initializer())
        # 如果提供了预训练模型,则加载预训练模型
        if self.pretrained_model is not None:
            print ('Loading pretrained model '
                   'weights from {:s}').format(self.pretrained_model)
            self.net.load(self.pretrained_model, sess, self.saver, True)

        last_snapshot_iter = -1
        # 计时器
        timer = Timer()
        # 进入循环迭代训练
        for iter in range(max_iters):
            # get one batch
            # 获取一个batch信息
            blobs = data_layer.forward()

            # Make one SGD update
            # 准备feed进网络中的数据
            feed_dict = {self.net.data: blobs['data'],
                         self.net.im_info: blobs['im_info'],
                         self.net.keep_prob: 0.5,
                         self.net.gt_boxes: blobs['gt_boxes']}

            # cfg.TRAIN.DEBUG_TIMELINE默认为False。不建议设置为True,否则可能会出错。下同。
            run_options = None
            run_metadata = None
            if cfg.TRAIN.DEBUG_TIMELINE:
                run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
                run_metadata = tf.RunMetadata()

            # 记录开始时间戳
            timer.tic()

            # 进行一次训练
            rpn_loss_cls_value, rpn_loss_box_value, loss_cls_value, loss_box_value, _ = sess.run(
                [rpn_cross_entropy, rpn_loss_box, cross_entropy, loss_box, train_op],
                feed_dict=feed_dict,
                options=run_options,
                run_metadata=run_metadata)

            # 记录结束时间戳
            timer.toc()

            if cfg.TRAIN.DEBUG_TIMELINE:
                trace = timeline.Timeline(step_stats=run_metadata.step_stats)
                trace_file = open(str(long(time.time() * 1000)) + '-train-timeline.ctf.json', 'w')
                trace_file.write(trace.generate_chrome_trace_format(show_memory=False))
                trace_file.close()

            # 显示训练的阶段性结果,主要为各种loss值。
            if (iter + 1) % (cfg.TRAIN.DISPLAY) == 0:
                print 'iter: %d / %d, total loss: %.4f, rpn_loss_cls: %.4f, rpn_loss_box: %.4f, loss_cls: %.4f, loss_box: %.4f, lr: %f' % \
                      (iter + 1, max_iters, rpn_loss_cls_value + rpn_loss_box_value + loss_cls_value + loss_box_value,
                       rpn_loss_cls_value, rpn_loss_box_value, loss_cls_value, loss_box_value, lr.eval())
                print 'speed: {:.3f}s / iter'.format(timer.average_time)

            # 进行网络的snapshot获取并保存整个Faster RCNN网络。
            if (iter + 1) % cfg.TRAIN.SNAPSHOT_ITERS == 0:
                last_snapshot_iter = iter
                self.snapshot(sess, iter)

        # 结束的时候再进行依次snapshot获取和网络保存
        if last_snapshot_iter != iter:
            self.snapshot(sess, iter)


def get_training_roidb(imdb):
    """
    Returns a roidb (Region of Interest database) for use in training.
    获取一个训练时使用的roidb。
    """
    if cfg.TRAIN.USE_FLIPPED:
        print 'Appending horizontally-flipped training examples...'
        imdb.append_flipped_images()
        print 'done'

    print 'Preparing training data...'
    if cfg.TRAIN.HAS_RPN:
        if cfg.IS_MULTISCALE:
            gdl_roidb.prepare_roidb(imdb)
        else:
            rdl_roidb.prepare_roidb(imdb)
    else:
        rdl_roidb.prepare_roidb(imdb)
    print 'done'

    return imdb.roidb


def get_data_layer(roidb, num_classes):
    """
    return a data layer.
    获取并返回一个一个数据层
    """
    if cfg.TRAIN.HAS_RPN:
        if cfg.IS_MULTISCALE:
            layer = GtDataLayer(roidb)
        else:
            layer = RoIDataLayer(roidb, num_classes)
    else:
        layer = RoIDataLayer(roidb, num_classes)

    return layer


def filter_roidb(roidb):
    """
    Remove roidb entries that have no usable RoIs.
    移除没有可用ROIS的roidb条目
    """

    def is_valid(entry):
        # Valid images have:
        #   (1) At least one foreground RoI OR
        #   (2) At least one background RoI
        overlaps = entry['max_overlaps']
        # find boxes with sufficient overlap
        fg_inds = np.where(overlaps >= cfg.TRAIN.FG_THRESH)[0]
        # Select background RoIs as those within [BG_THRESH_LO, BG_THRESH_HI)
        bg_inds = np.where((overlaps < cfg.TRAIN.BG_THRESH_HI) &
                           (overlaps >= cfg.TRAIN.BG_THRESH_LO))[0]
        # image is only valid if such boxes exist
        valid = len(fg_inds) > 0 or len(bg_inds) > 0
        return valid

    num = len(roidb)
    filtered_roidb = [entry for entry in roidb if is_valid(entry)]
    num_after = len(filtered_roidb)
    print 'Filtered {} roidb entries: {} -> {}'.format(num - num_after,
                                                       num, num_after)
    return filtered_roidb


def train_net(network, imdb, roidb, output_dir, pretrained_model=None, max_iters=40000):
    """
    Train a Fast R-CNN network.
    :param network: Faster RCNN训练的网络结构
    :param imdb: 图片数据集
    :param roidb: rois数据集
    :param output_dir: 网络权重文件的保存目录
    :param pretrained_model: 预训练网络权重文件路径
    :param max_iters: 最大迭代次数
    :return: None
    """
    # 筛选roidb
    roidb = filter_roidb(roidb)
    # tf网络保存器
    saver = tf.train.Saver(max_to_keep=100)
    # tf会话
    with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
        # solver封装
        sw = SolverWrapper(sess, saver, network, imdb, roidb, output_dir, pretrained_model=pretrained_model)
        print 'Solving...'
        # 训练网络
        sw.train_model(sess, max_iters)
        print 'done solving'
展开阅读全文

Faster R-CNN

02-21

<p style="font-size:16px;">rn 本课程适合具有一定深度学习基础,希望发展为深度学习之计算机视觉方向的算法工程师和研发人员的同学们。<br />rn<br />rn基于深度学习的计算机视觉是目前人工智能最活跃的领域,应用非常广泛,如人脸识别和无人驾驶中的机器视觉等。该领域的发展日新月异,网络模型和算法层出不穷。如何快速入门并达到可以从事研发的高度对新手和中级水平的学生而言面临不少的挑战。精心准备的本课程希望帮助大家尽快掌握基于深度学习的计算机视觉的基本原理、核心算法和当前的领先技术,从而有望成为深度学习之计算机视觉方向的算法工程师和研发人员。<br />rn<br />rn本课程系统全面地讲述基于深度学习的计算机视觉技术的原理并进行项目实践。课程涵盖计算机视觉的七大任务,包括图像分类、目标检测、图像分割(语义分割、实例分割、全景分割)、人脸识别、图像描述、图像检索、图像生成(利用生成对抗网络)。本课程注重原理和实践相结合,逐篇深入解读经典和前沿论文70余篇,图文并茂破译算法难点, 使用思维导图梳理技术要点。项目实践使用Keras框架(后端为Tensorflow),学员可快速上手。<br />rn<br />rn通过本课程的学习,学员可把握基于深度学习的计算机视觉的技术发展脉络,掌握相关技术原理和算法,有助于开展该领域的研究与开发实战工作。另外,深度学习之计算机视觉方向的知识结构及学习建议请参见本人CSDN博客。<br />rn<br />rn本课程提供课程资料的课件PPT(pdf格式)和项目实践代码,方便学员学习和复习。<br />rn<br />rn本课程分为上下两部分,其中上部包含课程的前五章(课程介绍、深度学习基础、图像分类、目标检测、图像分割),下部包含课程的后四章(人脸识别、图像描述、图像检索、图像生成)。rn</p>rn<div>rn <br />rn</div>rn<p>rn <br />rn</p>rn<p>rn <br />rn</p>rn<p style="font-size:16px;">rn <br />rn</p>rn<p style="font-size:16px;">rn <img src="https://img-bss.csdn.net/201902211157137641.jpg" alt="" /><img src="https://img-bss.csdn.net/201902211157578041.gif" alt="" /><img src="https://img-bss.csdn.net/201902211158173579.gif" alt="" /><img src="https://img-bss.csdn.net/201902211158498135.gif" alt="" /><img src="https://img-bss.csdn.net/201902211159093293.gif" alt="" /><img src="https://img-bss.csdn.net/201902211159209625.gif" alt="" /> rn</p>rn<p style="font-size:16px;">rn <br />rn</p>

没有更多推荐了,返回首页