Pytorch中DataLoader自定义数据集载入时报错“TypeError: default_collate: batch must contain tensors, numpy arrays, numbers, dicts or lists; found 《class 'NoneType'》”
问题描述
创建自己的数据集并用DataLoader载入时报出如下错误:
产生原因
该问题表明在用迭代器循环载入数据及标签时,存在某一数据类型为’None’
解决办法
提前剔除数据类型为’None’的数据
- 自定义数据集中__getitem__(self, idx)函数的循环长度由同一类中的__len__(self)函数确定;
- 如果该循环长度内的数据中有类型为’None’的数据,则在主程序中用DataLoader迭代载入数据时就会出现上述错误;
- 故此,在自定义数据集_pre_process(self)函数中要先去除数据类型为’None’的数据,再计算去除数据类型为’None’后剩余数据的长度,该长度就为自定义数据集的长度,也就是1中的循环长度;
- 用不含’None’数据类型的循环长度,再去实现__getitem__中的内容,才能保证主程序中用DataLoader迭代载入数据时不报上述错误。
程序展示
该程序为提取病理切片肿瘤区域所对应的patch
wsi_path:jpg格式病理切片的存储位置
mask_path:该病理切片npy格式的肿瘤掩码(肿瘤区域值为255,其他区域元素值为0)
import numpy as np
from torch.utils.data import Dataset
from PIL import Image
Image.MAX_IMAGE_PIXELS = None
class WSIPatchDataset(Dataset):
def __init__(self, wsi_path, mask_path, image_size=256, crop_size=224,
normalize=True, flip='NONE', rotate='NONE'):
self._wsi_path = wsi_path
self._