PyTorch中常用工具
torchvision
- models:提供深度学习中各种经典的网络结构以及预训练好的模型,包括AlexNet, VGG, ResNet, Inception
- datsets:提供常用的数据集加载,设计上均继承torch.utils.data.dataset,主要包括MNIST,CIFAR10/100,ImageNet,COCO
- transforms:提供常用的数据预处理操作,主要包括对tensor和PILImage对象的操作
- torchvision.utils.save_image:直接将tensor保存成图片
from torch.utils import data import os from PIL import Image from torchvision import transforms from torchvision import utils import numpy import torch class Data(data.Dataset): def __init__(self, root): # 返回指定路径下的文件和文件夹列表。 imgs_HR = os.listdir(os.path.join(root, 'gt')) self.imgs_HR = [os.path.join(root, 'gt', img) for img in imgs_HR] imgs_LR = os.listdir(os.path.join(root, 'lr')) self.imgs_LR = [os.path.join(root, 'lr', img) for img in imgs_LR] self.transform = transforms.ToTensor() def _