这篇就不再介绍dataloader的概念了, 直接放code了。
import os
import cv2
import random
import numpy as np
import data.utils as utils
import torch
import torch.utils.data as data
class ImgsDataset(data.Dataset):
'''
Read image pairs direct from image folders(GT, LQ)
'''
def __init__(self, opt):
super(ImgsDataset, self).__init__()
# self.color = opt["color"] ## RGB | Y
# self.GT_cropSize = opt["GT_size"]
# self.LQ_cropSize = opt["LQ_size"]
# self.scale = opt["scale"]
# self.paths_LQ = opt["dataroot_LQ"]
# self.paths_GT = opt["dataroot_GT"]
self.color = opt.color ## RGB | Y
self.GT_cropSize = opt.GT_size
self.LQ_cropSize = opt.LQ_size
self.scale = opt.scale
self.paths_LQ = opt.dataroot_LQ
self.paths_GT = opt.dataroot_GT
self.LQ_list = sorted(utils.load_file_list(self.paths_LQ, regx='\.(png|bmp)', keep_prefix=True))
self.GT_list = sorted(utils.load_file_list(self.paths_GT, regx='\.(png|bmp)', keep_prefix=True))
assert len(self.LQ_list) == len(self.GT_list), "PLEASE CHECK DATASETS"
def __getitem__(self, index):
img_LQ = utils.read_img(self.LQ_list[index])
img_GT = utils.read_img(self.GT_list[index])
img_LQ = utils.channel_convert(img_LQ.shape[2], self.color, [img_LQ])[0]
img_GT = utils.channel_convert(img_GT.shape[2], self.color, [img_GT])[0]
GH, GW, GC = img_GT.shape
H, W, C = img_LQ.shape
assert (GH / H == self.scale and GW / W == self.scale), "PLEASE CHECK DATASET SIZE"
# randomly crop
rnd_h = random.randint(0, max(0, H - self.LQ_cropSize))
rnd_w = random.randint(0, max(0, W - self.GT_cropSize))
img_LQ = img_LQ[rnd_h:rnd_h + self.LQ_cropSize, rnd_w:rnd_w + self.LQ_cropSize, :]
rnd_h_GT, rnd_w_GT = int(rnd_h * self.scale), int(rnd_w * self.scale)
img_GT = img_GT[rnd_h_GT:rnd_h_GT+self.GT_cropSize, rnd_w_GT:rnd_w_GT+self.GT_cropSize, :]
# augmentation - flip, rotate
img_LQ, img_GT = utils.augment([img_LQ, img_GT])
# BGR to RGB, HWC to CHW, numpy to tensor
if img_GT.shape[2] == 3:
img_GT = img_GT[:, :, [2, 1, 0]]
img_LQ = img_LQ[:, :, [2, 1, 0]]
img_GT = torch.from_numpy(np.ascontiguousarray(np.transpose(img_GT, (2, 0, 1)))).float()
img_LQ = torch.from_numpy(np.ascontiguousarray(np.transpose(img_LQ, (2, 0, 1)))).float()
return img_LQ, img_GT
def __len__(self):
return len(self.GT_list)