【pytorch】学会pytorch dataloader数据加载(四)---直接读图

这篇就不再介绍dataloader的概念了, 直接放code了。

import os
import cv2
import random
import numpy as np
import data.utils as utils
import torch
import torch.utils.data as data

class ImgsDataset(data.Dataset):
    '''
    Read image pairs direct from image folders(GT, LQ)
    '''
    def __init__(self, opt):
        super(ImgsDataset, self).__init__()
        # self.color    = opt["color"]      ## RGB | Y
        # self.GT_cropSize  = opt["GT_size"]
        # self.LQ_cropSize  = opt["LQ_size"]
        # self.scale    = opt["scale"]
        # self.paths_LQ = opt["dataroot_LQ"]
        # self.paths_GT = opt["dataroot_GT"]

        self.color    = opt.color      ## RGB | Y
        self.GT_cropSize  = opt.GT_size
        self.LQ_cropSize  = opt.LQ_size
        self.scale    = opt.scale
        self.paths_LQ = opt.dataroot_LQ
        self.paths_GT = opt.dataroot_GT
        
        self.LQ_list  = sorted(utils.load_file_list(self.paths_LQ, regx='\.(png|bmp)', keep_prefix=True))
        self.GT_list  = sorted(utils.load_file_list(self.paths_GT, regx='\.(png|bmp)', keep_prefix=True))

        assert len(self.LQ_list) == len(self.GT_list), "PLEASE CHECK DATASETS"
    

    def __getitem__(self, index):
        img_LQ = utils.read_img(self.LQ_list[index])
        img_GT = utils.read_img(self.GT_list[index])
        img_LQ = utils.channel_convert(img_LQ.shape[2], self.color, [img_LQ])[0]
        img_GT = utils.channel_convert(img_GT.shape[2], self.color, [img_GT])[0]
        
        GH, GW, GC = img_GT.shape
        H,  W,  C  = img_LQ.shape
        assert (GH / H == self.scale and GW / W == self.scale), "PLEASE CHECK DATASET SIZE"

        # randomly crop
        rnd_h = random.randint(0, max(0, H - self.LQ_cropSize))
        rnd_w = random.randint(0, max(0, W - self.GT_cropSize))
        img_LQ = img_LQ[rnd_h:rnd_h + self.LQ_cropSize, rnd_w:rnd_w + self.LQ_cropSize, :]
        rnd_h_GT, rnd_w_GT = int(rnd_h * self.scale), int(rnd_w * self.scale)
        img_GT = img_GT[rnd_h_GT:rnd_h_GT+self.GT_cropSize, rnd_w_GT:rnd_w_GT+self.GT_cropSize, :]
        
        # augmentation - flip, rotate
        img_LQ, img_GT = utils.augment([img_LQ, img_GT])

        # BGR to RGB, HWC to CHW, numpy to tensor
        if img_GT.shape[2] == 3:
            img_GT = img_GT[:, :, [2, 1, 0]]
            img_LQ = img_LQ[:, :, [2, 1, 0]]
        
        img_GT = torch.from_numpy(np.ascontiguousarray(np.transpose(img_GT, (2, 0, 1)))).float()
        img_LQ = torch.from_numpy(np.ascontiguousarray(np.transpose(img_LQ, (2, 0, 1)))).float()

        return img_LQ, img_GT


    def __len__(self):
        return len(self.GT_list)
  • 1
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值