torchvision中datasets.MNIST介绍

用法介绍

torchvisiondatasets中所有封装的数据集都是torch.utils.data.Dataset的子类,它们都实现了__getitem____len__方法。因此,它们都可以用torch.utils.data.DataLoader进行数据加载。以datasets.MNIST类为例,具体参数和用法如下所示:

CLASS torchvision.datasets.MNIST(root: str, train: bool = True, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, download: bool = False)

  • root (string): 表示数据集的根目录,其中根目录存在MNIST/processed/training.ptMNIST/processed/test.pt的子目录
  • train (bool, optional): 如果为True,则从training.pt创建数据集,否则从test.pt创建数据集
  • download (bool, optional): 如果为True,则从internet下载数据集并将其放入根目录。如果数据集已下载,则不会再次下载
  • transform (callable, optional): 接收PIL图片并返回转换后版本图片的转换函数
  • target_transform (callable, optional): 接收PIL接收目标并对其进行变换的转换函数

datasets.MNIST(“mnist-data”)下载mnist数据集之后生成的目录结构如下所示

代码实例

以下是用datasets.MNIST()类加载mnist数据集,并进行训练的完整代码

from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Dataset
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

class CNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.Sq1 = nn.Sequential(         
            nn.Conv2d(in_channels=1, out_channels=16, kernel_size=5, stride=1, padding=2),   # (16, 28, 28)                           #  output: (16, 28, 28)
            nn.ReLU(),                    
            nn.MaxPool2d(kernel_size=2),    # (16, 14, 14)
        )
        self.Sq2 = nn.Sequential(
            nn.Conv2d(in_channels=16, out_channels=32, kernel_size=5, stride=1, padding=2),  # (32, 14, 14)
            nn.ReLU(),                      
            nn.MaxPool2d(2),                # (32, 7, 7)
        )
        self.out = nn.Linear(32 * 7 * 7, 10)   

    def forward(self, x):
        x = self.Sq1(x)
        x = self.Sq2(x)
        x = x.view(x.size(0), -1)          
        output = self.out(x)
        return output
        
def train():
    epoches = 1
    mnist_net = CNN()
    mnist_net.train()
    loss_fn = nn.CrossEntropyLoss()
    opitimizer = optim.SGD(mnist_net.parameters(), lr=0.002)
    mnist_train = datasets.MNIST("mnist-data", train=True, download=True, transform=transforms.ToTensor())
    train_loader = torch.utils.data.DataLoader(mnist_train, batch_size= 5, shuffle=True)

    loss = 0 
    for epoch in range(epoches):
    	for batch_X, batch_Y in train_loader:
    		opitimizer.zero_grad()
    		outputs = mnist_net(batch_X)
    		loss = loss_fn(outputs, batch_Y)
    		loss.backward()
    		opitimizer.step()
    		loss += loss.item()
    		print(loss)

if __name__ == '__main__':
	train()
  • 24
    点赞
  • 102
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 2
    评论
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

道2024

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值