TF2.0里面的tfrecord

这篇博客介绍了在TF2.0中如何处理tfrecord文件。内容包括数据类型的一致性要求,以及tfrecord的制作和读取方法。重点讲述了tfrecord的读取,分别讨论了单条数据读取和批量读取的方式,并提供了相关代码示例。
摘要由CSDN通过智能技术生成

记录一下最近遇到的问题

特别注意,在制作tfrercord的时候,写入的数据类型要和读出的时候保持一致,不然会出现多个矩阵的问题。
的萨芬

比如现有一个(100,100)的精度很高的矩阵,他的元素数据类型是float64,制作时是这个类型写进tfrecord,
但是在读出tfrecord的时候你用的类型是tf.float32,那么就会原来的写入的一个矩阵读出来会变成两个矩阵,应该是由于精度的问题,
data = tf.io.decode_raw(features['data'], out_type= tf.float64) #如果这里的out_type=换成tf.float32的话
  1. TF2里面的tfrecord数据格式制作与加载
    制作没什么好讲的,代码如下:
def _make_tfrecords(doc_path, save_dir, train_or_test):

    tfrecords_path = save_dir +'\\' +train_or_test +'.tfrecord'
    if os.path.exists(tfrecords_path):
        os.remove(tfrecords_path)
        print('clean origional tfrecords')

    def _bytes_feature(values):
        return tf.train.Feature(bytes_list= tf.train.BytesList(value= [values]))
    def _int_feature(values):
        return tf.train.Feature(int64_list = tf.train.Int64List(value= [values]))

    writer = tf.io.TFRecordWriter(tfrecords_path)
    for i, single_image_path in enumerate(traverse_document(doc_path)):
        # img = np.array(Image.open(img_path))
        # 这里得到一个图片的真实标签,另一个标签tag_label.图片名字,图片数组,宽以及高
        tru_label, tag_label, img_name, img, width, height= from_name_get_true_and_target_label(single_image_path)
        # 字符串编码为二进制
        img_name = img_name.encode('utf-8')
        example = tf.train.Example(
            features = tf.train.Features(
                feature = {
   
                    'img_name':
  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值