读取数据
根据不同任务生成数据集
- 数据集形式:图片数据,文件夹名代表标签
使用函数ImageDataBunch.from_folder,具体示例如下:
data = ImageDataBunch.from_folder(path, ds_tfms=tfms, size=24)
api分解如下:
data = (ImageItemList.from_folder(path) #Where to find the data? -> in path and its subfolders
.split_by_folder() #How to split in train/valid? -> use the folders
.label_from_folder() #How to label? -> depending on the folder of the filenames
.add_test_folder() #Optionally add a test set (here default name is test)
.transform(tfms, size=64) #Data augmentation? -> use tfms with a size of 64
.databunch()) #Finally? -> use the defaults for conversion to ImageDataBunch
- 数据集形式:图片数据,标签保存在csv文件中
使用ImageDataBunch.from_csv方法,示例如下:
data = ImageDataBunch.from_csv(planet, folder='train', size=128, suffix='.jpg', sep = ' ', ds_tfms=planet_tfms)
api分解如下:
data = (ImageItemList.from_csv(planet, 'labels.csv', folder='train', suffix='.jpg')
#Where to find the data? -> in planet 'train' folder
.random_split_by_pct()
#How to split in train/valid? -> randomly with the default 20% in valid
.label_from_df(sep=' ')
#How to label? -> use the csv file
.transform(planet_tfms, size=128)
#Data augmentation? -> use tfms with a size of 128
.databunch())
#Finally -> use the defaults for conversion to databunch
- 对于分割与目标检测,其也可以使用其定义的数据块api快速定义,具体参见文档data_block部分,有详细的介绍。
DataBunch
- DataBunch是fastai中读取数据最基本的类,其针对不同的任务将数据集处理成合适的形式,以便送入learner进行训练。
class DataBunch
DataBunch(`train_dl`:DataLoader, `valid_dl`:DataLoader, `fix_dl`:DataLoader=`None`, `test_dl`:Optional[DataLoader]=`None`, `device`:device=`None`, `tfms`:Optional[Collection[Callable]]=`None`, `path`:PathOrStr=`'.'`, `collate_fn`:Callable=`'data_collate'`, `no_check`:bool=`False`)
此类中一个重要的方法是create,可以使用我们自己的pytorch数据集形式来创造fastai所需要的databunch,如下所示:
create(`train_ds`:Dataset, `valid_ds`:Dataset, `test_ds`:Optional[Dataset]=`None`, `path`:PathOrStr=`'.'`, `bs`:int=`64`, `num_workers`:int=`4`, `tfms`:Optional[Collection[Callable]]=`None`, `device`:device=`None`, `collate_fn`:Callable=`'data_collate'`, `no_check`:bool=`False`) → DataBunch
其中的Dataset代表pytorch中的一种数据形式。
2. 一种可视化数据的方法show_batch
show_batch(`rows`:int=`5`, `ds_type`:DatasetType=``, `kwargs`)
- 一种添加数据变化的方法add_tfm
使用方法如下:add_tfm(
tfm:Callable)
使用自己的数据集
遵循pytorch的方法,需要继承Dataset属性,必须包含__len__、__getitem__两个属性,除此之外还需要带有c属性,其代表模型最后一层的输出个数,classes属性代表数据的类别。具体根据具体的任务来设置。
DeviceDataLoader
将数据送给不同的显卡进行训练
DeviceDataLoader(`dl`:DataLoader, `device`:device, `tfms`:List[Callable]=`None`, `collate_fn`:Callable=`'data_collate'`)