最近在做图像分类的任务,使用了densenet进行fine-tune,因为图片数据比较多且占内存因此数据的加载使用ImageDataGenerator生成器,使用flow_from_directory从文件夹中获取各个类别的数据。
因为在测试的时候,需要知道哪个数据被判错了,要找到对应的文件名。
test_datagen = ImageDataGenerator(rescale=1./255)
val_generator = test_datagen.flow_from_directory( test_dir,
target_size=(img_size, img_size),
batch_size=32,
shuffle=False)
print(val_generator.class_indices) # 输出对应的标签文件夹
print(val_generator.filenames) # 按顺序输出文件的名字