这是笔者项目中用到的部分代码:
with tf.Session() as sess:
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
for i in range(TRAIN_NUM*5):
# ! 这里注意,不能在迭代中进行增加节点的操作,例如 tf.cast, tf.convert等等,否则计算机会因为节点数量过多而溢出。
# ! 这里,返回的结果应当用同一个sess.run()得到,不要在同一个循环中两次调用sess.run(),否则在多线程中可能会导致混乱。
image, label, file_name, image_encode_eval = sess.run([images_decode, labels, file_names, image_encode])
file_name = str(file_name, encoding='utf8').replace('.jpg', '')
# 转换后的展示图片
# plt.imshow(image_transfer.eval())
# plt.title(file_name)
# plt.show()
with tf.gfile.GFile(TRAIN_SUPPLEMENT_PATH+file_name+'_'+str(i)+'.jpg', 'wb') as f:
f.write(image_encode_eval)
if i%100==0:
print('step %d, pic: %s' % (i,file_name))
coord.request_stop()
coord.join(threads)
- 这里是对图像进行处理的迭代过程,用到了多线程;
- 在迭代过程中,有两点需要注意:
- 1 不能在迭代中使用tf.cast() 或者 tf.convert() 等会增加图节点的方法,否则随着迭代的进行,tf的图会越来越大,最终导致溢出;
- 2 在每轮的迭代中,用sess.run()一次性返回所有需要的结果,不要多次调用sess.run(),因为在同一轮迭代中调用多次sess.run(),意味着会从多个线程中取值,这回导致数据的混乱。