- sess.run() 的时候还可以[]组合到一起,一组写在一行里面。
rpn_box_loss, rpn_cls_loss, refined_box_loss, refined_cls_loss, mask_loss, \
gt_boxesnp, \
rpn_batch_pos, rpn_batch, refine_batch_pos, refine_batch, mask_batch_pos, mask_batch = \
sess.run([update_op, total_loss, reg_loss, img_id] +
losses +
[gt_boxes] +
batch_info)
- 每1000轮迭代或者最大次数迭代时,保存ckpt
if (step % 1000 == 0 or step + 1 == FLAGS.max_iters) and step != 0:
saver.save(sess,path,glabal_step = step)
- 训练框架
saver = tf.train.Saver(max_to_keep=20) #保存20个ckpt文件(默认是5)
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
with tf.Session(config = config) as sess:
sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer())
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord = coord)
step = 0
summary_writer = tf.summary.FileWriter('output', sess.graph)
while (True):
step += 1
try:
#这里写主程序
if step % 100 == 0:
summary_writer.add_summary(summary_, step)
summary_writer.flush()
save_path = saver.save(sess, os.path.join(design.saver_folder, design.path_ckpt), global_step=step)
if step % 10000 == 0:
coord.request_stop()
coord.join(threads)
except tf.errors.OutOfRangeError:
break
# Finish off the fiqlename queue coordinator.
coord.request_stop()
coord.join(threads)
4.条件判断函数tf.cond(),满足条件执行第一个表达式,不满足条件执行第二个表达式
imsize = tf.size(image)
image = tf.cond(tf.equal(imsize, ih * iw), # tf.cond()条件表达式,满足条件执行第一个表达式,否则执行第二个lambda表达式
lambda: tf.image.grayscale_to_rgb(tf.reshape(image, (ih, iw, 1))),
lambda: tf.reshape(image, (ih, iw, 3)))
5.找label的操作(已知label的shape为(N,))
keeps = tf.where(tf.greater_equal(labels, 0)) #tf.where()函数返回矩阵中true的位置
keeps = tf.reshape(keeps, [-1]) #输出正label的索引(过滤掉了值为负的labels,只输出了值为正的labels)
顺便提一下tf.where()的另一个操作:将bool型的,True OR False值转换成了数字0.0或1.0,加起来求了个和~
res.append(tf.reduce_sum(tf.cast(tf.greater_equal(labels, 0), tf.float32)))
6.拆分bbox坐标:[boxes[:,i] for i in range(5)]
boxes = np.array([[10,10,40,80,0.5],[30,25,65,78,0.8],[25,30,70,80,0.9]])
x1,y1,x2,y2,score = [boxes[:,i] for i in range(5)]
print(x1,y1,x2,y2,score)
#输出:[10. 30. 25.] [10. 25. 30.] [40. 65. 70.] [80. 78. 80.] [0.5 0.8 0.9]
7.python手动垃圾回收:gc.collect()命令。
8.将多个图像拼接成一个图像的方法:
from skimage.util.montage import montage2d as montage
montage_rgb = lambda x: np.stack([montage(x[:, :, :, i]) for i in range(x.shape[3])], -1)