1、将mnist数据转成原始图片数据
def convert_mnist_img(data, save_path):
for i in range(data.images.shape[0]):
img = data.images[i].reshape([28, 28, 1])
img = (img * 255).astype(np.uint8)
label = data.labels[i]
# cv2.imshow('image', img)
# cv2.waitKey(500)
filename = save_path + '/{}_{}.jpg'.format(label, i)
cv2.imwrite(filename, img)
if __name__ == '__main__':
mnist = input_data.read_data_sets('./data', source_url='http://yann.lecun.com/exdb/mnist/')
convert_mnist_img(mnist.train, 'img_train')
print('convert training data to image complete')
convert_mnist_img(mnist.test, 'img_test')
print('convert test data to image complete')
convert_mnist_img(mnist.validation, 'img_validation')
print('convert validation data to image complete')
这样就可以把训练、验证、测试集的图片分别保存下来:
2、将图片数据转成TFRecord格式文件
def convert_img_tfrecords(data_path, record_dir):
writer = tf.python_io.TFRecordWriter(record_dir)
for file in os.listdir(data_path):
img = cv2.imread(os.path.join(data_path, file), cv2.IMREAD_GRAYSCALE)
img_raw = img.tobytes()
label = int(file.split('_')[0])
example = tf.train.Example(features=tf.train.Features(feature={
'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[label])),
'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))
}))
writer.write(example.SerializeToString())
writer.close()
if __name__ == '__main__':
convert_img_tfrecords('./img_validation', 'validation_img.tfrecords')
print('convert validation image to tfrecords complete')
convert_img_tfrecords('./img_test', 'test_img.tfrecords')
print('convert test image to tfrecords complete')
convert_img_tfrecords('./img_train', 'train_img.tfrecords')
print('convert train image to tfrecords complete')
针对训练集、验证集、测试集生成对应的三个TFRecord格式文件。
3、解析TFRecord格式文件
def read_record(record_dir):
for serialized_exam in tf.python_io.tf_record_iterator(record_dir):
example = tf.train.Example()
example.ParseFromString(serialized_exam)
image = example.features.feature['img_raw'].bytes_list.value[0]
label = example.features.feature['label'].int64_list.value[0]
image = np.fromstring(image, dtype=np.uint8)
image = image.reshape([28, 28, 1])
cv2.imshow('image', image)
cv2.waitKey(1000)
print(image.shape, label)
cv2.destroyAllWindows()
可以解析TFRecord文件,查看是否正确。
真正训练的时候,可以结合tf.train.string_input_producer和tf.train.Coordinator()使用,利用队列生成批量数据,以供训练。