Faster R-CNN代码理解(一)之train.py

代码链接https://github.com/DetectionTeamUCAS/FPN_Tensorflow
在此,感谢他对我帮助
废话不多说,我们开始对代码进行理解,首先从train.py开始

from __future__ import absolute_import
from __future__ import print_function
from __future__ import division

import tensorflow as tf
import tensorflow.contrib.slim as slim
import os, sys
import numpy as np
import time
sys.path.append("../")

from libs.configs import cfgs
# from libs.networks import build_whole_network2
from libs.networks import build_whole_network
from data.io.read_tfrecord import next_batch
from libs.box_utils import show_box_in_tensor
from help_utils import tools

os.environ["CUDA_VISIBLE_DEVICES"] = cfgs.GPU_GROUP

这一段讲的是一些包的引用及GPU的调用,我也说的不是很明白

def train():
    faster_rcnn = build_whole_network.DetectionNetwork(base_network_name=cfgs.NET_NAME,
                 is_training=True)
   #将以batch_size=1的进度获取变换后的图像
    #img_name_batch, img_batch, gtboxes_and_label_batch, num_obs_batch
    #img_batch: shape:(1, new_imgH, new_imgW, C)
    with tf.name_scope('get_batch'):
        img_name_batch, img_batch, gtboxes_and_label_batch, num_objects_batch = \
            next_batch(dataset_name=cfgs.DATASET_NAME,  # 'pascal', 'coco'
                       batch_size=cfgs.BATCH_SIZE,
                       shortside_len=cfgs.IMG_SHORT_SIDE_LEN,
                       is_training=True)

此处faster_rcnn代表了实例化,可以调用build_whole_network这个类里面的方法了。
next_batch函数的过程是:

  1. 读取训练集文件夹内的文件名,并以一个list的方式返回
  2. 提取每个图像的名字、图像、标注框、标签和标注数目
  3. tf.train.batch利用一个tensor的列表或字典来获取一个batch数据。分别得到一个batch内的图像名、图像、带位置信息的框、框的个数

shape:

  1. img_name_batch: shape(1, 1)
  2. img_batch: shape(1, new_imgH, new_imgW, C)
  3. gtboxes_and_label_batch: shape(1, Num_Of_objects, 5]
    gtboxes_and_label_batch中的5表示[x1, y1, x2, y2, label]
gtboxes_and_label = tf.reshape(gtboxes_and_label_batch, [-1, 5])

看上去像是降维了,每一行就是一个框,每个框的信息代表左上角、右下角,坐标信息是[x1, y1, x2, y2, label],有好多好多的框啊

    with slim.arg_scope([slim.conv2d, slim.conv2d_in_plane, \
                         slim.conv2d_transpose, slim.separable_conv2d, slim.fully_connected],
                        weights_regularizer=weights_regularizer,
                        biases_regularizer=biases_regularizer,
                        biases_initializer=tf.constant_initializer(0.0)):
        final_bbox, final_scores, final_category, loss_dict = faster_rcnn.build_whole_detection_network(
            input_img_batch=img_batch,
            gtboxes_batch=gtboxes_and_label)

slim.arg_scope(XXXX)里面的东西相当于对各个属性赋默认值,之后可以在这里面的代码(就是with包裹住的代码)直接用。

final_bbox, final_scores, final_category, loss_dict = faster_rcnn.build_whole_detection_network(
            input_img_batch=img_batch,
            gtboxes_batch=gtboxes_and_label)

这一段就要说的重点!因为这个faster-rcnn相当于通过这个函数构造网络

先打住,我们现在看一下怎么做的=-=,超级超级繁琐啊

评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值