tensorflow 图片批处理--- tf.train.batch

当我们使用tensorflow进行深度学习时,进行训练模型时,我们往往要读取大量的图片进行批处理输入模型进行训练.
如果我们一次性读取全部图片或者过多张图片,内存将有可能溢出.
如果我们一次读取小批量图片,再将图片转换成tensor,然后再输入模型,则随着模型的迭代次数增大,内存占用将越来越大,最终内存溢出.如下代码:
sess=tf.Session()
ImgFiles= ***** (包括所有训练集图片的文件名)
for imgFile in imgFiles:
img=scipy.misc.imread(imgFile) #读取图片
img=tf.convert_to_tensor(img,dtype='float32') #将图片转化成tensor
img=preprocessing(img) #图片预处理
res=net(img) #将图片输入网络模型进行训练,得出结果
如上代码,因为tensor结点是不会自动回收的,即使你变量名被覆盖,原来的tensor结点依然占用内存,最终内存占用将越来越大,所以不要在循环里面生成tensor.
可通过如下方法检测是否不断生成计算节点
在sess里面,循环外面,使用graph.finalize()锁定graph.如果运行时保存,则说明有计算节点加入.
所以,我们要使用tf.train.batch进行图片读取训练训练.
代码如下:
def read(Path):
filenames = [join(Path, f) for f in listdir(Path) if isfile(join(Path, f))] #Path为图片训练集的文件夹路径,返回的是所有训练集图片的路径的集合
filename_queue = tf.train.string_input_producer(filenames, shuffle=True, num_epochs=10) #将图片产生一个队列,可控制是否排序,图片的迭代次数
reader = tf.WholeFileReader() #产生一个读取器reader
_, img_bytes = reader.read(filename_queue) #将队列输入读取器reader当中,读取序列
image = tf.image.decode_png(img_bytes, channels=3) #对序列解码,现在image还是一张图片,为tensor
image=preprocessing(image) #对图片进行预处理
image=tf.train.batch([image], 2, dynamic_pad=True) #将图片合并生成一个批次,第二个参数2是控制这个批次包含多少张图片.

with tf.Graph().as_default() as g:
with tf.Session() as sess: #协调器要求在with tf.Session() as sess 里面使用.
img=read(Path)
coord = tf.train.Coordinator() #创建一个协调器,管理线程
threads = tf.train.start_queue_runners(coord=coord)
try:
while not coord.should_stop():
#进行模型训练
finally:
coord.request_stop()
coord.join(threads)

  • 1
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

程序猿也可以很哲学

让我尝下打赏的味道吧

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值