用时:30 min
原图大小:3.5 G
tfrecord文件大小:65.3 G (amazing! 注意原图是jpg压缩的)
# -*- coding: utf-8 -*-
"""
Created on Thu Sep 7 19:25:38 2017
@author: wayne
http://wiki.jikexueyuan.com/project/tensorflow-zh/how_tos/reading_data.html
"""
import json
import os
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
import datetime
record_PATH = 'ai_challenger_scene_train_20170904/' # 目标文件夹
tfrecord_file = record_PATH + 'train.tfrecord'
writer = tf.python_io.TFRecordWriter(tfrecord_file)
def _int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
def _bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def get_image_binary(filename):
""" You can read in the image using tensorflow too, but it's a drag
since you have to create graphs. It's much easier using Pillow and NumPy
"""
image = Image.open(filename)
image = np.asarray(image, np.uint8)
shape = np.array(image.shape, np.int32)
return shape, image.tobytes() # convert image to raw data bytes in the array.
def write_to_tfrecord(label, shape, binary_image, tfrecord_file):
""" This example is to write a sample to TFRecord file. If you want to write
more samples, just use a loop.
"""
# write label, shape, and image content to the TFRecord file
example = tf.train.Example(features=tf.train.Features(feature={
'label': _int64_feature(label),
'h': _int64_feature(shape[0]),
'w': _int64_feature(shape[1]),
'c': _int64_feature(shape[2]),
'image': _bytes_feature(binary_image)
}))
writer.write(example.SerializeToString())
def write_tfrecord(label, image_file, tfrecord_file):
shape, binary_image = get_image_binary(image_file)
write_to_tfrecord(label, shape, binary_image, tfrecord_file)
with open('ai_challenger_scene_train_20170904/scene_train_annotations_20170904.json', 'r') as f: #label文件
label_raw = json.load(f)
def file_name2(file_dir): #特定类型的文件
L=[]
image = []
for root, dirs, files in os.walk(file_dir):
for file in files:
if os.path.splitext(file)[1] == '.jpg':
L.append(os.path.join(root, file))
image.append(file)
return L, image
path, image = file_name2('ai_challenger_scene_train_20170904/scene_train_images_20170904') #图片目录
'''
存入tfrecords
'''
label = {}
for item in label_raw:
label[item['image_id']] = int(item['label_id'])
starttime = datetime.datetime.now()
#long running
num = len(path)
for i in range(num):
write_tfrecord(label[image[i]], path[i], tfrecord_file)
if i%1000==0:
print(i)
writer.close()
endtime = datetime.datetime.now()
print (endtime - starttime).seconds