Tensorflow YOLO代码解析(4)

下面介绍训练和测试代码,训练代码主要graph构建,加载预训练模型,训练中的数据读取和保存相关日志和模型文件等内容,测试代码主要部分是模型预测结果格式的转换。

其他相关的部分请见:
YOLO代码解析(1) 代码总览与使用
YOLO代码解析(2) 数据处理
YOLO代码解析(3) 模型和损失函数
YOLO代码解析(4) 训练和测试代码

训练相关代码:yolo_solver.py

def _train(self):
    """训练模型
    创建优化器,最小化Loss
    Args:
      total_loss: Total loss from net.loss()
      global_step: Integer Variable counting the number of training steps
      processed
    Returns:
      train_op: op for training
    """
    # 使用Momentum优化算法
    opt = tf.train.MomentumOptimizer(self.learning_rate, self.moment)
    grads = opt.compute_gradients(self.total_loss)

    apply_gradient_op = opt.apply_gradients(grads, global_step=self.global_step)

    # 这里也可以直接写成
    # tf.train.MomentumOptimizer(self.learning_rate,self.moment).minimize(self.total_loss,global_step=self.global_step)

    return apply_gradient_op

  def construct_graph(self):
    # 构建graph
    self.global_step = tf.Variable(0, trainable=False)
    # (1)训练时网络的输入
    self.images = tf.placeholder(tf.float32, (self.batch_size, self.height, self.width, 3))
    self.labels = tf.placeholder(tf.float32, (self.batch_size, self.max_objects, 5))
    self.objects_num = tf.placeholder(tf.int32, (self.batch_size))

    # (2)inference部分,输入是一张图片,输出是一个(N,cell_size,cell_size,class_num+box_num*5)的tensor
    self.predicts = self.net.inference(self.images)

    # (3)loss 部分
    self.total_loss = self.net.loss(self.predicts, self.labels, self.objects_num)
    
    tf.summary.scalar('loss', self.total_loss)
    self.train_op = self._train()

  def solve(self):
    saver1 = tf.train.Saver(self.net.pretrained_collection, write_version=1)
    saver2 = tf.train.Saver(self.net.trainable_collection, write_version=1)

    # 变量初始化
    init =  tf.global_variables_initializer()

    summary_op = tf.summary.merge_all()

    sess = tf.Session()
    sess.run(init)

    # 加载预训练模型
    saver1.restore(sess, self.pretrain_path)

    # 创建 event file writer
    summary_writer = tf.summary.FileWriter(self.train_dir, sess.graph)

    for step in range(self.max_iterators):
      start_time = time.time()
      # 获取一个batch的训练数据
      np_images, np_labels, np_objects_num = self.dataset.batch()

      _, loss_value = sess.run([self.train_op, self.total_loss], feed_dict={self.images: np_images, self.labels: np_labels, self.objects_num: np_objects_num})


      duration = time.time() - start_time

      assert not np.isnan(loss_value), 'Model diverged with loss = NaN'

      if step % 10 == 0:
        num_examples_per_step = self.dataset.batch_size
        examples_per_sec = num_examples_per_step / duration
        sec_per_batch = float(duration)

        format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f sec/batch)')
        print (format_str % (datetime.now(), step, loss_value,examples_per_sec, sec_per_batch))

        sys.stdout.flush()
      if step % 100 == 0: # 保存event file
        summary_str = sess.run(summary_op, feed_dict={self.images: np_images, self.labels: np_labels, self.objects_num: np_objects_num})
        summary_writer.add_summary(summary_str, step)
      if step % 5000 == 0: # 保存checkpoint
        saver2.save(sess, self.train_dir + '/model.ckpt', global_step=step)
    sess.close()

测试相关代码:demo.py

# 对网络给出的预测结果做处理
def process_predicts(predicts):
    # predicts 的shape是 (N,grid_size,grid_size,30), 30=(4+1)*2+20
    p_classes = predicts[0, :, :, 0:20] # 类别的概率
    C = predicts[0, :, :, 20:22]        # Bbox中有物体的概率
    coordinate = predicts[0, :, :, 22:] # 预测的Bbox坐标
    print(predicts.shape)

    p_classes = np.reshape(p_classes, (7, 7, 1, 20))
    C = np.reshape(C, (7, 7, 2, 1))

    # P = 有物体的概率 * 类别的概率
    P = C * p_classes
    print(P.shape)

    # 找到有最大的概率P的Bbox
    index = np.argmax(P)
    index = np.unravel_index(index, P.shape)

    class_num = index[3]

    coordinate = np.reshape(coordinate, (7, 7, 2, 4))

    max_coordinate = coordinate[index[0], index[1], index[2], :]

    # 对网络输出的坐标值进行处理
    # 网络输出的Bbox的中心坐标是相对于格子左上角的坐标,并且用格子的宽度进行归一化(偏移+归一化),这里需要处理成在原图中的坐标
    # 网络输出的Bbox的宽高是相对于图片大小归一化的,这里也要恢复成原始大小
    xcenter = max_coordinate[0]
    ycenter = max_coordinate[1]
    w = max_coordinate[2]
    h = max_coordinate[3]

    # ‘恢复’中心坐标:反偏移,反归一化
    xcenter = (index[1] + xcenter) * (448/7.0)
    ycenter = (index[0] + ycenter) * (448/7.0)
    # ‘恢复’宽高到原始像素大小
    w = w * 448
    h = h * 448

    xmin = xcenter - w/2.0
    ymin = ycenter - h/2.0

    xmax = xmin + w
    ymax = ymin + h

    # 这里检测部分写的比较‘简单’,直接取了物体概率*类别概率最大的那个Bbox和class的结果
    # 实际上应该对每一个类分别进行检测,并用NMS去除多余的候选框
    return xmin, ymin, xmax, ymax, class_num


common_params = {'image_size': 448, 'num_classes': 20, 'batch_size':1}
net_params = {'cell_size': 7, 'boxes_per_cell':2, 'weight_decay': 0.0005}

# network,input place holder and output tensor
net = YoloTinyNet(common_params, net_params, test=True)
image = tf.placeholder(tf.float32, (1, 448, 448, 3))
predicts = net.inference(image)

sess = tf.Session()

# 读入图片
np_img = cv2.imread('cat.jpg')
height, width, channels = np_img.shape
print(height, width, channels)


# 对图片作处理,尺寸缩放,值映射到[-1,1]
resized_img = cv2.resize(np_img, (448, 448))
np_img = cv2.cvtColor(resized_img, cv2.COLOR_BGR2RGB)
np_img = np_img.astype(np.float32)
np_img = np_img / 255.0 * 2 - 1
np_img = np.reshape(np_img, (1, 448, 448, 3))

# 加载模型,并做前向传播得到检测结果
saver = tf.train.Saver()
saver.restore(sess, 'models/pretrain/yolo_tiny.ckpt')
np_predict = sess.run(predicts, feed_dict={image: np_img})

xmin, ymin, xmax, ymax, class_num = process_predicts(np_predict)
class_name = classes_name[class_num]
cv2.rectangle(resized_img, (int(xmin), int(ymin)), (int(xmax), int(ymax)), (0, 0, 255))
cv2.putText(resized_img, class_name, (int(xmin), int(ymin)), 2, 1.5, (0, 0, 255))
cv2.imwrite('cat_out.jpg', resized_img)
sess.close()

其他没有提到的部分代码请见完整代码
另外对代码中涉及到的一些TensorFlow的函数的使用做了一个简单的整理,详见tensorflow函数

  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值