代码链接:https://github.com/DetectionTeamUCAS/FPN_Tensorflow
在此,感谢他对我帮助
废话不多说,我们开始对代码进行理解,首先从train.py开始
from __future__ import absolute_import
from __future__ import print_function
from __future__ import division
import tensorflow as tf
import tensorflow.contrib.slim as slim
import os, sys
import numpy as np
import time
sys.path.append("../")
from libs.configs import cfgs
# from libs.networks import build_whole_network2
from libs.networks import build_whole_network
from data.io.read_tfrecord import next_batch
from libs.box_utils import show_box_in_tensor
from help_utils import tools
os.environ["CUDA_VISIBLE_DEVICES"] = cfgs.GPU_GROUP
这一段讲的是一些包的引用及GPU的调用,我也说的不是很明白
def train():
faster_rcnn = build_whole_network.DetectionNetwork(base_network_name=cfgs.NET_NAME,
is_training=True)
#将以batch_size=1的进度获取变换后的图像
#img_name_batch, img_batch, gtboxes_and_label_batch, num_obs_batch
#img_batch: shape:(1, new_imgH, new_imgW, C)
with tf.name_scope('get_batch'):
img_name_batch, img_batch, gtboxes_and_label_batch, num_objects_batch = \
next_batch(dataset_name=cfgs.DATASET_NAME, # 'pascal', 'coco'
batch_size=cfgs.BATCH_SIZE,
shortside_len=cfgs.IMG_SHORT_SIDE_LEN,
is_training=True)
此处faster_rcnn代表了实例化,可以调用build_whole_network这个类里面的方法了。
next_batch函数的过程是:
- 读取训练集文件夹内的文件名,并以一个list的方式返回
- 提取每个图像的名字、图像、标注框、标签和标注数目
- tf.train.batch利用一个tensor的列表或字典来获取一个batch数据。分别得到一个batch内的图像名、图像、带位置信息的框、框的个数
shape:
- img_name_batch: shape(1, 1)
- img_batch: shape(1, new_imgH, new_imgW, C)
- gtboxes_and_label_batch: shape(1, Num_Of_objects, 5]
gtboxes_and_label_batch中的5表示[x1, y1, x2, y2, label]
gtboxes_and_label = tf.reshape(gtboxes_and_label_batch, [-1, 5])
看上去像是降维了,每一行就是一个框,每个框的信息代表左上角、右下角,坐标信息是[x1, y1, x2, y2, label],有好多好多的框啊
with slim.arg_scope([slim.conv2d, slim.conv2d_in_plane, \
slim.conv2d_transpose, slim.separable_conv2d, slim.fully_connected],
weights_regularizer=weights_regularizer,
biases_regularizer=biases_regularizer,
biases_initializer=tf.constant_initializer(0.0)):
final_bbox, final_scores, final_category, loss_dict = faster_rcnn.build_whole_detection_network(
input_img_batch=img_batch,
gtboxes_batch=gtboxes_and_label)
slim.arg_scope(XXXX)里面的东西相当于对各个属性赋默认值,之后可以在这里面的代码(就是with包裹住的代码)直接用。
final_bbox, final_scores, final_category, loss_dict = faster_rcnn.build_whole_detection_network(
input_img_batch=img_batch,
gtboxes_batch=gtboxes_and_label)
这一段就要说的重点!因为这个faster-rcnn相当于通过这个函数构造网络
先打住,我们现在看一下怎么做的=-=,超级超级繁琐啊