import oneflow.core.record.record_pb2 as ofrecord
import six
import random
import struct
def int32_feature(value):
if not isinstance(value, (list, tuple)):
value = [value]
return ofrecord.Feature(int32_list=ofrecord.Int32List(value=value))
def int64_feature(value):
if not isinstance(value, (list, tuple)):
value = [value]
return ofrecord.Feature(int64_list=ofrecord.Int64List(value=value))
def float_feature(value):
#判断value是不是列表或元组,不是的话转为列表
if not isinstance(value, (list, tuple)):
value = [value]
#将python对象数据转为Feature对象
return ofrecord.Feature(float_list=ofrecord.FloatList(value=value))
def double_feature(value):
if not isinstance(value, (list, tuple)):
value = [value]
return ofrecord.Feature(double_list=ofrecord.DoubleList(value=value))
def bytes_feature(value):
if not isinstance(value, (list, tuple)):
value = [value]
if not six.PY2:
if isinstance(value[0], str):
value = [x.encode() for x in value]
return ofrecord.Feature(bytes_list=ofrecord.BytesList(value=value))
obserations = 28 * 28
#二进制方式打开文件用来存储ofrecord对象
f = open("./dataset/part-0", "wb")
for loop in range(0, 3):#模拟生成三个图像和对应标签的数据
image = [random.random() for x in range(0, obserations)]
label = [random.randint(0, 9)]
#将得到的Feature对象存放在字典中(images:Feature)
topack = {
"images": float_feature(image),#利用上边定义的函数获取数据的Feature对象
"labels": int64_feature(label),
}
#利用上一步的字典创建OFrecord对象
ofrecord_features = ofrecord.OFRecord(feature=topack)
#调用方法生成序列化结果
serializedBytes = ofrecord_features.SerializeToString()
#计算序列字节长度
length = ofrecord_features.ByteSize()
#将int64的字节长度转成八字节的二进制写入文件
f.write(struct.pack("q", length))
#将序列化数据写入文件
f.write(serializedBytes)
print("Done!")
f.close()
import oneflow.core.record.record_pb2 as ofrecord
import struct
with open("./dataset/part-0", "rb") as f:#以二进制读的方式打开ofrecord文件
for loop in range(0, 3):
#从文件头八个字节中计算序列化数据长度
length = struct.unpack("q", f.read(8))
#根据序列长度读取指定长度的数据
serializedBytes = f.read(length[0])
#从序列化数据中逆回到Feature对象
ofrecord_features = ofrecord.OFRecord.FromString(serializedBytes)
#从feature对象逆回float数据
image = ofrecord_features.feature["images"].float_list.value
#从feature对象逆回到int数据
label = ofrecord_features.feature["labels"].int64_list.value
#打印从ofrecord文件中读取出来的数据
print(image, label, end="\n\n")