这里给大家带来的pytorch读库代码,相信大家在做深度学习的图像质量评价时,第一步就会卡在读库代码,那么如何读取LIVE、LIVEC、TID2013、ESPL-LIVE、TMID等数据库呢?
涉及到的一些库函数
一、torchvision
作为pytorch的一个图形库,torchvision发挥着很重要的作用。
1.torchvision.datasets: 一些加载数据的函数及常用的数据集接口;
2.torchvision.models: 包含常用的模型结构(含预训练模型),例如AlexNet、VGG、ResNet等;
3.torchvision.transforms: 常用的图片变换,例如裁剪、旋转等;
4.torchvision.utils: 其他的一些有用的方法。
1. torchvision.transforms
主要是做图像的一些预处理,比如裁剪、归一化等。
torchvision.transforms.Compose 主要是串联多个图像变化的操作,构造如下:
transforms = torchvision.transforms.Compose([
torchvision.transforms.RandomHorizontalFlip(),
torchvision.transforms.RandomCrop(size=args.patch_size),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(mean=(0.485, 0.456, 0.406),
std=(0.229, 0.224, 0.225))
])
2.torchvision.datasets
torchvision.datasets
是用来进行数据加载的,PyTorch团队在这个包中帮我们提前处理好了很多很多图片数据集。
"LSUN", "LSUNClass", "ImageFolder", "CIFAR10", "CIFAR100", ......... "QMNIST", "MNIST", "KMNIST", "DTD", "FE "FGVCAircraft", "EuroSAT", "RenderedSST2",
导入torchvision数据集:
# 图像处理
img_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,)), #灰度图用这个进行归一化
])
# MNIST数据集
'''
train=True表示导入训练集数据,如果自己电脑中带了数据集,那么download=True可以改成download=False,并在root中改成自己的路径
'''
mnist_train = datasets.MNIST(
root='./data/', train=True, transform=img_transform, download=True)
mnist_test = datasets.MNIST(
root='./data/', train=False, transform=img_transform, download=True)
# 批量数据读取
train_loader = data_utils.DataLoader(dataset = mnist_train,
batch_size = 64,
shuffle = True) #训练数据可以加个Shuffle
test_loader = data_utils.DataLoader(dataset = mnist_test,
batch_size = 64)
3.torchvision.models
torchvision.models
中为我们提供了已经训练好的模型,让我们可以加载之后,直接使用。
..........
SqueezeNet
DenseNet
调用代码
import torchvision.models as models
resnet50 = models.resnet50(pretrained=True) #加入pretrained 调用预训练模型
alexnet = models.alexnet(pretrained=True)
调用mnist数据集整体代码
# 我们这里还是对MNIST进行处理,初始的MNIST是 28 * 28,我们把它处理成 96 * 96 的torch.Tensor的格式
from torchvision import transforms as transforms
import torchvision
from torch.utils.data import DataLoader
# 图像预处理步骤
transform = transforms.Compose([
transforms.Resize(96), # 缩放到 96 * 96 大小
transforms.ToTensor(),
transforms.Normalize((0.5), (0.5)) # 归一化
])
BATCH_SIZE = 64
train_dataset = torchvision.datasets.MNIST(root='./data/', train=True, transform=transform, download=True)
test_dataset = torchvision.datasets.MNIST(root='./data/', train=False, transform=transform, download=True)
train_loader = DataLoader(dataset=train_dataset,
batch_size=BATCH_SIZE,
shuffle=True)
print(len(train_dataset))
print(len(train_loader))