需求:想要给dataset的每一个图像加上路径信息。
train_dl = DataLoader(dataset=train_ds, batch_size=batch_size, shuffle=True, collate_fn=my_collate_fn)
DataLoader根据batch_size大小,每次得到一组长度为batch_size的index,根据index调用dataset.__getitem__()方法,将返回的结果存在列表中,最后将列表作为输入参数传递给collate_fn,collate_fn的返回值即为每次next(iter(train_dl))的返回值。
例如,如果batch_size=3,那么index可能为5,27,49,则将列表
batch=[dataset.__getitem__(5),dataset.__getitem__(27),dataset.__getitem__(49)]
输入到collate_fn(batch)中,随后,collate_fn(batch)的返回值即为next(iter(train_dl))的返回值。由此可知,为了使得dataset和train_dl有路径信息,需要重写__getitem__()方法。于是,我们创建一个名为ImageFolderWithID的类,继承自ImageFolder
class ImageFolderWithID(ImageFolder):
def __init__(self,root,img_transforms):
super(ImageFolderWithID, self).__init__(root,transform=img_transforms)
def __getitem__(self, item):
return super(ImageFolderWithID, self).__getitem__(item),self.samples[item][0]
def __len__(self):
return super(ImageFolderWithID, self).__len__()
imgset = ImageFolderWithID(root_path,data_transforms)
这样一来,当调用imgset.__getitem__(5)的时候,返回的元组的第一个元素是原来的既包含图像信息,又包含图像类别的子元组;第二个元素是该图片的路径,即
imgset.__getitem__(5)=( ( 图片矩阵,图片类别 ) , 图片路径 )
为了在训练的时候只把矩阵和类别提取出来,而忽略路径,需要重写DataLoader的collate_fn函数
def my_collate_fn(the_batch):
origin_batch, path_list = [], []
for item in the_batch:
origin_batch.append(item[0])
path_list.append(item[1])
img_x, label_y = default_collate(origin_batch)
return img_x, label_y, path_list
train_ds, test_ds = random_split(imgset, [0.8, 0.2])
train_dl = DataLoader(dataset=train_ds,
batch_size=batch_size,
shuffle=True,
collate_fn=my_collate_fn)
前面提到过,the_batch是
the_batch=[dataset.__getitem__(5),dataset.__getitem__(27),dataset.__getitem__(49)]
而每一个__getitem__的返回值是
imgset.__getitem__(5)=( ( 图片矩阵,图片类别 ) , 图片路径 )
因此my_collate_fn先将每一个__getitem__()的值提取出前面一个元组,将它们送给默认的collate函数:default_collate得到我们训练需要的img_x和label_y,至于path_list,在训练的时候解包用_字符占位即可。