在深度学习的训练过程中,我们需要将数据分批的放入到训练网络中,批数量的大小也被成为batch_size,通过pytorch提供的dataloader方法,可以自动实现一个迭代器,每次返回一组batch_size个样本和标签来进行训练。
下面是一个dataloader的简单例子:
import torch
import torch.utils.data as Data
BATCH_SIZE = 5
#生成1,2,3,4,5,6,7,8,9,10
x = torch.linspace(1, 10, 10)
#生成10,9,8,7,6,5,4,3,2,1
y = torch.linspace(10, 1, 10)
# 对于给定的tensor数据(样本和标签),将其包装为dataset
torch_dataset = Data.TensorDataset(x, y)
#创建一个dataloader类的实例
loader = Data.DataLoader(
# 从数据库中每次抽出batch size个样本
dataset=torch_dataset,
batch_size=BATCH_SIZE,
shuffle=True,
num_workers=2,
)
def show_batch():
for epoch in range(3):
for step, (batch_x, batch_y) in enumerate(loader):
# training
print("step:{}, batch_x:{}, batch_y:{}".format(step, batch_x, batch_y))
# def show_batch():
# for epoch in range(3):
# for (batch_x, batch_y) in loader:
# # training
#
#
# print(" batch_x:{}, batch_y:{}".format(batch_x, batch_y))
if __name__ == '__main__':
show_batch()
dataloader的使用十分简单,重要的是我们需要提前构造dataloader需要的一个重要参数dataset。
下面是一个创建dataset数据集的流程
在图像超分辨中,我们需要创建低分辨图像以及高分辨图像的数据对,低分辨图像为样本,高分辨图像为标签,每次训练要为神经网络传入batch_size个这样的数据对。
- 创建.h5文件
import argparse
import glob
import h5py
import numpy as np
import PIL.Image as pil_image
from utils import calc_patch_size, convert_rgb_to_y
@calc_patch_size
def train(args):
h5_file = h5py.File(args.output_path, 'w')#h5文件输出路劲
lr_patches = []#低分辨块列表
hr_patches = []#高分辨块列表
for image_path in sorted(glob.glob('{}/*'.format(args.images_dir))):#获得对应图像的完整目录
#glob.glob()函数,*列出该目录下的所有文件列表
hr = pil_image.open(image_path).convert('RGB')
hr_images = []
if args.with_aug:
for s in [1.0, 0.9, 0.8, 0.7, 0.6]:
for r in [0, 90, 180, 270]:
tmp = hr.resize((int(hr.width * s), int(hr.height * s)), resample=pil_image.BICUBIC)
tmp = tmp.rotate(r, expand=True)
hr_images.append(tmp)
else:
hr_images.append(hr)
for hr in hr_images:
hr_width = (hr.width // args.scale) * args.scale
hr_height = (hr.height // args.scale) * args.scale#得到一个标准的大小
hr = hr.resize((hr_width, hr_height), resample=pil_image.BICUBIC)
lr = hr.resize((hr.width // args.scale, hr_height // args.scale), resample=pil_image.BICUBIC)#到这就得到了HRLR
hr = np.array(hr).astype(np.float32)
lr = np.array(lr).astype(np.float32)
hr = convert_rgb_to_y(hr)
lr = convert_rgb_to_y(lr)#转化为亮度
for i in range(0, lr.shape[0] - args.patch_size + 1, args.scale):
for j in range(0, lr.shape[1] - args.patch_size + 1, args.scale):
lr_patches.append(lr[i:i+args.patch_size, j:j+args.patch_size])
hr_patches.append(hr[i*args.scale:i*args.scale+args.patch_size*args.scale, j*args.scale:j*args.scale+args.patch_size*args.scale])#将一张图像分成小块写入
lr_patches = np.array(lr_patches)
hr_patches = np.array(hr_patches)
h5_file.create_dataset('lr', data=lr_patches)
h5_file.create_dataset('hr', data=hr_patches)
h5_file.close()