以上代码将各个文件夹中的图片和标签制作为tfrecord格式数据集,使用了PIL打开图片
import os
import tensorflow as tf
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
cwd = 'D:/python学习/神经网络动物分类/train/'
classes = ["airplane", "automobile","bird","cat","deer",
"dog","frog","horse","ship","truck"]
tfRecord_train = "D:\\python学习\\神经网络动物分类\\train.tfrecords"
tfRecord_test = "D:\\python学习\\神经网络动物分类\\test.tfrecords"
writer = tf.python_io.TFRecordWriter("train.tfrecords")
for index, name in enumerate(classes):
class_path = cwd + name + '/'
for img_name in os.listdir(class_path):
img_path = class_path + img_name
img = Image.open(img_path)
# img = img.resize((128,128))
img_raw = img.tobytes()
example = tf.train.Example(features = tf.train.Features(feature={
"label":tf.train.Feature(int64_list = tf.train.Int64List(value=[index])),
'img_raw':tf.train.Feature(bytes_list = tf.train.BytesList(value=[img_raw]))
}))
writer.write(example.SerializeToString())
writer.close()