我们在使用Unet网络来训练的时候,发现卡在data_loader这个位置,
for i, data in enumerate(data_loader):#卡在此处
image, target = data
image, target = image.to(device), target.to(device)
一直不动,有的人说修改num_works=0,但是根本解决不了问题,经过不懈努力,我发现问题的根源在数据集,我们一般使用的图片为png格式或者pig格式,但是Unet网络的原始数据是做医学图像分割的,他的训练数据格式是tif,所以我们在定义dataset的时候,会与普通的图片格式有所不同:
img = cv2.imread(self.img_list[idx])
img = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
我们先要通过CV读取tif格式的数据(读取后的格式是BGR的,我们需要转换成RGB),然后 转换成RGB格式数据。
其完整代码如下:
class DriveDataset(Dataset):
def __init__(self, root: str, train: bool, transforms=None):
super(DriveDataset, self).__init__()
self.flag = "training" if train else "test"
data_root = os.path.join(root, "DRIVE", self.flag)
assert os.path.exists(data_root), f"path '{data_root}' does not exists."
self.transforms = transforms
img_names = [i for i in os.listdir(os.path.join(data_root, "images")) if i.endswith(".tif")]
self.img_list = [os.path.join(data_root, "images", i) for i in img_names]
self.manual = [os.path.join(data_root, "1st_manual", i.split("_")[0] + "_manual1.gif")
for i in img_names]
# check files
for i in self.manual:
if os.path.exists(i) is False:
raise FileNotFoundError(f"file {i} does not exists.")
self.roi_mask = [os.path.join(data_root, "mask", i.split("_")[0] + f"_{self.flag}_mask.gif")
for i in img_names]
# check files
for i in self.roi_mask:
if os.path.exists(i) is False:
raise FileNotFoundError(f"file {i} does not exists.")
def __getitem__(self, idx):
img = cv2.imread(self.img_list[idx])
img = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
#img = Image.open(self.img_list[idx]).convert('RGB')
manual = Image.open(self.manual[idx]).convert('L')
manual = np.array(manual) / 255
roi_mask = Image.open(self.roi_mask[idx]).convert('L')
roi_mask = 255 - np.array(roi_mask)
mask = np.clip(manual + roi_mask, a_min=0, a_max=255)
# 这里转回PIL的原因是,transforms中是对PIL数据进行处理
mask = Image.fromarray(mask)
if self.transforms is not None:
img, mask = self.transforms(img, mask)
return img, mask
def __len__(self):
return len(self.img_list)
@staticmethod
def collate_fn(batch):
images, targets = list(zip(*batch))
batched_imgs = cat_list(images, fill_value=0)
batched_targets = cat_list(targets, fill_value=255)
return batched_imgs, batched_targets