pytorch中DataLoader详解

在深度学习的训练过程中,我们需要将数据分批的放入到训练网络中,批数量的大小也被成为batch_size,通过pytorch提供的dataloader方法,可以自动实现一个迭代器,每次返回一组batch_size个样本和标签来进行训练。
下面是一个dataloader的简单例子:

import torch
import torch.utils.data as Data

BATCH_SIZE = 5
#生成1,2,3,4,5,6,7,8,9,10
x = torch.linspace(1, 10, 10)
#生成10,9,8,7,6,5,4,3,2,1
y = torch.linspace(10, 1, 10)
# 对于给定的tensor数据(样本和标签),将其包装为dataset
torch_dataset = Data.TensorDataset(x, y)
#创建一个dataloader类的实例
loader = Data.DataLoader(
    # 从数据库中每次抽出batch size个样本
    dataset=torch_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=2,
)


def show_batch():
    for epoch in range(3):
        for step, (batch_x, batch_y) in enumerate(loader):
            # training


            print("step:{}, batch_x:{}, batch_y:{}".format(step, batch_x, batch_y))

# def show_batch():
#     for epoch in range(3):
#         for (batch_x, batch_y) in loader:
#             # training
# 
# 
#             print(" batch_x:{}, batch_y:{}".format(batch_x, batch_y))

if __name__ == '__main__':
    show_batch()

dataloader的使用十分简单,重要的是我们需要提前构造dataloader需要的一个重要参数dataset。
下面是一个创建dataset数据集的流程
在图像超分辨中,我们需要创建低分辨图像以及高分辨图像的数据对,低分辨图像为样本,高分辨图像为标签,每次训练要为神经网络传入batch_size个这样的数据对。

  1. 创建.h5文件
import argparse
import glob
import h5py
import numpy as np
import PIL.Image as pil_image
from utils import calc_patch_size, convert_rgb_to_y


@calc_patch_size
def train(args):
    h5_file = h5py.File(args.output_path, 'w')#h5文件输出路劲

    lr_patches = []#低分辨块列表
    hr_patches = []#高分辨块列表

    for image_path in sorted(glob.glob('{}/*'.format(args.images_dir))):#获得对应图像的完整目录
        #glob.glob()函数,*列出该目录下的所有文件列表   
        hr = pil_image.open(image_path).convert('RGB')
        hr_images = []

        if args.with_aug:
            for s in [1.0, 0.9, 0.8, 0.7, 0.6]:
                for r in [0, 90, 180, 270]:
                    tmp = hr.resize((int(hr.width * s), int(hr.height * s)), resample=pil_image.BICUBIC)
                    tmp = tmp.rotate(r, expand=True)
                    hr_images.append(tmp)
        else:
            hr_images.append(hr)

        for hr in hr_images:
            hr_width = (hr.width // args.scale) * args.scale
            hr_height = (hr.height // args.scale) * args.scale#得到一个标准的大小
            hr = hr.resize((hr_width, hr_height), resample=pil_image.BICUBIC)
            lr = hr.resize((hr.width // args.scale, hr_height // args.scale), resample=pil_image.BICUBIC)#到这就得到了HRLR
            hr = np.array(hr).astype(np.float32)
            lr = np.array(lr).astype(np.float32)
            hr = convert_rgb_to_y(hr)
            lr = convert_rgb_to_y(lr)#转化为亮度

            for i in range(0, lr.shape[0] - args.patch_size + 1, args.scale):
                for j in range(0, lr.shape[1] - args.patch_size + 1, args.scale):
                    lr_patches.append(lr[i:i+args.patch_size, j:j+args.patch_size])
                    hr_patches.append(hr[i*args.scale:i*args.scale+args.patch_size*args.scale, j*args.scale:j*args.scale+args.patch_size*args.scale])#将一张图像分成小块写入

    lr_patches = np.array(lr_patches)
    hr_patches = np.array(hr_patches)

    h5_file.create_dataset('lr', data=lr_patches)
    h5_file.create_dataset('hr', data=hr_patches)

    h5_file.close()


评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值