TensorFlow读取目录下的图片

TensorFlow读取目录下的图片,目录参数必须是list,然后转成tf.string

path = tf.convert_to_tensor(path, dtype=tf.string)

通过tf.train.string_input_producer生成文件队列,shuffle表示是否打乱图片的顺序, num_epochs表示图片被加载几次,一般在训练的过程有epoch的概念,就是训练集被计算几轮。

file_queue = tf.train.string_input_producer(path, shuffle=True, num_epochs=2)

接下来使用tf.WholeFileReader,生成image_reader, 该image_reader调用read函数读取文件队列的图片内容,返回key,value,其中value保存了所有的训练图像数据

image_reader = tf.WholeFileReader()
key, image = image_reader.read(file_queue)

接下来解码图像数据,tf.image.decode_jpeg,根据图片格式有decode_jpeg、decode_bmp等

image = tf.image.decode_jpeg(image)

最后就是启动实例,创建线程,获取图像数据

with tf.Session() as sess:
	sess.run(tf.local_variables_initializer())
	coord = tf.train.Coordinator()
	threads = tf.train.start_queue_runners(sess=sess, coord = coord)
	try:
	    while not coord.should_stop():
		    plt.figure
		    plt.imshow(image.eval())
		    plt.show()
	except tf.errors.OutOfRangeError:
		print ('done')
	finally:
		coord.request_stop()
	coord.join(threads)

图片显示使用matplotlib.pyplot模块,安装步骤

sudo apt-get install python-tk
sudo pip2 install -i https://pypi.tuna.tsinghua.edu.cn/simple  matplotlib

最后显示图像如下:

最后贴上调试后的完整代码:

#-*- encoding:utf-8 -*-

import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
import os
print (tf.__version__)

def readimg(file_path):
	image_raw = tf.gfile.FastGFile(file_path).read()
	img = tf.image.decode_jpeg(image_raw) #Tensor

	with tf.Session() as sess:
		print (type(image_raw))
		print (type(img))

		print (sess.run(img))

		plt.figure(1)
		plt.imshow(sess.run(img))
		plt.show()
def file_name(file_dir):
	for root, dirs, files in os.walk(file_dir):
		print (root)
		print (dirs)
		print (files)
def file_name2(file_dir):
	L = []
	for root, dirs, files in os.walk(file_dir):
		for file in files:
			if os.path.splitext(file)[1] == '.jpg':
				L.append(os.path.join(root, file))
	print root
	return L

def readimg2(path):
	nlen = len(path)
	path = tf.convert_to_tensor(path, dtype=tf.string)
	file_queue = tf.train.string_input_producer(path, shuffle=True, num_epochs=2)
	image_reader = tf.WholeFileReader()
	key, image = image_reader.read(file_queue)
	image = tf.image.decode_jpeg(image)

	with tf.Session() as sess:
		sess.run(tf.local_variables_initializer())
		coord = tf.train.Coordinator()
		threads = tf.train.start_queue_runners(sess=sess, coord = coord)
		try:
			while not coord.should_stop():

				plt.figure
				plt.imshow(image.eval())
				plt.show()
		except tf.errors.OutOfRangeError:
			print ('done')
		finally:
			coord.request_stop()
		coord.join(threads)

def load_img(path_queue):
	reader = tf.WholeFileReader()
	key, value = reader.read(path_queue)

	img = tf.image.decode_jpeg(value, channels=3)
	img = tf.reshape(img,shape=(224,224,3))
	return img

path = file_name2('/home/jyf/jyf/python/loadimage')

#readimg(path[0])
print path[0]
readimg2(path)

 

展开阅读全文

没有更多推荐了,返回首页