#!/usr/bin/env python # -*- coding: utf-8 -*- import os import tensorflow as tf from PIL import Image import matplotlib.pyplot as plt import numpy as np # 制作TFRecord文件 def makeTFRecord(cwd, classes, fileNames): """ 这个方法是将我们的图片数据集转化成一个TFRcord格式的 二进制文件 输入: - cwd: 分好类文件存放的路径 - classes: 图片的种类 - fileNames:要生成的文件名 返回值: 无 """ #每个文件最大图片数 bestnum = 1000 #第几张图片 num = 0 #第几个文件名 recordfilenum = 0 fileName = "%s/%s" % (cwd,fileNames[recordfilenum]) # print(fileName) writer = tf.python_io.TFRecordWriter(fileName) # 遍历所有类别 for index, name in enumerate(classes): class_path = "%s/%s" % (cwd, name) # 某类图片的路径 # 遍历某个类别中的所有文件 for img_name in os.listdir(class_path): num += 1 if num > bestnum: num = 1 recordfilenum += 1 fileName = "%s/%s" % (cwd, fileNames[recordfilenum]) # print(fileName) writer = tf.python_io.TFRecordWriter(fileName) img_path = "%s/%s" % (class_path, img_name) # 某类图片中某张图片的路径 print(img_path) img = Image.open(img_path) img = img.resize((128,128)) img_raw = img.tobytes() # 将图片转成二进制 example= tf.train.Example(features = tf.train.Features(feature={ 'label': tf.train.Feature(int64_list = tf.train.Int64List(value = [index])), 'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw])) })) # example对象对label和image数据进行封装 writer.write(example.SerializeToString()) # 序列化为字符串 writer.close() # 将二进制文件读入图中 def read_and_decode(cwd ,fileName): """ 这个函数是要将TFRecord文件中的数据提取出来,放入计算图中 输入: - cwd: 文件存放的路径 - fileName: 文件名字(需要加后缀) 返回值: - img: 128*128的3通道图片张量 - label: 对应的标签 """ fileNames = [cwd + str(i) for i in fileName] # 生成每个文件的路径 fileName_Queue = tf.trian.string_input_producer([fileNames])# 生成一个文件队列 reader = tf.TFRcordReader() _, serialized_example = reader.read(fileName_Queue)# 返回文件名和文件 features = tf.parse_sigle_example(serialized_example, features = { 'label': tf.FixedLenFeature([], tf.int64), 'img_raw': tf.FixedLenFeature([], tf.string) })# 将image数据个label提取出来 img = tf.decode_raw(featrues['img_raw'], tf.uint8) img = tf.reshape(img, [128, 128, 3]) # 将图片的reshape为128*128的3通道图片 img = tf.cast(img, tf.float32) * (1.0 / 255) - 0.5 # 在流中抛出img张量 label = tf.cast(featrues['label'], tf.int32) return img, label def dispaly_image(cwd ,save_dir,fileName,classes): """ 这个函数是要将TFRecord文件中的数据提取出来,生成图片并保存 输入: - cwd: tfrecord文件存放的路径 - save_dir: 图片文件保存路径 - fileName: 文件名字(需要加后缀) - classes: 图片的种类 返回值: 无 """ # 创建文件存放目录 for i in range(len(classes)): dir = "%s/%s" % (save_dir, i) if not os.path.exists(dir): print("目录 %s 不存在,自动创建中..." % (dir)) os.makedirs(dir) # 生成每个文件的路径 fileNames = [cwd +"/"+ str(i) for i in fileName] fileName_Queue = tf.train.string_input_producer(fileNames, shuffle=True)# 生成一个文件队列 reader = tf.TFRecordReader() _, serialized_example = reader.read(fileName_Queue)# 返回文件名和文件 features = tf.parse_single_example(serialized_example, features={ 'label': tf.FixedLenFeature([], tf.int64), 'img_raw': tf.FixedLenFeature([], tf.string) })# 取出包含image和label的feature对象 img = tf.decode_raw(features['img_raw'], tf.uint8) img = tf.reshape(img, [128, 128, 3]) # 将图片的reshape为128*128的3通道图片 label = tf.cast(features['label'], tf.int32) #print(img.shape) with tf.Session() as sess: #开始一个会话 init_op = tf.initialize_all_variables() sess.run(init_op) coord=tf.train.Coordinator() threads= tf.train.start_queue_runners(coord=coord) #print("----------",sess.run(img.shape),sess.run(label.shape)) for i in range(4): example, l = sess.run([img, label]) # 在会话中取出image和label #print (example.shape, l.shape) # 变量名同名的话要注意 #img = Image.fromarray(exaple, 'RBG') image=Image.fromarray(example, 'RGB')#这里Image是之前提到的 path = "%s/%s/%s.jpg" % (save_dir,l,i) image.save(path)#存下图片 #print(example, l) coord.request_stop() coord.join(threads) if __name__ == "__main__": # 图片分类存放的源路径 cwd = "D:/1500130226/ml/cnn/cnn_tfrecord/data/image" # 解析tfrecord文件后图片存放的路径 save_dir = "D:/1500130226/ml/cnn/cnn_tfrecord/data" # 图片的类别 classes = {'niu','gou'} # 每一千张图片需要一个文件名 fileName = ["train0.tfrecords"] # 将图片转成tfrecord文件(图片会变成120*120*3的格式) makeTFRecord(cwd, classes, fileName) # 将tfrecord文件转成图片(图片会变成120*120*3的格式)可以自己在源码中修改图片大小 dispaly_image(cwd, save_dir,fileName,classes)
tensorflow之图片与tfrecord之间的转化
最新推荐文章于 2024-08-25 13:36:07 发布