tensoflow: 从文件夹加载数据管道的两种方式

 1. ImageDataGenerator

train_datagen = ImageDataGenerator(rescale=1./255, 
                                   zoom_range=0.20, 
                                   fill_mode="nearest")
print(type(train_datagen)) # >>> <class 'keras.preprocessing.image.ImageDataGenerator'>

validation_datagen = ImageDataGenerator(rescale=1./255)
      
train_generator = train_datagen.flow_from_directory(train_path, 
                                                    target_size=(img_rows, img_cols), 
                                                    batch_size=batch_size, 
                                                    class_mode='categorical', 
                                                    subset='training')
print(type(train_generator)) # >>> <class 'keras_preprocessing.image.directory_iterator.DirectoryIterator'>
validation_generator = validation_datagen.flow_from_directory(validation_path, 
                                                              target_size=(img_rows, img_cols), 
                                                              batch_size=batch_size, 
                                                              class_mode=None,  # only data, no labels 
                                                              shuffle=False)   
 
history = model.fit_generator(train_generator, 
                              steps_per_epoch=len(train_generator), 
                              epochs=epochs)
    
predictions = model.predict_generator(validation_generator,
                                          steps=len(validation_generator),
                                          verbose=1)

2.  image_dataset_from_directory

train_dataset = image_dataset_from_directory(train_dir,
                                             shuffle=True, # 默认: True
                                             batch_size=BATCH_SIZE, # 默认:(256 x 256)
                                             image_size=IMG_SIZE) # 默认: 32

validation_dataset = image_dataset_from_directory(validation_dir,
                                                  label_mode='binary',
                                                  shuffle=True, 
                                                  batch_size=BATCH_SIZE, 
                                                  image_size=IMG_SIZE)
print(type(validation_dataset)) # >>> <class 'tensorflow.python.data.ops.dataset_ops.BatchDataset'>
test_dataset = image_dataset_from_directory(test_dir, 
                                            label_mode='binary',
                                            shuffle=True, 
                                            image_size=IMG_SIZE)
# print(dir(test_dataset)) # dir(object):返回对象的属性和方法。
print(test_dataset.class_names) # 数据集属性中类的名称

history = model.fit(train_dataset,
                    epochs=initial_epochs, 
                    validation_data=validation_dataset)

 

  • 2
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

despacito,

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值