tensorflow学习笔记(四)

代码学习有点吃力,学习了YOLOv1的代码,主要是训练部分的代码,对yolo的又有了进一步的理解。其文件夹下主要包含py文件为,train.py, yolo_net.py, pascal_voc.。下面是比较详细的代码解读。但是还是有一些内容理解的不是很透彻。暂时就这样吧。

首先看一下yolo_net.py文件,这个文件主要定义了网络结构,损失函数的计算等内容。

import numpy as np
import tensorflow as tf
import tensorflow.contrib.slim as slim
'''YOLOnet主要包含了损失函数的计算 和 定义网络结构,其中损失函数的计算包括了交并比函数'''
class YOLONet(object):
    def __init__(self, is_training=True):
        self.classes = ['aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse',
           'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor']
        #self.num_class = len(self.classes)
        self.num_class = 20
        self.image_size = 448       #图片尺寸
        self.cell_size = 7          #格子数目
        self.boxes_per_cell = 2     #每个格子预测2个框
        self.output_size = (self.cell_size * self.cell_size) * (self.num_class + self.boxes_per_cell * 5)  #输出的维度S*S*(B*5+C) = 1470
        self.scale = 1.0 * self.image_size / self.cell_size     #每个格子的像素大小
        self.boundary1 = self.cell_size * self.cell_size * self.num_class       # 7 * 7 * 20
        self.boundary2 = self.boundary1 + self.cell_size * self.cell_size * self.boxes_per_cell     #7 * 7 * 20 + 7 * 7 *2

        self.object_scale = 1       # 这四个是损失函数前面的系数
        self.noobject_scale = 1
        self.class_scale = 2
        self.coord_scale = 5

        self.learning_rate = 0.0001
        self.batch_size = 20
        self.alpha = 0.1
        #offset.shape = [7,7,2]
        self.offset = np.transpose(np.reshape(np.array([np.arange(self.cell_size)] * self.cell_size * self.boxes_per_cell),
            (self.boxes_per_cell, self.cell_size, self.cell_size)), (1, 2, 0))

        self.images = tf.placeholder(tf.float32, [None, self.image_size, self.image_size, 3], name='images')
        self.logits = self.build_network(self.images, num_outputs=self.output_size, alpha=self.alpha, is_training=is_training)  #self.logits.shape: (?, 1470)
        print("self.logits.shape :",self.logits.shape)
        if is_training:
            #self.labels.shape = [None,7,7,25]
            self.labels = tf.placeholder(tf.float32, [None, self.cell_size, self.cell_size, 5 + self.num_class])
            # self.logits.shape: (?, 1470)
            self.loss_layer(self.logits, self.labels)
            self.total_loss = tf.losses.get_total_loss()
            tf.summary.scalar('total_loss', self.total_loss)

    def build_network(self, images, num_outputs, alpha, keep_prob=0.5, is_training=True):
        # num_outputs = 1470
        with tf.variable_scope('yolo'):
            with slim.arg_scope(
                [slim.conv2d, slim.fully_connected],
                activation_fn=leaky_relu(alpha),
                weights_regularizer=slim.l2_regularizer(0.0005),
                weights_initializer=tf.truncated_normal_initializer(0.0, 0.01)
            ):
                net = tf.         pad(images, np.array([[0, 0], [3, 3], [3, 3], [0, 0]]), name='pad_1')
                net = slim.    conv2d(net, 64, 7, 2, padding='VALID', scope='conv_2')
                net = slim.max_pool2d(net, 2,    padding='SAME', scope='pool_3')
                net = slim.    conv2d(net, 192,  3, scope='conv_4')
                net = slim.max_pool2d(net, 2,    padding='SAME', scope='pool_5')
                net = slim.    conv2d(net, 128,  1, scope='conv_6')
                net = slim.    conv2d(net, 256,  3, scope='conv_7')
                net = slim.    conv2d(net, 256,  1, scope='conv_8')
                net = slim.    conv2d(net, 512,  3, scope='conv_9')
                net = slim.max_pool2d(net, 2,    padding='SAME', scope='pool_10')
                net = slim.    conv2d(net, 256,  1, scope='conv_11')
                net = slim.    conv2d(net, 512,  3, scope='conv_12')
                net = slim.    conv2d(net, 256,  1, scope='conv_13')
                net = slim.    conv2d(net, 512,  3, scope='conv_14')
                net = slim.    conv2d(net, 256,  1, scope='conv_15')
                net = slim.    conv2d(net, 512,  3, scope='conv_16')
                net = slim.    conv2d(net, 256,  1, scope='conv_17')
                net = slim.    conv2d(net, 512,  3, scope='conv_18')
                net = slim.    conv2d(net, 512,  1, scope='conv_19')
                net = slim.    conv2d(net, 1024, 3, scope='conv_20')
                net = slim.max_pool2d(net, 2,    padding='SAME', scope='pool_21')
                net = slim.    conv2d(net, 512,  1, scope='conv_22')
                net = slim.    conv2d(net, 1024, 3, scope='conv_23')
                net = slim.    conv2d(net, 512,  1,  scope='conv_24')
                net = slim.    conv2d(net, 1024, 3, scope='conv_25')
                print(net.op.name, net.shape)
                net = slim.    conv2d(net, 1024, 3, scope='conv_26')
                print(net.op.name, net.shape)
                net = tf.         pad(net, np.array([[0, 0], [1, 1], [1, 1], [0, 0]]), name='pad_27')
                print(net.op.name, net.shape)
                net = slim .   conv2d(net, 1024, 3, 2, padding='VALID', scope='conv_28')
                print(net.op.name, net.shape)
                net = slim.    conv2d(net, 1024, 3, scope='conv_29')
                print(net.op.name, net.shape)
                net = slim.    conv2d(net, 1024, 3, scope='conv_30')
                print(net.op.name ,net.shape)
                net = tf.   transpose(net, [0, 3, 1, 2], name='trans_31')
                print(net.op.name, net.shape)
                net = slim.   flatten(net, scope='flat_32')
                print(net.op.name, net.shape)
                net = slim.fully_connected(net, 512, scope='fc_33')
                print(net.op.name, net.shape)
                net = slim.fully_connected(net, 4096, scope='fc_34')
                print(net.op.name, net.shape)
                net = slim.        dropout(net, keep_prob=keep_prob, is_training=is_training, scope='dropout_35')
                print(net.op.name, net.shape)
                net = slim.fully_connected(net, num_outputs, activation_fn=None, scope='fc_36')
                print(net.op.name, net.shape)
        return net

    def calc_iou(self, boxes1, boxes2, scope='iou'):
        """calculate ious
        Args:
          boxes1: 5-D tensor [BATCH_SIZE, CELL_SIZE, CELL_SIZE, BOXES_PER_CELL, 4]  ====> (x_center, y_center, w, h)
          boxes2: 5-D tensor [BATCH_SIZE, CELL_SIZE, CELL_SIZE, BOXES_PER_CELL, 4] ===> (x_center, y_center, w, h)
        Return:
          iou: 4-D tensor [BATCH_SIZE, CELL_SIZE, CELL_SIZE, BOXES_PER_CELL]
        """
        with tf.variable_scope(scope):
            # transform (x_center, y_center, w, h) to (x1, y1, x2, y2)
            boxes1_t = tf.stack([boxes1[..., 0] - boxes1[..., 2] / 2.0,#四个坐标的计算方法:中心点减去二分之一个宽,得到左边坐标 x1
                                 boxes1[..., 1] - boxes1[..., 3] / 2.0,#中心点减去二分之一个高,得到上坐标 y1
                                 boxes1[..., 0] + boxes1[..., 2] / 2.0,#中心点加上二分之一个高,得到上坐标 y2
                                 boxes1[..., 1] + boxes1[..., 3] / 2.0],#中心点加上二分之一个高,得到上坐标y2
                                axis=-1)
            boxes2_t = tf.stack([boxes2[..., 0] - boxes2[..., 2] / 2.0, #那么下面几行就是计算第二个框的四个坐标值了
                                 boxes2[..., 1] - boxes2[..., 3] / 2.0,
                                 boxes2[..., 0] + boxes2[..., 2] / 2.0,
                                 boxes2[..., 1] + boxes2[..., 3] / 2.0],
                                axis=-1)
            # 计算左上点和右下点
            lu = tf.maximum(boxes1_t[..., :2], boxes2_t[..., :2])
            rd = tf.minimum(boxes1_t[..., 2:], boxes2_t[..., 2:])

            # 计算相交部分面积(我没太弄明白,这个函数到底是怎么计算相交面积的)
            intersection = tf.maximum(0.0, rd - lu)
            inter_square = intersection[..., 0] * intersection[..., 1]

            # 分别计算两个框的面积(真实框和预测框)
            square1 = boxes1[..., 2] * boxes1[..., 3]
            square2 = boxes2[..., 2] * boxes2[..., 3]
            # 这一步在计算两个框相交的面积,公共面积union_square
            union_square = tf.maximum(square1 + square2 - inter_square, 1e-10)
        # 虽然细节的地方没弄太明白,但是明显该受到这个返回值是交并比,也就是论文中的IOU
        return tf.clip_by_value(inter_square / union_square, 0.0, 1.0)

    def loss_layer(self, predicts, labels, scope='loss_layer'):
        # self.logits.shape : (?, 1470)     预测的是两个框,和20个类别的概率,一共30维
        # labels.shape = [None,7,7,25]      真实的图片,只有一个类别和一个框,所以25维
        with tf.variable_scope(scope):
            # 哦,这里原来就是花式索引那部分:将7*7*30个向量,花式索引
            predict_classes = tf.reshape(
                predicts[:, :self.boundary1],   # boundary1 = 7*7*20
                [self.batch_size, self.cell_size, self.cell_size, self.num_class])  # shape=(None,7,7,20)
            predict_scales = tf.reshape(
                predicts[:, self.boundary1:self.boundary2], # 索引第21和22个,意义是预测的概率大小
                [self.batch_size, self.cell_size, self.cell_size, self.boxes_per_cell])
            predict_boxes = tf.reshape(predicts[:, self.boundary2:],       # 索引框的位置,22之后的元素
                [self.batch_size, self.cell_size, self.cell_size, self.boxes_per_cell, 4])#predict_boxes.shape=(None,7,7,2,4)
            response = tf.reshape(labels[..., 0], [self.batch_size, self.cell_size, self.cell_size, 1])
            # 这些定义的都是框的维度,里面还没有具体的内容
            boxes = tf.reshape(labels[..., 1:5], [self.batch_size, self.cell_size, self.cell_size, 1, 4])   # boxes.shape=(None,7,7,1,4)
            #那么我这个boxes应该就是真实的框,然后经过了一个归一化
            boxes = tf.tile(boxes, [1, 1, 1, self.boxes_per_cell, 1]) / self.image_size # boxes.shape=(None,7,7,2,4)

            classes = labels[..., 5:]

            # offset.shapae(7,7,2) ----->(1,7,7,2)
            offset = tf.reshape(tf.constant(self.offset, dtype=tf.float32), [1, self.cell_size, self.cell_size, self.boxes_per_cell])

            # offset.shape(1,7,7,2) ----->(None,7,7,2)
            offset = tf.tile(offset, [self.batch_size, 1, 1, 1])

            # offset_tran.shape(None,7,7,2)----->(None,7,7,2)        不知道为何7和7要互换一下位置
            offset_tran = tf.transpose(offset, (0, 2, 1, 3))

            predict_boxes_tran = tf.stack(
                [(predict_boxes[..., 0] + offset) / self.cell_size,
                 (predict_boxes[..., 1] + offset_tran) / self.cell_size,
                 tf.square(predict_boxes[..., 2]),
                 tf.square(predict_boxes[..., 3])], axis=-1)

            iou_predict_truth = self.calc_iou(predict_boxes_tran, boxes)

            # calculate I tensor [BATCH_SIZE, CELL_SIZE, CELL_SIZE, BOXES_PER_CELL]
            object_mask = tf.reduce_max(iou_predict_truth, 3, keep_dims=True)
            object_mask = tf.cast((iou_predict_truth >= object_mask), tf.float32) * response

            # calculate no_I tensor [CELL_SIZE, CELL_SIZE, BOXES_PER_CELL]
            noobject_mask = tf.ones_like(object_mask, dtype=tf.float32) - object_mask

            # boxes.shape=(None,7,7,2,4)那么最后一个4,是四维的,(x,y,w,h)
            boxes_tran = tf.stack(
                [boxes[..., 0] * self.cell_size - offset,
                 boxes[..., 1] * self.cell_size - offset_tran,
                 tf.sqrt(boxes[..., 2]),
                 tf.sqrt(boxes[..., 3])], axis=-1)

            # class_loss
            class_delta = response * (predict_classes - classes)
            class_loss = tf.reduce_mean(
                tf.reduce_sum(tf.square(class_delta), axis=[1, 2, 3]),
                name='class_loss') * self.class_scale

            # object_loss
            object_delta = object_mask * (predict_scales - iou_predict_truth)
            object_loss = tf.reduce_mean(tf.reduce_sum(tf.square(object_delta), axis=[1, 2, 3]),
                name='object_loss') * self.object_scale

            # noobject_loss
            noobject_delta = noobject_mask * predict_scales
            noobject_loss = tf.reduce_mean(
                tf.reduce_sum(tf.square(noobject_delta), axis=[1, 2, 3]),
                name='noobject_loss') * self.noobject_scale

            # coord_loss
            coord_mask = tf.expand_dims(object_mask, 4)
            boxes_delta = coord_mask * (predict_boxes - boxes_tran)
            coord_loss = tf.reduce_mean(
                tf.reduce_sum(tf.square(boxes_delta), axis=[1, 2, 3, 4]),
                name='coord_loss') * self.coord_scale
            # 累加loss
            tf.losses.add_loss(class_loss)
            tf.losses.add_loss(object_loss)
            tf.losses.add_loss(noobject_loss)
            tf.losses.add_loss(coord_loss)
            # summary基本上都是把数据记录到可视化文件的功能
            tf.summary.scalar('class_loss', class_loss)
            tf.summary.scalar('object_loss', object_loss)
            tf.summary.scalar('noobject_loss', noobject_loss)
            tf.summary.scalar('coord_loss', coord_loss)

            tf.summary.histogram('boxes_delta_x', boxes_delta[..., 0])
            tf.summary.histogram('boxes_delta_y', boxes_delta[..., 1])
            tf.summary.histogram('boxes_delta_w', boxes_delta[..., 2])
            tf.summary.histogram('boxes_delta_h', boxes_delta[..., 3])
            tf.summary.histogram('iou', iou_predict_truth)

#重载了leaky_relu函数
def leaky_relu(alpha):
    def op(inputs):
        return tf.nn.leaky_relu(inputs, alpha=alpha, name='leaky_relu')
    return op

第二个文件是pascal_voc.py,这个文件主要定义了一些数据处理的函数,以及从xml文件中提取标签信息的方法。

#将尺寸缩放到448*448并进行归一化到(-1,1) 和对应的labels的数据
import os
import xml.etree.ElementTree as ET  # 用于解析xml文件的
import numpy as np
import cv2
import pickle
import copy

class pascal_voc(object):  # 定义一个pascal_voc类
    def __init__(self, phase, rebuild=False):
        self.devkil_path = os.path.join('E:\Desktop\yolo_tensorflow-master1\data\pascal_voc', 'VOCdevkit')  # 开发包列表目录:当前工作路径/data/pascal_voc/VOCdevkit
        self.data_path = os.path.join(self.devkil_path, 'VOC2007')  # 开发包数据目录:当前工作路径/data/pascal_voc/VOCdevkit/VOC2007
        self.cache_path = 'E:\Desktop\yolo_tensorflow-master1\data\pascal_voc\cache'  # 见yolo目录下的config.py文件
        self.batch_size = 20
        self.image_size = 448
        self.cell_size = 7
        self.classes = ['aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse',
           'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor']
        # 将类别中文名数字序列化成0,1,2,……
        self.class_to_ind = dict(zip(self.classes, range(len(self.classes))))
        self.flipped = True     # 水平翻转的参数,应该是想增强数据
        self.phase = phase  # 定义训练or测试
        self.rebuild = rebuild
        self.cursor = 0  # 光标移动用,查询gt_labels这个结构
        self.epoch = 1
        self.gt_labels = None
        self.prepare()

    def get(self):
        # 初始化图像。batch_size x 448x448x3, self.batch_size=30,  image_size=448
        images = np.zeros((self.batch_size, self.image_size, self.image_size, 3))
        # 初始化类别(gt)。batch_size x 7x7x25 ,cell_size = 7,对于另外一个box就不构建维度了,因此是25
        labels = np.zeros((self.batch_size, self.cell_size, self.cell_size, 25))
        count = 0
        while count < self.batch_size:  # batch处理
            imname = self.gt_labels[self.cursor]['imname']  # 从gt label中读取图像名
            flipped = self.gt_labels[self.cursor]['flipped']  # 从gt label中查看是否flipped
            images[count, :, :, :] = self.image_read(imname, flipped)
            # 从gt label中获取label类别坐标等信息
            labels[count, :, :, :] = self.gt_labels[self.cursor]['label']
            count += 1
            self.cursor += 1
            if self.cursor >= len(self.gt_labels):  # 判断是否训练完一个epoch了
                np.random.shuffle(self.gt_labels)
                self.cursor = 0
                self.epoch += 1
        return images, labels  # 返回尺寸缩放和归一化后的image序列;以及labels 真实信息
    #读取图片做归一化
    def image_read(self, imgname, flipped=False):
        image = cv2.imread(imgname)
        image = cv2.resize(image, (self.image_size, self.image_size))
        # astype,转换数据类型
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB).astype(np.float32)
        # 图像像素值归一化到(-1,1)
        image = (image / 255.0) * 2.0 - 1.0
        if flipped:
            image = image[:, ::-1, :]
        return image

    def prepare(self):  # 是否做flipped并打乱原来次序返回结果
        gt_labels = self.load_labels()  # 获取gt labels数据
        if self.flipped:  # 判断是否做flipped
            print('Appending horizontally-flipped training examples ...')
            gt_labels_cp = copy.deepcopy(gt_labels)
            for idx in range(len(gt_labels_cp)):
                gt_labels_cp[idx]['flipped'] = True
                gt_labels_cp[idx]['label'] = \
                    gt_labels_cp[idx]['label'][:, ::-1, :]
                for i in range(self.cell_size):
                    for j in range(self.cell_size):
                        if gt_labels_cp[idx]['label'][i, j, 0] == 1:
                            gt_labels_cp[idx]['label'][i, j, 1] = \
                                self.image_size - 1 - \
                                gt_labels_cp[idx]['label'][i, j, 1]
            gt_labels += gt_labels_cp
        np.random.shuffle(gt_labels)  # 对gt labels打乱顺序
        self.gt_labels = gt_labels
        return gt_labels

    def load_labels(self):
        cache_file = os.path.join(
            self.cache_path, 'pascal_' + self.phase + '_gt_labels.pkl')  # cache/pascal_test/train_gt_labels.pkl

        if os.path.isfile(cache_file) and not self.rebuild:
            print('Loading gt_labels from: ' + cache_file)  # 从cache目录加载gt label文件
            with open(cache_file, 'rb') as f:
                gt_labels = pickle.load(f)
            return gt_labels  # 返回gt

        print('Processing gt_labels from: ' + self.data_path)  # 处理来自data目录下的gt label

        if not os.path.exists(self.cache_path):  # 如果不存在目录文件则创建
            os.makedirs(self.cache_path)

        if self.phase == 'train':
            # 如果是train阶段,则txtname是:当前工作路径/data/pascal_voc/VOCdevkit/VOC2007/ImageSets/Main/trainval.txt这个
            txtname = os.path.join(self.data_path, 'ImageSets', 'Main', 'trainval.txt')
        else:
            # 如果是test阶段,则txtname是:当前工作路径/data/pascal_voc/VOCdevkit/VOC2007/ImageSets/Main/test.txt这个
            txtname = os.path.join(self.data_path, 'ImageSets', 'Main', 'test.txt')
        with open(txtname, 'r') as f:
            self.image_index = [x.strip() for x in f.readlines()]

        gt_labels = []  # 创建列表存放gt label
        for index in self.image_index:
            print(index)
            label, num = self.load_pascal_annotation(index)  # 取gt label以及num目标数
            if num == 0:
                continue
            # 找到图像文件夹下对应索引号的图像
            imname = os.path.join(self.data_path, 'JPEGImages', index + '.jpg')
            gt_labels.append({'imname': imname,
                              'label': label,
                              'flipped': False})
        print('Saving gt_labels to: ' + cache_file)
        with open(cache_file, 'wb') as f:
            pickle.dump(gt_labels, f)  # 将gt labels(图形名,目标类别位置坐标信息,是否flipped)写入cache中
        return gt_labels

    def load_pascal_annotation(self, index):  # 从xml文件中获取bbox信息
        """
        Load image and bounding boxes info from XML file in the PASCAL VOC
        format.
        """
        imname = os.path.join(self.data_path, 'JPEGImages', index + '.jpg')  # 图像目录下读取jpg文件:当前工作路径/data/pascal_voc/VOCdevkit/VOC2007/JPEGImages
        im = cv2.imread(imname)
        h_ratio = 1.0 * self.image_size / im.shape[0]  # 尺寸缩放系数
        w_ratio = 1.0 * self.image_size / im.shape[1]
        # im = cv2.resize(im, [self.image_size, self.image_size])

        label = np.zeros((self.cell_size, self.cell_size, 25))
        filename = os.path.join(self.data_path, 'Annotations', index + '.xml')  # 读取xml文件
        tree = ET.parse(filename)  # 解析树
        objs = tree.findall('object')  # 找xml文件中的object

        for obj in objs:  # 遍历object
            bbox = obj.find('bndbox')  # 查找object的bounding box
            # Make pixel indexes 0-basedq
            #           (float(bbox.find('xmin').text) - 1) * w_ratio
            x1 = max(min((float(bbox.find('xmin').text) - 1) * w_ratio, self.image_size - 1), 0)  # 将xml文件中的坐标做尺寸缩放
            y1 = max(min((float(bbox.find('ymin').text) - 1) * h_ratio, self.image_size - 1), 0)
            x2 = max(min((float(bbox.find('xmax').text) - 1) * w_ratio, self.image_size - 1), 0)
            y2 = max(min((float(bbox.find('ymax').text) - 1) * h_ratio, self.image_size - 1), 0)
            cls_ind = self.class_to_ind[obj.find('name').text.lower().strip()]  # 实际类别对应数字序号
            boxes = [(x2 + x1) / 2.0, (y2 + y1) / 2.0, x2 - x1, y2 - y1]  # 坐标转换成(x,y,w,h)
            x_ind = int(boxes[0] * self.cell_size / self.image_size)  # 判断x属于第几个cell

            y_ind = int(boxes[1] * self.cell_size / self.image_size)  # 判断y属于第几个cell
            if label[y_ind, x_ind, 0] == 1:
                print('x_ind{},y_ind{}'.format(x_ind, y_ind))
                print('label[y_ind, x_ind, 1:5]{}'.format(label[y_ind, x_ind, 1:5]))  # 坐标赋值
                print('label[y_ind, x_ind, 0]等于1       ', label[y_ind, x_ind, 0])
                continue
            print('x_ind{},y_ind{}'.format(x_ind, y_ind))
            print('label[y_ind, x_ind, 1:5]{}'.format(label[y_ind, x_ind, 1:5]))  # 坐标赋值
            print('label[y_ind, x_ind, 0]不等于1         ', label[y_ind, x_ind, 0])
            label[y_ind, x_ind, 0] = 1  # cell索引后,是否存在目标位赋1
            label[y_ind, x_ind, 1:5] = boxes  # 坐标赋值
            label[y_ind, x_ind, 5 + cls_ind] = 1  # 类别赋值


        return label, len(objs)  # 返回label(gt)/以及xml中目标个数

第三个是train.py,这个主要定义了一个主函数接口,将数据和网络丢给这个处理机制(类),就可以进行训练了。

import os
import argparse
import datetime
import tensorflow as tf
import yolo.config as cfg
from yolo.yolo_net import YOLONet
from utils.timer import Timer
from utils.pascal_voc import pascal_voc
slim = tf.contrib.slim

class Solver(object):
    def __init__(self, net, data):
        self.net = net
        self.data = data
        self.weights_file = 'E:\Desktop\yolo_tensorflow-master1\data\pascal_voc\weights\YOLO_small.ckpt'
        self.max_iter = 10000
        self.initial_learning_rate = 0.0001
        self.decay_steps = 30000
        self.decay_rate = 0.1
        self.staircase = True
        self.summary_iter = 10
        self.save_iter = 1000
        self.output_dir = os.path.join(
            'E:\Desktop\yolo_tensorflow-master1\data\pascal_voc\output', datetime.datetime.now().strftime('%Y_%m_%d_%H_%M'))
        if not os.path.exists(self.output_dir):
            os.makedirs(self.output_dir)
        self.save_cfg()

        self.variable_to_restore = tf.global_variables()
        self.saver = tf.train.Saver(self.variable_to_restore, max_to_keep=None)
        self.ckpt_file = os.path.join(self.output_dir, 'yolo')
        self.summary_op = tf.summary.merge_all()
        self.writer = tf.summary.FileWriter(self.output_dir, flush_secs=60)
        # self.global_step疑似一个计数器,好像没什么太大用处吧。比如说下面的train.op和learning_rate用到了global_step
        # 那么作用就是我的train_op每训练一次,或者learning_rate每更新一次,我这个global_step会自动加一,起到一个计数的效果
        self.global_step = tf.train.create_global_step()
        self.learning_rate = tf.train.exponential_decay(
            self.initial_learning_rate, self.global_step, self.decay_steps,
            self.decay_rate, self.staircase, name='learning_rate')
        self.optimizer = tf.train.GradientDescentOptimizer(learning_rate=self.learning_rate)
        self.train_op = slim.learning.create_train_op(self.net.total_loss, self.optimizer, global_step=self.global_step)

        gpu_options = tf.GPUOptions()
        config = tf.ConfigProto(gpu_options=gpu_options)
        self.sess = tf.Session(config=config)
        self.sess.run(tf.global_variables_initializer())

        if self.weights_file is not None:
            print('Restoring weights from: ' + self.weights_file)
            self.saver.restore(self.sess, self.weights_file)

        self.writer.add_graph(self.sess.graph)

    # 定义训练函数
    def train(self):
        train_timer = Timer()  # 计算时间
        load_timer = Timer()  # 计算时间
        for step in range(1, self.max_iter + 1):
            load_timer.tic()  # 计算时间
            images, labels = self.data.get()
            load_timer.toc()  # 计算时间
            feed_dict = {self.net.images: images, self.net.labels: labels}
            # 这一层的if是说,每10次我就把我的信息保存一次到模型里,所以else中的内容就是正常的训练过程,当然这个模型是为了进行可视化的吧
            # 当然大部分肯定是执行的else部分
            if step % self.summary_iter == 0:
                # 这一层的if是说,每100次打印出一次结果
                if step % (self.summary_iter * 10) == 0:
                    train_timer.tic()   # 计算时间
                    # 运行模型、损失函数、和train_op
                    summary_str = self.sess.run(self.summary_op, feed_dict=feed_dict)
                    loss = self.sess.run(self.net.total_loss, feed_dict=feed_dict)
                    _ = self.sess.run(self.train_op, feed_dict=feed_dict)

                    train_timer.toc()   # 计算时间
                    # 打印时间、loss、迭代次数等信息
                    log_str = '''{} Epoch: {}, Step: {}, Learning rate: {}, Loss: {:5.3f}\nSpeed: 
                    {:.3f}s/iter, Load: {:.3f}s/iter, Remain: {}
                    '''.format(
                        datetime.datetime.now().strftime('%m-%d %H:%M:%S'),
                        self.data.epoch,
                        int(step),
                        round(self.learning_rate.eval(session=self.sess), 6),
                        loss,
                        train_timer.average_time,
                        load_timer.average_time,
                        train_timer.remain(step, self.max_iter))
                    print(log_str)
                else:
                    train_timer.tic()    # 计算时间
                    summary_str, _ = self.sess.run([self.summary_op, self.train_op], feed_dict=feed_dict)
                    train_timer.toc()   # 计算时间
                self.writer.add_summary(summary_str, step)

            else:
                train_timer.tic()   # 计算时间
                self.sess.run(self.train_op, feed_dict=feed_dict)
                train_timer.toc()   # 计算时间
            # 每1000次保存一次模型文件,到指定路径
            if step % self.save_iter == 0:
                print('{} Saving checkpoint file to: {}'.format(
                    datetime.datetime.now().strftime('%m-%d %H:%M:%S'),
                    self.output_dir))
                self.saver.save(self.sess, self.ckpt_file, global_step=self.global_step)

    def save_cfg(self):
        with open(os.path.join(self.output_dir, 'config.txt'), 'w') as f:
            cfg_dict = cfg.__dict__
            for key in sorted(cfg_dict.keys()):
                if key[0].isupper():
                    cfg_str = '{}: {}\n'.format(key, cfg_dict[key])
                    f.write(cfg_str)

def update_config_paths(data_dir, weights_file):
    cfg.DATA_PATH = data_dir
    cfg.PASCAL_PATH = os.path.join(data_dir, 'pascal_voc')
    cfg.CACHE_PATH = os.path.join('E:\Desktop\yolo_tensorflow-master1\data\pascal_voc', 'cache')
    cfg.OUTPUT_DIR = os.path.join('E:\Desktop\yolo_tensorflow-master1\data\pascal_voc', 'output')
    cfg.WEIGHTS_DIR = os.path.join('E:\Desktop\yolo_tensorflow-master1\data\pascal_voc', 'weights')
    cfg.WEIGHTS_FILE = os.path.join('E:\Desktop\yolo_tensorflow-master1\data\pascal_voc\weights', weights_file)

def main():
    '''
    虽说这里是main函数,但是实际上他只是定义了一些接口,比如说存放模型文件的路径,是否使用gpu训练等
    就是说定义了下面这些东西,你在运行train.py的时候,通过改变下面定义的这些参数,可以更改一些配置
    而真正的训练过程,在Solver.train()中
    '''
    parser = argparse.ArgumentParser()      # 添加参数的函数,在终端运行train.py时,输入想要改变的东西
    parser.add_argument('--weights', default="YOLO_small.ckpt", type=str)
    parser.add_argument('--data_dir', default="data", type=str)
    parser.add_argument('--threshold', default=0.2, type=float)
    parser.add_argument('--iou_threshold', default=0.5, type=float)
    parser.add_argument('--gpu', default='', type=str)
    args = parser.parse_args()
    # 设置GPU和权重文件路径的
    if args.gpu is not None:
        cfg.GPU = args.gpu
    if args.data_dir != cfg.DATA_PATH:
        update_config_paths(args.data_dir, args.weights)
    os.environ['CUDA_VISIBLE_DEVICES'] = '0'
    # 下面几乎都是各个调用,因为之前定义的都是Class,然后我们要创建实际的对象
    yolo = YOLONet()        # 首先建立网络模型对象
    pascal = pascal_voc('train')    # 然后又建立数据集的对象
    solver = Solver(yolo, pascal)   # 把网络模型和数据集丢给solver
    solver.train()          # 使用train函数进行训练

# 设置一个主函数借口,可以从外部调用
if __name__ == '__main__':
    # python train.py --weights YOLO_small.ckpt --gpu 0
    main()

有可能会有少量表述不正当的地方,仅供参考。

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

CV界的文盲

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值