6 个答案:
答案 0 :(得分:4)
tf.data.Dataset.list_files创建一个名为MatchingFiles:0的张量(如果适用,使用适当的前缀)。
你可以评估
tf.shape(tf.get_default_graph().get_tensor_by_name('MatchingFiles:0'))[0]
获取文件数。
当然,这仅适用于简单的情况,特别是如果每张图像只有一个样本(或已知数量的样本)。
在更复杂的情况下,例如当您不知道每个文件中的样本数量时,您只能观察到一个时期结束时的样本数量。
为此,您可以观看Dataset计算的时期数。 repeat()创建一个名为_count的成员,用于计算时期数。通过在迭代期间观察它,您可以发现它何时发生变化并从那里计算数据集大小。
这个计数器可能埋没在连续调用成员函数时创建的Dataset层次结构中,所以我们必须像这样挖掘它。
d = my_dataset
# RepeatDataset seems not to be exposed -- this is a possible workaround
RepeatDataset = type(tf.data.Dataset().repeat())
try:
while not isinstance(d, RepeatDataset):
d = d._input_dataset
except AttributeError:
warnings.warn('no epoch counter found')
epoch_counter = None
e