pytorch中的torchvision常用功能简单介绍

        这里给大家带来的pytorch读库代码,相信大家在做深度学习的图像质量评价时,第一步就会卡在读库代码,那么如何读取LIVE、LIVEC、TID2013、ESPL-LIVE、TMID等数据库呢?

涉及到的一些库函数

一、torchvision

作为pytorch的一个图形库,torchvision发挥着很重要的作用。

1.torchvision.datasets: 一些加载数据的函数及常用的数据集接口;
2.torchvision.models: 包含常用的模型结构(含预训练模型),例如AlexNet、VGG、ResNet等;
3.torchvision.transforms: 常用的图片变换,例如裁剪、旋转等;
4.torchvision.utils: 其他的一些有用的方法。

1. torchvision.transforms

主要是做图像的一些预处理,比如裁剪、归一化等。

torchvision.transforms.Compose 主要是串联多个图像变化的操作,构造如下:

transforms = torchvision.transforms.Compose([
        torchvision.transforms.RandomHorizontalFlip(),
        torchvision.transforms.RandomCrop(size=args.patch_size),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize(mean=(0.485, 0.456, 0.406),
                                         std=(0.229, 0.224, 0.225))
    ])

2.torchvision.datasets

torchvision.datasets 是用来进行数据加载的,PyTorch团队在这个包中帮我们提前处理好了很多很多图片数据集。

"LSUN",
"LSUNClass",
"ImageFolder",
"CIFAR10",
"CIFAR100",
.........
"QMNIST",
"MNIST",
"KMNIST",
"DTD",
"FE
"FGVCAircraft",
"EuroSAT",
"RenderedSST2",

导入torchvision数据集:

# 图像处理
img_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,)),   #灰度图用这个进行归一化
])
# MNIST数据集
'''
train=True表示导入训练集数据,如果自己电脑中带了数据集,那么download=True可以改成download=False,并在root中改成自己的路径
'''
mnist_train = datasets.MNIST(
    root='./data/', train=True, transform=img_transform, download=True) 

mnist_test = datasets.MNIST(
    root='./data/', train=False, transform=img_transform, download=True) 

# 批量数据读取
train_loader = data_utils.DataLoader(dataset = mnist_train,
                                  batch_size = 64,
                                  shuffle = True)  #训练数据可以加个Shuffle
 
test_loader = data_utils.DataLoader(dataset = mnist_test,
                                  batch_size = 64)


3.torchvision.models 

 torchvision.models 中为我们提供了已经训练好的模型,让我们可以加载之后,直接使用。

 AlexNet
VGG
ResNet

..........
SqueezeNet
DenseNet

调用代码 

import torchvision.models as models

resnet50 = models.resnet50(pretrained=True)  #加入pretrained 调用预训练模型
alexnet = models.alexnet(pretrained=True)

调用mnist数据集整体代码

# 我们这里还是对MNIST进行处理,初始的MNIST是 28 * 28,我们把它处理成 96 * 96 的torch.Tensor的格式
from torchvision import transforms as transforms
import torchvision
from torch.utils.data import DataLoader
 
# 图像预处理步骤
transform = transforms.Compose([
    transforms.Resize(96), # 缩放到 96 * 96 大小
    transforms.ToTensor(),
    transforms.Normalize((0.5), (0.5)) # 归一化
])
 

BATCH_SIZE = 64
 
train_dataset = torchvision.datasets.MNIST(root='./data/', train=True, transform=transform, download=True)

test_dataset = torchvision.datasets.MNIST(root='./data/', train=False, transform=transform, download=True)
 
train_loader = DataLoader(dataset=train_dataset,
                          batch_size=BATCH_SIZE,
                          shuffle=True)
 
print(len(train_dataset))
print(len(train_loader))
  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值