tensorflow之图片与tfrecord之间的转化

#!/usr/bin/env python
# -*- coding: utf-8 -*-

import os
import tensorflow as tf
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np

# 制作TFRecord文件
def makeTFRecord(cwd, classes, fileNames):
   """
   这个方法是将我们的图片数据集转化成一个TFRcord格式的
   二进制文件

   输入:
   - cwd: 分好类文件存放的路径
   - classes: 图片的种类
   - fileNames:要生成的文件名

   返回值:
   """
   #每个文件最大图片数
   bestnum = 1000
   #第几张图片
   num = 0
   #第几个文件名
   recordfilenum = 0
   fileName = "%s/%s" % (cwd,fileNames[recordfilenum])
   # print(fileName)
   writer = tf.python_io.TFRecordWriter(fileName)
   # 遍历所有类别
   for index, name in enumerate(classes):
      class_path = "%s/%s" % (cwd, name) # 某类图片的路径
      # 遍历某个类别中的所有文件
      for img_name in os.listdir(class_path):
         num += 1
         if num > bestnum:
            num = 1
            recordfilenum += 1
            fileName = "%s/%s" % (cwd, fileNames[recordfilenum])
            # print(fileName)
            writer = tf.python_io.TFRecordWriter(fileName)
         img_path = "%s/%s" % (class_path, img_name) # 某类图片中某张图片的路径
         print(img_path)
         img = Image.open(img_path)
         img = img.resize((128,128))
         img_raw = img.tobytes() # 将图片转成二进制
         example= tf.train.Example(features = tf.train.Features(feature={
            'label': tf.train.Feature(int64_list = tf.train.Int64List(value = [index])),
            'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))
         })) # example对象对label和image数据进行封装
         writer.write(example.SerializeToString()) # 序列化为字符串

   writer.close()

# 将二进制文件读入图中
def read_and_decode(cwd ,fileName):
   """
   这个函数是要将TFRecord文件中的数据提取出来,放入计算图中
   输入:
   - cwd: 文件存放的路径
   - fileName: 文件名字(需要加后缀)

   返回值:
   - img: 128*128的3通道图片张量
   - label: 对应的标签
   """
   fileNames = [cwd + str(i) for i in fileName] # 生成每个文件的路径
   fileName_Queue = tf.trian.string_input_producer([fileNames])# 生成一个文件队列

   reader = tf.TFRcordReader()
   _, serialized_example = reader.read(fileName_Queue)# 返回文件名和文件
   features = tf.parse_sigle_example(serialized_example,
                        features = {
                           'label': tf.FixedLenFeature([], tf.int64),
                           'img_raw': tf.FixedLenFeature([], tf.string)
                        })# 将image数据个label提取出来
   img = tf.decode_raw(featrues['img_raw'], tf.uint8)
   img = tf.reshape(img, [128, 128, 3]) # 将图片的reshape为128*128的3通道图片
   img = tf.cast(img, tf.float32) * (1.0 / 255) - 0.5 # 在流中抛出img张量
   label = tf.cast(featrues['label'], tf.int32)

   return img, label

def dispaly_image(cwd ,save_dir,fileName,classes):
   """
   这个函数是要将TFRecord文件中的数据提取出来,生成图片并保存
   输入:
   - cwd: tfrecord文件存放的路径
   - save_dir: 图片文件保存路径
   - fileName: 文件名字(需要加后缀)
   - classes: 图片的种类
   返回值:
   """
   # 创建文件存放目录
   for i in range(len(classes)):
      dir = "%s/%s" % (save_dir, i)
      if not os.path.exists(dir):
         print("目录 %s 不存在,自动创建中..." % (dir))
         os.makedirs(dir)
   # 生成每个文件的路径
   fileNames = [cwd +"/"+ str(i) for i in fileName]
   fileName_Queue = tf.train.string_input_producer(fileNames, shuffle=True)# 生成一个文件队列

   reader = tf.TFRecordReader()
   _, serialized_example = reader.read(fileName_Queue)# 返回文件名和文件
   features = tf.parse_single_example(serialized_example,
                        features={
                           'label': tf.FixedLenFeature([], tf.int64),
                           'img_raw': tf.FixedLenFeature([], tf.string)
                        })# 取出包含image和label的feature对象
   img = tf.decode_raw(features['img_raw'], tf.uint8)
   img = tf.reshape(img, [128, 128, 3]) # 将图片的reshape为128*128的3通道图片
   label = tf.cast(features['label'], tf.int32)
   #print(img.shape)



   with tf.Session() as sess: #开始一个会话
      init_op = tf.initialize_all_variables()
      sess.run(init_op)
      coord=tf.train.Coordinator()
      threads= tf.train.start_queue_runners(coord=coord)

      #print("----------",sess.run(img.shape),sess.run(label.shape))
      for i in range(4):
         example, l = sess.run([img, label])  # 在会话中取出image和label
         #print (example.shape, l.shape)
         # 变量名同名的话要注意
         #img = Image.fromarray(exaple, 'RBG')
         image=Image.fromarray(example, 'RGB')#这里Image是之前提到的
         path = "%s/%s/%s.jpg" % (save_dir,l,i)
         image.save(path)#存下图片
         #print(example, l)
      coord.request_stop()
      coord.join(threads)


if __name__ == "__main__":
   # 图片分类存放的源路径
   cwd = "D:/1500130226/ml/cnn/cnn_tfrecord/data/image"
   # 解析tfrecord文件后图片存放的路径
   save_dir = "D:/1500130226/ml/cnn/cnn_tfrecord/data"
   # 图片的类别
   classes = {'niu','gou'}
   # 每一千张图片需要一个文件名
   fileName = ["train0.tfrecords"]
   # 将图片转成tfrecord文件(图片会变成120*120*3的格式)
   makeTFRecord(cwd, classes, fileName)
   # 将tfrecord文件转成图片(图片会变成120*120*3的格式)可以自己在源码中修改图片大小
   dispaly_image(cwd, save_dir,fileName,classes)
 

评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值