tensorflow代码学习:图像转换tfrecords

又是一片水文
首先就是读取文件名称了标签了,我这里就是简单的猫狗识别数据集

def _find_image_files(data_dir, labels_file):
    jpeg_file_path = '%s/*.jpg' % (data_dir) #文件路径
    matching_files = tf.gfile.Glob(jpeg_file_path)#搜寻这个路径下的所有指定格式文件,这里是*.jpg,即所有的jpg文件
    labels = [0 if 'cat' in os.path.basename(file) else 1 for file in matching_files] #根据标签的格式截取cat.0.jpg...,所以我们判断图片名称是否包含猫或狗
    c = list(zip(matching_files, labels)) #文件名和对应的标签结合
    shuffle(c)#将其打乱
    filenames, labels = zip(*c)#解压出来后就是打乱后的数据了
    return filenames, labels
```python
接着指定一下需要保存的数据格式,就是example,这个是tensorflow指定的数据结构,我这里就是只保存图片的数据和标签

```python
def _convert_to_example(image_buffer, label):
    example = tf.train.Example(features=tf.train.Features(feature={
            'image/label': _int64_feature(label),
            'image/encoded': _bytes_feature(image_buffer)}))
    return example

然后就是把写数据了,我们可以将所有的图片和标签写到一个文件,也可以分开写到几个文件中,首先就是把图片和标签分割成几部分,然后这几段数据分别写到不同的tfrecords中。

def _process_image_files(output_directory, name, filenames, labels, num_shards):
    num_images = len(filenames) # 看看有多少张图片
    num_batch = np.linspace(0,num_images,num_shards+1).astype(np.int) # 将数据划分为n块,可以看一下np.linspace就知道为什么要加1了

    for counter in range(num_shards):
        output_filename = '%s-%.5d-of-%.5d' % (name, counter, num_shards) #要保存的tfrecords文件名,train-00001-00008(名字,写入的第几个文件,总共几个文件)
        output_file = os.path.join(output_directory, output_filename)
        start,end = num_batch[counter], num_batch[counter+1]#获得每一段数据的起始范围
        writer = tf.python_io.TFRecordWriter(output_file)#写文件描述符
        for j in range(start,end): #遍历这段区间的所有文件
            filename,label = filenames[j],labels[j]
            try:
                with tf.gfile.FastGFile(filename,'rb') as f:
                    image_buffer = f.read() #读取图片的原始数据
            except Exception as e:
                print(e)
                continue
            example = _convert_to_example(image_buffer, label)#将数据和标签保存成指定的example格式
            writer.write(example.SerializeToString())
            print('writing {} picture, filename is {}, label is {}, shard is {}'.format(j,filename,label,counter))
        writer.close()

整合一下

def _process_dataset(output_directory,name, directory, labels_file, num_shards):
    filenames, labels = _find_image_files(directory, labels_file)        
    _process_image_files(output_directory,name, filenames, labels,num_shards)   

def main(unused_argv):
    _process_dataset('tfdata','train', '../../cat_dog/train', '', 8)   
# 猫狗数据没有具体的标签文件,所以直接为空,如果有的话自己处理
if __name__ == '__main__':
    tf.app.run()

执行过程的输出

writing 24993 picture, filename is ..\..\cat_dog\train\dog.10861.jpg, label is 1, shard is 7
writing 24994 picture, filename is ..\..\cat_dog\train\dog.7031.jpg, label is 1, shard is 7
writing 24995 picture, filename is ..\..\cat_dog\train\cat.7885.jpg, label is 0, shard is 7
writing 24996 picture, filename is ..\..\cat_dog\train\dog.8770.jpg, label is 1, shard is 7
writing 24997 picture, filename is ..\..\cat_dog\train\dog.6193.jpg, label is 1, shard is 7
writing 24998 picture, filename is ..\..\cat_dog\train\cat.11390.jpg, label is 0, shard is 7
writing 24999 picture, filename is ..\..\cat_dog\train\cat.5946.jpg, label is 0, shard is 7

文件写完之后我们就可以读取了,首先就是把写好的tfrecords文件加载进来

dataset = tf.data.TFRecordDataset(filenames)

有了这个数据集我们就可以为所欲为了,比如我们想要打乱数据可以使用dataset.buffle(1024),我们想要重复几次,其实就是循环训练几次数据集,可以使用dataset.repeat(),等等等。。。,当然干这些之前首先要把我们的数据重example的格式解析成图片格式,需要dataset.map(parse_function)这个函数

    def _parse_function(example_proto):
        features = {
                'image/label': tf.FixedLenFeature((), tf.int64, default_value=0),
                'image/encoded': tf.FixedLenFeature((), tf.string, default_value="")
            }
        parsed = tf.parse_single_example(example_proto, features) #解析example

        label = tf.cast(parsed['image/label'], tf.int32)
        encoded = tf.image.decode_image(parsed['image/encoded'])
        #encoded = tf.image.decode_jpeg(parsed['image/encoded']) #这个就不需要加下面的那句了
        encoded.set_shape([None, None, None])#一定要加这一句,不信你试试,
        encoded = tf.image.resize_images(encoded,(224,224)) #如果不缩放到相同的尺寸,就无法batch读取,这个真不是很好,还不如自己读取batch,例如原始的多线程队列读取
        return encoded, label

下面将所有的代码汇总一下把

# -*- coding: utf-8 -*-

# -*- coding: utf-8 -*-

import tensorflow as tf
import six
import os
import numpy as np
from random import shuffle



def _int64_feature(value):
  if not isinstance(value, list):
    value = [value]
  return tf.train.Feature(int64_list=tf.train.Int64List(value=value))


def _float_feature(value):
  """Wrapper for inserting float features into Example proto."""
  if not isinstance(value, list):
    value = [value]
  return tf.train.Feature(float_list=tf.train.FloatList(value=value))


def _bytes_feature(value):
  """Wrapper for inserting bytes features into Example proto."""
  if isinstance(value, six.string_types):           
    value = six.binary_type(value, encoding='utf-8') 
  return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))


def _convert_to_example(image_buffer, label):
    example = tf.train.Example(features=tf.train.Features(feature={
            'image/label': _int64_feature(label),
            'image/encoded': _bytes_feature(image_buffer)}))
    return example


def _process_image_files(output_directory, name, filenames, labels, num_shards):
    num_images = len(filenames)
    num_batch = np.linspace(0,num_images,num_shards+1).astype(np.int) # 将数据划分为n块[0,100,200...]

    for counter in range(num_shards):
        output_filename = '%s-%.5d-of-%.5d' % (name, counter, num_shards)
        output_file = os.path.join(output_directory, output_filename)
        start,end = num_batch[counter], num_batch[counter+1]
        writer = tf.python_io.TFRecordWriter(output_file)
        for j in range(start,end):
            filename,label = filenames[j],labels[j]
            try:
                with tf.gfile.FastGFile(filename,'rb') as f:
                    image_buffer = f.read()
            except Exception as e:
                print(e)
                continue
            example = _convert_to_example(image_buffer, label)
            writer.write(example.SerializeToString())
            print('writing {} picture, filename is {}, label is {}, shard is {}'.format(j,filename,label,counter))
        writer.close()
  
def _find_image_files(data_dir, labels_file):
    jpeg_file_path = '%s/*.jpg' % (data_dir)
    matching_files = tf.gfile.Glob(jpeg_file_path)
    labels = [0 if 'cat' in os.path.basename(file) else 1 for file in matching_files]
    c = list(zip(matching_files, labels))
    shuffle(c)
    filenames, labels = zip(*c)
    return filenames, labels
     
def _process_dataset(output_directory,name, directory, labels_file, num_shards):
    filenames, labels = _find_image_files(directory, labels_file)        
    _process_image_files(output_directory,name, filenames, labels,num_shards)    
        
def main(unused_argv):
    _process_dataset('tfdata','train', '../../cat_dog/train', '', 8)   

def input_function(filenames):
    def _parse_function(example_proto):
        features = {
                'image/label': tf.FixedLenFeature((), tf.int64, default_value=0),
                'image/encoded': tf.FixedLenFeature((), tf.string, default_value="")
            }
        parsed = tf.parse_single_example(example_proto, features)

        label = tf.cast(parsed['image/label'], tf.int32)
        encoded = tf.image.decode_image(parsed['image/encoded'])
        encoded.set_shape([None, None, None])
        encoded = tf.image.resize_images(encoded,(224,224)) #如果不缩放到相同的尺寸,就无法读取
        return encoded, label

    dataset = tf.data.TFRecordDataset(filenames)
    dataset = dataset.map(_parse_function)
    dataset = dataset.repeat()
    dataset = dataset.batch(32)
    iterator = dataset.make_one_shot_iterator()
    next_element = iterator.get_next()
    return next_element

if __name__ == '__main__':
    tf.app.run()

我们可以验证读取数据

# -*- coding: utf-8 -*-

import tensorflow as tf
from build1 import input_function
import matplotlib.pyplot as plt
import numpy as np

filenames = ['tfdata/train-00000-of-00008','tfdata/train-00001-of-00008']
next_element = input_function(filenames)
with tf.Session() as sess:
    img,lab = sess.run(next_element)
    plt.imshow(img[0].astype(np.int))
    print(img[0].shape)
    while True:
        try:
            print(sess.run(next_element))
        except tf.errors.OutOfRangeError:
            break
  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值