matlab 转换 tfrecord,训练数据集与TFRecord互相转换的两种方式

TensorFlow使用TFRecord格式来统一存储数据,该格式可以将图像数据、标签信息、图像路径以及宽高等不同类型的信息放在一起进行统一存储,从而方便有效的管理不同的属性。

将训练数据集转成TFRecord

这里采用的数据集为目前正在做的项目的数据集,共包含两个目标文件夹(分别包含100幅图像)及对应的label.txt,label文件中的每一条内容分别对应两个文件夹中的一幅图像的路径及目标物的位置信息,即左上顶点和右下顶点的坐标信息(),接下来我们将上面的数据制作成TFRecord文件,由于后续需要验证制作的TFRecord数据是否正确,而每张图像的尺寸并不一致,因此在生成的TFRecord文件中除了包含图像内容和标签信息,还包括了图像的宽、高及通道的信息,这样在解析图像的时候,才能把图像数据重新reshape成图像。

根据读取图像数据方式的不同,共有两种方式将自己的数据集转换成TFRecord格式,同样对应两种方式对TFRecord格式进行解析。具体代码如下:

# Convert own_data to TFRecord of TF-Example protos.

import tensorflow as tf

from PIL import Image

import numpy as np

import os

# 生成整数型的属性

def int64_feature(values):

return tf.train.Feature(int64_list=tf.train.Int64List(value=values))

# 生成浮点型的属性

def float_feature(values):

return tf.train.Feature(float_list=tf.train.FloatList(value=values))

# 生成字符串型的属性

def bytes_feature(values):

return tf.train.Feature(bytes_list=tf.train.BytesList(value=[values]))

# 标签信息的地址

dataset_dir = "/Users/**/**/label.txt"

# 图像存放的根目录地址

root_dir = '"/Users/**/**/'

# 输出TFRecord文件的地址

output_filename = "/Users/**/**/output.tfrecord"

file_lines = open(dataset_dir).readlines()

# 创建一个writer来写TFRecord文件

writer = tf.python_io.TFRecordWriter(output_filename)

# 统计有效数据

valid_record_count = 0

# 从label.txt循环读入要写入的数据信息

for idx, line in enumerate(file_lines):

line = line.strip('\n')

image_target_path = line.split(",")[0]

image_search_path = line.split(",")[1]

image_labels_str = line.split(",")[2:]

image_format = str(image_target_path.split('.')[-1]).lower()

image_target_path = os.path.join(root_dir, image_target_path)

image_search_path = os.path.join(root_dir, image_search_path)

# 使用tf.gfile.FastGFile读取图像的原始数据,method_1

image_target_data = tf.gfile.FastGFile(image_target_path, 'r').read()

image_search_data = tf.gfile.FastGFile(image_search_path, 'r').read()

# 使用tf.image.decode_jpeg对图像进行解码,并利用img.eval().shape获得图像的宽高和通道信息

T_height, T_width, channels = tf.image.decode_jpeg(image_target_data).eval().shape

S_height, S_width, channels = tf.image.decode_jpeg(image_search_data).eval().shape

# 使用PIL的Image.open读取图像,method_2

image_target = Image.open(image_target_path, 'r')

image_target_data = image_target.tobytes()

T_height, T_width = image_target.size

image_search = Image.open(image_search_path, 'r')

image_search_data = image_search.tobytes()

S_height, S_width = image_search.size

image_labels = [float(x) for x in image_labels_str]

if not len(image_labels) == 4:

print("invalid label: " + line)

continue

# 将一个样例转化为Example Protocol Buffer,并将所有信息写入数据结构

example = tf.train.Example(features=tf.train.Features(feature={

'image_target/encoded': bytes_feature(image_target_data),

'image_search/encoded': bytes_feature(image_search_data),

'image_target/format': bytes_feature(image_format),

'image_search/format': bytes_feature(image_format),

'image/class/label': float_feature(image_labels),

'image_target/height': int64_feature(T_height),

'image_target/width': int64_feature(T_width),

'image_search/height': int64_feature(S_height),

'image_search/width': int64_feature(S_width),

'image/channels': int64_feature(channels),

'image_target/path': bytes_feature(image_target_path),

'image_search/path': bytes_feature(image_search_path) }))

# 将一个Example写入TFRecord文件

writer.write(example.SerializeToString())

valid_record_count += 1

writer.close()

print("\nvalid image count: " + str(valid_record_count))

读取TFRecord文件,具体代码如下:

# 使用 tf.image.decode_jpeg对jpg格式图像进行解码,对应tf.gfile读取图像,method_1

image_target = tf.image.decode_jpeg(features['image_target/encoded'])

# 使用tf.decode_raw将字符串解析成图像对应的像素数组,对应Image.open读取图像,method_2

image_target = tf.decode_raw(features['image_target/encoded'], tf.uint8)

label = features['image/class/label']

T_height = tf.cast(features['image_target/height'], tf.int32)

T_width = tf.cast(features['image_target/width'], tf.int32)

channels = tf.cast(features['image/channels'], tf.int32)

image_target_path = features['image_target/path']

sess = tf.Session()

coord = tf.train.Coordinator()

threads = tf.train.start_queue_runners(sess=sess,coord=coord)

# 每次运行可以读取TFRecord文件中的一个样例

for i in range(100):

image_t, label_info,t_height, t_width, channnel, path = sess.run([image_target,label,T_height, T_width,channels,image_target_path])

image_name = path.split("/")[-1].split(".")[0]

sample = sess.run(tf.reshape(image_t, [t_height, t_width, channnel]))

image= Image.fromarray(sample,'RGB')

# 以图像名称_label信息对图像命名,并进行存储

image.save(decode_path+ image_name+'_'+ str(label_info[0])+'.jpg')

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值