将mnist数据转成原始图片数据再转成TFRecord格式

1、将mnist数据转成原始图片数据

def convert_mnist_img(data, save_path):
    for i in range(data.images.shape[0]):
        img = data.images[i].reshape([28, 28, 1])
        img = (img * 255).astype(np.uint8)
        label = data.labels[i]
        # cv2.imshow('image', img)
        # cv2.waitKey(500)
        filename = save_path + '/{}_{}.jpg'.format(label, i)
        cv2.imwrite(filename, img)

if __name__ == '__main__':
    mnist = input_data.read_data_sets('./data', source_url='http://yann.lecun.com/exdb/mnist/')
    convert_mnist_img(mnist.train, 'img_train')
    print('convert training data to image complete')
    convert_mnist_img(mnist.test, 'img_test')
    print('convert test data to image complete')
    convert_mnist_img(mnist.validation, 'img_validation')
    print('convert validation data to image complete')

这样就可以把训练、验证、测试集的图片分别保存下来:

2、将图片数据转成TFRecord格式文件

def convert_img_tfrecords(data_path, record_dir):
    writer = tf.python_io.TFRecordWriter(record_dir)
    for file in os.listdir(data_path):
        img = cv2.imread(os.path.join(data_path, file), cv2.IMREAD_GRAYSCALE)
        img_raw = img.tobytes()
        label = int(file.split('_')[0])
        example = tf.train.Example(features=tf.train.Features(feature={
            'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[label])),
            'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))
        }))
        writer.write(example.SerializeToString())
    writer.close()

if __name__ == '__main__':
    convert_img_tfrecords('./img_validation', 'validation_img.tfrecords')
    print('convert validation image to tfrecords complete')
    convert_img_tfrecords('./img_test', 'test_img.tfrecords')
    print('convert test image to tfrecords complete')
    convert_img_tfrecords('./img_train', 'train_img.tfrecords')
    print('convert train image to tfrecords complete')

针对训练集、验证集、测试集生成对应的三个TFRecord格式文件。

3、解析TFRecord格式文件

def read_record(record_dir):
    for serialized_exam in tf.python_io.tf_record_iterator(record_dir):
        example = tf.train.Example()
        example.ParseFromString(serialized_exam)

        image = example.features.feature['img_raw'].bytes_list.value[0]
        label = example.features.feature['label'].int64_list.value[0]
        image = np.fromstring(image, dtype=np.uint8)
        image = image.reshape([28, 28, 1])

        cv2.imshow('image', image)
        cv2.waitKey(1000)

        print(image.shape, label)
    cv2.destroyAllWindows()

可以解析TFRecord文件,查看是否正确。

真正训练的时候,可以结合tf.train.string_input_producer和tf.train.Coordinator()使用,利用队列生成批量数据,以供训练。

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值