OpenPose -tensorflow代码解析(4)—— 训练脚本 train.py

前言

该openpose-tensorflow的工程是自己实现的,所以有些地方写的会比较简单,但阅读性强、方便使用。

论文翻译 || openpose – Realtime Multi-Person 2D Pose Estimation using Part Affinity Fields
工程实现 || 基于opencv使用openpose完成人体姿态估计

OpenPose -tensorflow代码解析(1)——工程概述&&训练前的准备
OpenPose -tensorflow代码解析(2)—— 数据增强和处理 dataset.py
OpenPose -tensorflow代码解析(3)—— 网络结构的搭建 Net.py
OpenPose -tensorflow代码解析(4)—— 训练脚本 train.py
OpenPose -tensorflow代码解析(5)—— 预测代码解析 predict.py

1 训练脚本train.py解析

将openpose 的训练定义成了一个类,初始化+训练操作

初始化:

  • 基础参数的设置:训练轮次、batchsize、保存路径、session的实例
  • placehoder、network、lossfunction、learningrate、optimizer、summary、saver
    滑动平均 tf.train.ExponentialMovingAverage这里没有使用,如需添加优化 参考 我的yolov3训练详解

训练:

  • 根据提供的路径,查看是否具有已训练的权重进行恢复
  • 从队列中获取数据,进行 sess.run()。
    其中数据读取,是在 dataset.py脚本中定义的,实现了多进程往队列中读取数据
  • 保存模型、训练日志
# from progress.bar import Bar, FillingCirclesBar, ChargingBar

from eval import *
from dataset import *
from NET import *
from opt import *

import time
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"


class OpenPoseTrain(object):
   def __init__(self):

       self.start_epoch = 0    # 设置初始轮次,默认为0,当自己训练中断后,恢复模型,该值会更新
       self.epoch     = cfg.TRAIN.first_stage_epoch    # 训练的总轮次
       self.batchsize = cfg.TRAIN.batch_size
       self.lr        = cfg.TRAIN.learn_rate_init
       self.cpm_num   = cfg.OP.cpm_num + 1            # 关键点热量图的张数
       self.paf_num   = cfg.OP.paf_num                       #  亲和域热量图的张数

       self.checkpoint_dir = cfg.TRAIN.ckpt_path
       self.log_path       = cfg.TRAIN.log_path

       self.trainset  = Dataset('train')
       self.testset   = Dataset('test')

       self.num_step_one_epoch       = len(self.trainset)
       self.num_step_one_epoch_valid = len(self.testset)

       config = tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)
       config.gpu_options.allow_growth = True
       self.sess = tf.Session(config=config)

       ##======== 设置 placehoder======================================================
       with tf.name_scope('define_input'):
           self.input_node = tf.placeholder(tf.float32, shape=(None, None, None, 3), name='image')
           self.heatmap_node = tf.placeholder(tf.float32, shape=(None, None, None, self.cpm_num), name='heatmap')
           self.vectmap_node = tf.placeholder(tf.float32, shape=(None, None, None, self.paf_num), name='vectmap')
           
       ##======== 设置 网络结构 以及loss ================================================
       with tf.name_scope("define_loss"):
           with tf.device(tf.DeviceSpec(device_type="GPU")):
               self.net = OpenPose(self.input_node, True)
               self.total_loss, self.total_loss_paf, self.total_loss_heat, self.total_loss_ll = \
                   self.net.loss_layer(self.heatmap_node, self.vectmap_node)

       ##======== 设置 学习率的衰减方式=================================================
       self.global_step = tf.Variable(0, trainable=False, name='global_step')
       with tf.name_scope('learn_rate'):

           case = "case1"
           if case == "case1":
               print("learn rate: ", self.lr)
               starter_learning_rate = float(self.lr)
               self.learning_rate = tf.train.exponential_decay(starter_learning_rate,
                                                          self.global_step,
                                                          decay_steps=800,
                                                          decay_rate=0.96,
                                                          staircase=True)
           else: 
               print('yes')
               lrs = [float(x) for x in self.lr.split(',')]
               boundaries = [self.num_step_one_epoch * 5 * i for i, _ in range(len(lrs)) if i > 0]
               self.learning_rate = tf.train.piecewise_constant(self.global_step, boundaries, lrs)

       ##======== 设置 优化器 ============================================================
       with tf.name_scope("train_stage_1"):

           optimizer = tf.train.AdamOptimizer(self.learning_rate, epsilon=1e-8)
           train_op = optimizer.minimize(self.total_loss, self.global_step, colocate_gradients_with_ops=True)

           update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
           with tf.control_dependencies(update_ops):
               with tf.control_dependencies([train_op]):
                   self.train_op_1 = tf.no_op()
                   
       ##======== 设置 summary ============================================================
       with tf.name_scope('summary'):
           tf.summary.scalar("loss", self.total_loss)
           tf.summary.scalar("loss_lastlayer", self.total_loss_ll)
           tf.summary.scalar("loss_lastlayer_paf", self.total_loss_paf)
           tf.summary.scalar("loss_lastlayer_heat", self.total_loss_heat)
           tf.summary.scalar("lr", self.learning_rate)
           self.merged = tf.summary.merge_all()

           self.train_writer = tf.summary.FileWriter(self.log_path + '/train', self.sess.graph)
           self.valid_writer = tf.summary.FileWriter(self.log_path + '/valid')

       ##======== 设置 Saver    ============================================================
       with tf.name_scope('loader_and_saver'):
           self.saver = tf.train.Saver(var_list=tf.global_variables(), max_to_keep=500)
           

   def train(self):
   
       # set random seed
       # tf.set_random_seed(-1)

       ##======== 从队列中获取数据 ============================================================
       
       Q_traindata = self.trainset.start(3)
       Q_validdata = self.testset.start(2)

       ##======== 从路径中恢复权重,如果没有就不恢复 ============================================================
       self.sess.run(tf.global_variables_initializer())
       if self.checkpoint_dir and os.path.isdir(self.checkpoint_dir):
           print('model restore =================')
           self.saver.restore(self.sess, tf.train.latest_checkpoint(self.checkpoint_dir))
           self.start_epoch = self.sess.run(self.global_step)//self.num_step_one_epoch
           print("继续训练的轮次为: ",  self.start_epoch)
           print("继续训练的开始学习率为: ",  self.sess.run(self.learning_rate))

       ##======== 训练、保存模型、保存日志文件 ============================================================
       for epo in range(self.start_epoch, self.epoch):

           for step in range(self.num_step_one_epoch):
               feed_dict_data = {}
               name, _, input_data, label_cmp, label_paf = Q_traindata.get()
               feed_dict_data[self.input_node] = input_data
               feed_dict_data[self.heatmap_node] = label_cmp
               feed_dict_data[self.vectmap_node] = label_paf

               merged_train, _, gs_num = self.sess.run([self.merged, self.train_op_1, self.global_step], feed_dict=feed_dict_data)
               self.train_writer.add_summary(merged_train, gs_num)

               net_out = self.net.CPM[-1]
               heamap = self.sess.run(net_out, feed_dict=feed_dict_data)
               train_acc = accuracy(heamap, label_cmp)
               train_summary = tf.Summary(value=[tf.Summary.Value(tag="accuracy", simple_value=train_acc)])
               self.train_writer.add_summary(train_summary, gs_num)

               loss, loss_ll, loss_ll_paf, loss_ll_heat, lr_val = \
                   self.sess.run([self.total_loss, self.total_loss_paf, self.total_loss_heat, self.total_loss_ll, self.learning_rate], feed_dict=feed_dict_data)
               print('epoch=%d step=%d || lr=%f, loss=%g, loss_ll=%g, '
                         'loss_ll_paf=%g, loss_ll_heat=%g, acc=%g' %
                         (epo, step, lr_val, loss, loss_ll, loss_ll_paf, loss_ll_heat, train_acc))

           self.saver.save(self.sess, self.checkpoint_dir+'model', global_step=epo)

           ACC_ALL = 0
           for step in range(self.num_step_one_epoch_valid):
               valid_dict_data = {}
               name, _, input_data, label_cmp, label_paf = Q_validdata.get()

               valid_dict_data[self.input_node] = input_data
               valid_dict_data[self.vectmap_node] = label_paf
               valid_dict_data[self.heatmap_node] = label_cmp
               merged_value = self.sess.run(self.merged, feed_dict=valid_dict_data)
               loss, loss_ll, loss_ll_paf, loss_ll_heat = \
                   self.sess.run([self.total_loss, self.total_loss_paf, self.total_loss_heat, self.total_loss_ll], feed_dict=valid_dict_data)
               self.valid_writer.add_summary(merged_value, gs_num+step)

               net_out = self.net.CPM[-1]
               heamap = self.sess.run(net_out, feed_dict=valid_dict_data)
               acc = accuracy(heamap, label_cmp)
               ACC_ALL = ACC_ALL + acc/self.num_step_one_epoch_valid

               print('========================================================================')
               print('epoch=%d step=%d, loss=%g, loss_ll=%g, '
                     'loss_ll_paf=%g, loss_ll_heat=%g, acc=%g' %
                     (epo, step, loss, loss_ll, loss_ll_paf, loss_ll_heat, acc))

           summary = tf.Summary(value=[tf.Summary.Value(tag="accuracy", simple_value=ACC_ALL)])
           self.valid_writer.add_summary(summary, gs_num)


if __name__ == '__main__':
   TrainModel = OpenPoseTrain()
   TrainModel.train()

2 eval.py 解析

在训练过程、模型预测过程中,我们需要对 heatmap 进行解析。

  • 热量图解析:由于我们工程的前提,就是单张图片中只会出现1个目标,所以解析热量图变得十分简单。只需要对关键点的热量图进行解析: 对于每张关键点热量图,查找到像素值最高的一个坐标即可。
  • 精度计算:将网络的label、output分别解析,然后计算存在的关键点 与对应预测的关键点 的偏差的平均值即可

import numpy as np

def get_preds(hm, return_conf=False):

   w = hm.shape[2]
   hm = hm.reshape(hm.shape[0], hm.shape[1]*hm.shape[2],hm.shape[3])
   idx = np.argmax(hm, axis=1)

   preds = np.zeros((hm.shape[0], hm.shape[2], 2))
   for i in range(hm.shape[0]):
       for j in range(hm.shape[2]):
           preds[i,j,0], preds[i,j,1] = idx[i,j] % w, idx[i,j] // w
   if return_conf:
       conf = np.amax(hm, axis=2).reshape(hm.shape[0], hm.shape[1],1)
       return preds, conf
   else:
       return preds

def calc_dists(preds, gt,):
   dists = np.zeros((preds.shape[0], preds.shape[1]))
   for i in range(preds.shape[0]):
       for j in range(preds.shape[1]):
           if gt[i,j,0]>0 and gt[i, j, 1]>0:
               dists[i][j] = ((gt[i,j] - preds[i][j]) ** 2).sum() ** 0.5
           else:
               dists[i][j] = -1
   return dists

def accuracy(output, target):

   ## hm.shape = [batch, size0, size1, channels]
   ## preds.shape = [batch, joint_num, 2]
   ## dists.shape = [batch, joint_num]
   
   preds = get_preds(output)   # 预测的结果
   gt = get_preds(target)      # label
   dists = calc_dists(preds, gt)[:, 0:-1]  #返回的dists是数组,里面是每个关节点之间的偏差
   
   idx = np.where(dists!=-1)
   acc = np.mean(dists[idx[0],idx[1]])
   return acc
   
  • 2
    点赞
  • 10
    收藏
    觉得还不错? 一键收藏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值