代码承接上一篇
打印生成的CSV文件名
。
// A code block
import pprint
print("train filenames:")
pprint.pprint(train_filenames)
print("valid filenames:")
pprint.pprint(valid_filenames)
print("test filenames:")
pprint.pprint(test_filenames)
那么tensorflow中如何把这些小的CSV文件集合到一起生成dataset呢?
tf需要两步
。
// A code block
# 1. filename -> dataset
# 2. read file -> dataset -> datasets -> merge
# 3. parse csv
filename_dataset = tf.data.Dataset.list_files(train_filenames)
for filename in filename_dataset:
print(filename)
filename_dataset是一个文件名的数据集,list_files是专门用来处理文件名的,会把文件名生成一个dataset,用train_filenames做一个演示。filename_dataset有20个tensor,每个tensor都是一个文件名
遍历文件名dataset
。
// A code block
n_readers = 5
dataset = filename_dataset.interleave(
lambda filename: tf.data.TextLineDataset(filename).skip(1),
cycle_length = n_readers
)
for line in dataset.take(15):
print(line.numpy())
interleave遍历文件名数据集里面的每一个元素。
TextLineDataset按行读取文本,生成一个dataset。
cycle_length用来控制读取文件的并行度。
因为整个数据集是很大的,就用take函数读取前面15个。
skip(1)就把head省略掉。
以上操作就把训练集得到了,但是每个元素都是一个字符串,字段没有分开。
解析字段
。
// A code block
# tf.io.decode_csv(str, record_defaults)
sample_str = '1,2,3,4,5'
record_defaults = [
tf.constant(0, dtype=tf.int32),
0,
np.nan,
"hello",
tf.constant([])
]
parsed_fields = tf.io.decode_csv(sample_str, record_defaults)
print(parsed_fields)
record_defaults指定每个字段的默认值。
tf.io.decode_csv解析之后的到的是tensor格式