【pytorch笔记】第五篇 torchvision,Dataloader,nn.Module的使用

1 torchvision数据集介绍

① torchvision中有很多数据集,当我们写代码时指定相应的数据集指定一些参数,它就可以自行下载。

② CIFAR-10数据集包含60000张32×32的彩色图片,一共10个类别,其中50000张训练图片,10000张测试图片。

1.1 torchvision数据集使用

import torchvision
help(torchvision.datasets.CIFAR10)
Output exceeds the size limit. Open the full output data in a text editor
Help on class CIFAR10 in module torchvision.datasets.cifar:

class CIFAR10(torchvision.datasets.vision.VisionDataset)
 |  `CIFAR10 <https://www.cs.toronto.edu/~kriz/cifar.html>`_ Dataset.
 |  
 |  Args:
 |      root (string): Root directory of dataset where directory
 |          ``cifar-10-batches-py`` exists or will be saved to if download is set to True.
 |      train (bool, optional): If True, creates dataset from training set, otherwise
 |          creates from test set.
 |      transform (callable, optional): A function/transform that takes in an PIL image
 |          and returns a transformed version. E.g, ``transforms.RandomCrop``
 |      target_transform (callable, optional): A function/transform that takes in the
 |          target and transforms it.
 |      download (bool, optional): If true, downloads the dataset from the internet and
 |          puts it in root directory. If dataset is already downloaded, it is not
 |          downloaded again.
 |  
 |  Method resolution order:
 |      CIFAR10
 |      torchvision.datasets.vision.VisionDataset
 |      torch.utils.data.dataset.Dataset
 |      typing.Generic
 |      builtins.object
 |  
...
 |  
 |  __new__(cls, *args, **kwds)
 |      Create and return a new object.  See help(type) for accurate signature.

1.2 查看CIFAR10数据集内容

import torchvision
train_set = torchvision.datasets.CIFAR10(root="./dataset",train=True,download=True) # root为存放数据集的相对路线
test_set = torchvision.datasets.CIFAR10(root="./dataset",train=False,download=True) # train=True是训练集,train=False是测试集  

print(test_set[0])       # 输出的3是target 
print(test_set.classes)  # 测试数据集中有多少种

img, target = test_set[0] # 分别获得图片、target
print(img)
print(target)

print(test_set.classes[target]) # 3号target对应的种类
img.show()
Files already downloaded and verified
Files already downloaded and verified
(<PIL.Image.Image image mode=RGB size=32x32 at 0x1A4275AAF28>, 3)
['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
<PIL.Image.Image image mode=RGB size=32x32 at 0x1A4275AAA58>
3
cat

2. Dataloader使用

① Dataset只是去告诉我们程序,我们的数据集在什么位置,数据集第一个数据给它一个索引0,它对应的是哪一个数据。

② Dataloader就是把数据加载到神经网络当中,Dataloader所做的事就是每次从Dataset中取数据,至于怎么取,是由Dataloader中的参数决定的。

import torchvision
from torch.utils.data import DataLoader

# 准备的测试数据集
test_data = torchvision.datasets.CIFAR10("./dataset",train=False,transform=torchvision.transforms.ToTensor())               
img, target = test_data[0]
print(img.shape)
print(img)

# batch_size=4 使得 img0, target0 = dataset[0]、img1, target1 = dataset[1]、img2, target2 = dataset[2]、img3, target3 = dataset[3],然后这四个数据作为Dataloader的一个返回      
test_loader = DataLoader(dataset=test_data,batch_size=4,shuffle=True,num_workers=0,drop_last=False)      
# 用for循环取出DataLoader打包好的四个数据
for data in test_loader:
    imgs, targets = data # 每个data都是由4张图片组成,imgs.size 为 [4,3,32,32],四张32×32图片三通道,targets由四个标签组成             
    print(imgs.shape)
    print(targets)

这里输出结果信息有点长,可以自行运行看结果

2.1 Dataloader多轮次

import torchvision
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

# 准备的测试数据集
test_data = torchvision.datasets.CIFAR10("./dataset",train=False,transform=torchvision.transforms.ToTensor())               
# batch_size=4 使得 img0, target0 = dataset[0]、img1, target1 = dataset[1]、img2, target2 = dataset[2]、img3, target3 = dataset[3],然后这四个数据作为Dataloader的一个返回      
test_loader = DataLoader(dataset=test_data,batch_size=64,shuffle=True,num_workers=0,drop_last=True)      
# 用for循环取出DataLoader打包好的四个数据
writer = SummaryWriter("logs")
for epoch in range(2):
    step = 0
    for data in test_loader:
        imgs, targets = data # 每个data都是由4张图片组成,imgs.size 为 [4,3,32,32],四张32×32图片三通道,targets由四个标签组成             
        writer.add_images("Epoch:{}".format(epoch),imgs,step)
        step = step + 1
    
writer.close()

在这里插入图片描述

3. nn.Module模块使用

① nn.Module是对所有神经网络提供一个基本的类。

② 我们的神经网络是继承nn.Module这个类,即nn.Module为父类,nn.Module为所有神经网络提供一个模板,对其中一些我们不满意的部分进行修改。

import torch
from torch import nn

class MyModule(nn.Module):
    def __init__(self):
        super(MyModule, self).__init__()  # 继承父类的初始化
        
    def forward(self, input):          # 将forward函数进行重写
        output = input + 1
        return output
    
myModule= MyModule()
x = torch.tensor(1.0)  # 创建一个值为 1.0 的tensor
output = myModule(x)
print(output)
tensor(2.)

3.1 super(Myclass, self)._init_()

① 简单理解就是子类把父类的__init__()放到自己的__init__()当中,这样子类就有了父类的_init_()的那些东西。

② Myclass类继承nn.Module,super(Myclass, self).__init__()就是对继承自父类nn.Module的属性进行初始化。而且是用nn.Module的初始化方法来初始化继承的属性。

③ super().__init()__()来通过初始化父类属性以初始化自身继承了父类的那部分属性;这样一来,作为nn.Module的子类就无需再初始化那一部分属性了,只需初始化新加的元素。

③ 子类继承了父类的所有属性和方法,父类属性自然会用父类方法来进行初始化。

3.2 forward函数

① 使用pytorch的时候,不需要手动调用forward函数,只要在实例化一个对象中传入对应的参数就可以自动调用 forward 函数。

② 因为 PyTorch 中的大部分方法都继承自 torch.nn.Module,而 torch.nn.Module 的__call__(self)函数中会返回 forward()函数 的结果,因此PyTroch中的 forward()函数等于是被嵌套在了__call__(self)函数中;因此forward()函数可以直接通过类名被调用,而不用实例化对象。

class A():
    def __call__(self, param):
        print('i can called like a function')
        print('传入参数的类型是:{}   值为: {}'.format(type(param), param))
        res = self.forward(param)
        return res
    
    def forward(self, input_):
        print('forward 函数被调用了')
        print('in  forward, 传入参数类型是:{}  值为: {}'.format( type(input_), input_))
        return input_

a = A()
input_param = a('i')
print("对象a传入的参数是:", input_param)
i can called like a function
传入参数的类型是:<class 'str'>   值为: i
forward 函数被调用了
in  forward, 传入参数类型是:<class 'str'>  值为: i
对象a传入的参数是: i
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值