4-4 tf.io.decode_csv使用

代码承接上一篇

打印生成的CSV文件名

// A code block
import pprint
print("train filenames:")
pprint.pprint(train_filenames)
print("valid filenames:")
pprint.pprint(valid_filenames)
print("test filenames:")
pprint.pprint(test_filenames)

那么tensorflow中如何把这些小的CSV文件集合到一起生成dataset呢?

tf需要两步

// A code block
# 1. filename -> dataset
# 2. read file -> dataset -> datasets -> merge
# 3. parse csv

filename_dataset = tf.data.Dataset.list_files(train_filenames)
for filename in filename_dataset:
    print(filename)

filename_dataset是一个文件名的数据集,list_files是专门用来处理文件名的,会把文件名生成一个dataset,用train_filenames做一个演示。filename_dataset有20个tensor,每个tensor都是一个文件名

遍历文件名dataset

// A code block
n_readers = 5
dataset = filename_dataset.interleave(
    lambda filename: tf.data.TextLineDataset(filename).skip(1),
    cycle_length = n_readers
)
for line in dataset.take(15):
    print(line.numpy())
    

interleave遍历文件名数据集里面的每一个元素。
TextLineDataset按行读取文本,生成一个dataset。
cycle_length用来控制读取文件的并行度。
因为整个数据集是很大的,就用take函数读取前面15个。
skip(1)就把head省略掉。
以上操作就把训练集得到了,但是每个元素都是一个字符串,字段没有分开。

解析字段

// A code block
# tf.io.decode_csv(str, record_defaults)

sample_str = '1,2,3,4,5'
record_defaults = [
    tf.constant(0, dtype=tf.int32),
    0,
    np.nan,
    "hello",
    tf.constant([])
]
parsed_fields = tf.io.decode_csv(sample_str, record_defaults)
print(parsed_fields)

record_defaults指定每个字段的默认值。
tf.io.decode_csv解析之后的到的是tensor格式

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值