基于深度学习的超分辨实践,包含:数据集的创建,模型的搭建,模型的训练,模型的测试。
这里有两点疑惑:数据集的创建,模型的训练
数据集的创建
数据集从文件夹获取。这里与原文的区别:1 没有数据增强 2 随机采样好像和原文有区别
import random
import glob
import numpy as np
import PIL.Image as pil_image
from torch import nn
class Dataset(object):
def __init__(self, images_dir, patch_size, scale):
self.image_files = sorted(glob.glob(images_dir + '/*'))
self.patch_size = patch_size
self.scale = scale
def __getitem__(self, idx):
hr = pil_image.open(self.image_files[idx]).convert('RGB')
# 随机裁剪
crop_x = random.randint(0, hr.width - self.patch_size * self.scale)
crop_y = random.randint(0, hr.height - self.patch_size * self.scale)
hr = hr.crop((crop_x, crop_y, crop_x + self.patch_size * self.scal