Tensorflow训练过程中validation


Tensorflow因为静态图的原因,边train边validation的过程相较于pytorch来说复杂一些。

载入数据

分别获取训练集和验证集的数据。我这里使用的是从tfrecoed读入数据。

# training data
img_name_batch_train, img_batch_train, gtboxes_and_label_batch_train, num_objs_batch_train, img_h_batch_train, img_w_batch_train = \
next_batch(dataset_name = xxx, ..., is_training = True)

# validation data
img_name_batch_val, img_batch_val, gtboxes_and_label_batch_val, num_objs_batch_val, img_h_batch_val, img_w_batch_val = \
next_batch(dataset_name = xxx, ..., is_training = False)

注意is training

定义is_training占位符

is_trainging = tf.placeholder(tf.bool, shape=())

用一个tf.placeholder来控制是否训练、验证。
使用这种方式就可以在一个graph里创建一个分支条件,从而通过控制placeholder来控制是否进行验证。

用is_training控制图结点唯一

img_name_batch, img_batch, gtboxes_and_label_batch, num_objs_batch, img_h_batch, img_w_batch = \
tf.cond(is_training, lambda:(img_name_batch_train, img_batch_train, gtboxes_and_label_batch_train, num_objs_batch_train, img_h_batch_train, img_w_batch_train), lambda:(img_name_batch_val, img_batch_val, gtboxes_and_label_batch_val, num_objs_batch_val, img_h_batch_val, img_w_batch_val))

如果不适用tf.cond(),会在原图上添加上许多新的结点,这些结点的参数都是需要重新初始化的,也是就是说,验证的时候并不是使用训练的权重。

sess运行

_, global_stepnp, total_loss_dict_ = sess.run([train_op, global_step, total_loss_dict], feed_dict = {is_training:True})

val_loss_list = []
total_loss_dict_ = sess.run(total_loss_dict_, feed_dict={is_training: False})
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值