目录
项目场景:
tf.data.Dataset
是一种高效好用的数据集载入工具. 使用它的 map
方法对数据进行处理也十分方便, 这个处理的函数最好是用 tf2 已经提供的 API 来实现.
对于一个用文件夹/文件名存储的数据集, 本人想对其文件路径进行解析(用正则库), 从而直接返回每个样本的数据和标签:
def load_data(file_name): # get tf.Tensor here
file_name = file_name.numpy().decode("utf8")
label = label_dict[pattern.search(file_name).group(0)]
data = np.loadtxt(file_name)[..., np.newaxis]
return tf.cast(data, tf.float32), tf.cast(label, tf.uint8)
ds_train = tf.data.Dataset.list_files("./data/dataset/train/*/*.txt", shuffle=True) \
.filter(lambda file_name: tf.strings.regex_full_match(file_name, source_list_regex)) \
.map(load_data, num_parallel_calls=64)
.batch(80) \
.prefetch(tf.data.experimental.AUTOTUNE) \
.cache