最近在做OCR的一个项目,手上的训练图片不够用,于是想通过 CGAN 的方法生产一些训练图片,遇到的第一个问题是将已有图片保存成 tfrecord 文件。
代码如下:
# coding: utf-8
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import gzip
import os
import sys
import numpy as np
from six.moves import urllib
import tensorflow as tf
from PIL import Image
# The URLs where the MNIST data can be downloaded.
# _DATA_URL = 'http://yann.lecun.com/exdb/mnist/'
_TRAIN_DATA_FILENAME = './data/orginal_image/train/'
# _TRAIN_LABELS_FILENAME = 'train-labels-idx1-ubyte.gz'
_TEST_DATA_FILENAME = './data/orginal_image/test/'
# _TEST_LABELS_FILENAME = 't10k-labels-idx1-ubyte.gz'
_IMAGE_HEIGHT = 56
_IMAGE_WIDTH = 256
_NUM_CHANNELS = 3
# # The names of the classes.
# _CLASS_NAMES = [
# 'zero',
# 'one',
# 'two',
# 'three',
# 'four',
# 'five',
# 'size',
# 'seven',
# 'eight',
# 'nine',
# ]
def get_filelist(filename):
return [os.path.join(filename,f) for f in os.listdir(filename)]
def _extract_images(filename):
"""Extract the images into a numpy array.
Args:
filename: The path to an MNIST images file.
num_images: The number of images in the file.
Returns:
A numpy array of shape [number_of_images, height, width, channels].
"""
print('Extracting images from: ', filename)
image_list = get_filelist(filename)
data = []
num_images = len(image_list)
for image in image_list:
if image.endswith('DS_Store') or image.endswith('txt'):
num_images = num_images-1
continue
img = Image.open(image)
img_resize = np.array(img.resize((_IMAGE_WIDTH,_IMAGE_HEIGHT))) # 注意顺序,一开始就把宽高顺序写反了,导致图片一致显示有误。
img_resize = img_resize.astype(np.uint8)
data.append(img_resize)
# data = np.array(data).astype(np.uint8)
# data = data.reshape(num_images, _IMAGE_HEIGHT,_IMAGE_WIDTH,3)
return data,num_images
def _extract_labels(filename):
"""Extract the labels into a vector of int64 label IDs.
Args:
filename: The path to an MNIST labels file.
num_labels: The number of labels in the file.
Returns:
A numpy array of shape [number_of_labels]
"""
print('Extracting labels from: ', filename)
image_list = get_filelist(filename)
labels = []
num_labels = len(image_list)
for image in image_list:
if image.endswith('DS_Store') or image.endswith('txt'):
num_labels = num_labels-1
continue
label = image.split('/')[-1].split('_')[0]
labels.append(label)
# print(label)
return labels,num_labels
def int64_feature(values):
"""Returns a TF-Feature of int64s.
Args:
values: A scalar or list of values.
Returns:
A TF-Feature.
"""
if not isinstance(values, (tuple, list)):
values = [values]
return tf.train.Feature(int64_list=tf.train.Int64List(value=values))
def bytes_feature(values):
"""Returns a TF-Feature of bytes.
Args:
values: A string.
Returns:
A TF-Feature.
"""
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[values]))
def image_to_tfexample(image_data, image_format, height, width, class_id):
return tf.train.Example(features=tf.train.Features(feature={
'image/encoded': bytes_feature(image_data),
'image/format': bytes_feature(image_format),
'image/class/label': bytes_feature(class_id), # 有更改
'image/height': int64_feature(height),
'image/width': int64_feature(width),
}))
def _add_to_tfrecord(data_filename,
tfrecord_writer):
"""Loads data from the binary MNIST files and writes files to a TFRecord.
Args:
data_filename: The filename of the MNIST images.
labels_filename: The filename of the MNIST labels.
num_images: The number of images in the dataset.
tfrecord_writer: The TFRecord writer to use for writing.
"""
images, num_images= _extract_images(data_filename)
labels, num_labels = _extract_labels(data_filename)
shape = (_IMAGE_HEIGHT,_IMAGE_WIDTH, _NUM_CHANNELS)
with tf.Graph().as_default():
image = tf.placeholder(dtype=tf.uint8, shape=shape)
encoded_jpeg = tf.image.encode_jpeg(image) # 这里要注意转码格式,因为我的代码是基于tf官方教程改的,是针对mnist图片,png格式的,我自己的训练图片是jpg的,转码不对也会出现问题。
with tf.Session('') as sess:
for j in range(num_images):
sys.stdout.write('\r>> Converting image %d/%d' % (j + 1, num_images))
sys.stdout.flush()
png_string = sess.run(encoded_jpeg, feed_dict={image: images[j]})
example = image_to_tfexample(
png_string, 'jpg'.encode(), _IMAGE_HEIGHT,_IMAGE_WIDTH, labels[j]) #这里同上,也要注意转码格式
tfrecord_writer.write(example.SerializeToString())
def _get_output_filename(dataset_dir, split_name):
"""Creates the output filename.
Args:
dataset_dir: The directory where the temporary files are stored.
split_name: The name of the train/test split.
Returns:
An absolute file path.
"""
return '%s/invoice_%s.tfrecord' % (dataset_dir, split_name)
def run(dataset_dir):
"""Runs the download and conversion operation.
Args:
dataset_dir: The dataset directory where the dataset is stored.
"""
if not tf.gfile.Exists(dataset_dir):
tf.gfile.MakeDirs(dataset_dir)
training_filename = _get_output_filename(dataset_dir, 'train')
testing_filename = _get_output_filename(dataset_dir, 'test')
print(training_filename)
print(testing_filename)
if tf.gfile.Exists(training_filename) and tf.gfile.Exists(testing_filename):
print('Dataset files already exist. Exiting without re-creating them.')
return
# _download_dataset(dataset_dir)
# First, process the training data:
with tf.python_io.TFRecordWriter(training_filename) as tfrecord_writer:
data_filename = _TRAIN_DATA_FILENAME
_add_to_tfrecord(data_filename, tfrecord_writer)
# Next, process the testing data:
with tf.python_io.TFRecordWriter(testing_filename) as tfrecord_writer:
data_filename = _TEST_DATA_FILENAME
_add_to_tfrecord(data_filename, tfrecord_writer)
# Finally, write the labels file:
# labels_to_class_names = dict(zip(range(len(_CLASS_NAMES)), _CLASS_NAMES))
# dataset_utils.write_label_file(labels_to_class_names, dataset_dir)
print('\nFinished converting the invoice dataset!')
if __name__ == '__main__':
dataset_dir = './data/tfrecord'
run(dataset_dir)
pass