def _bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def _int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
filename = ('存储recoders文件路径/hhh.tfrecoders')
writer = tf.python_io.TFRecordWriter(filename)
V_patch = loadmat('.mat文件路径/xxxx.mat')
b = V_patch['patch_3d']
a = b.astype(np.float32)
shape = a.shape
l1 = shape[0]
print(l1)
for i in range(l1):
volume = a[i]
print(i)
volume1 = np.delete(volume, -1, 0) #以下三行为截取图片代码,不重要,可根据需要删除
volume2 = np.delete(volume1, -1, 1)
volume3 = np.delete(volume2, -1, 2)
d = volume3.shape
vol_str = volume3.tostring()
e = len(vol_str)
example = tf.train.Example(features=tf.train.Features(feature={'vol_raw': _bytes_feature(vol_str)}))
writer.write(example.SerializeToString())
writer.close()
print('Done!')
tf.train.Example()用于将数据处理成二进制,目的是提升IO效率。
也可以对tfrecord文件添加标签
tf.train.Example(features=tf.train.Features(feature={'vol_raw': _bytes_feature(vol_str),
'label':_int64_feature(index)}))
_int64_feature()上面代码中已经定义。index是标签,比如二分类标签index可以填0或者1
特别需要注意的一点!!!代码中定义 .