本篇笔记是PyTorch学习的最后一篇,使用PyTorch搭建神经网络,完成手写体数字识别问题。
1. 导入数据并可视化
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
#定义batch size
batch_size = 64
#下载MNIST数据集
train_dataset = datasets.MNIST(root='./data/',
train=True,
transform=transforms.ToTensor(),
download=True)
test_dataset = datasets.MNIST(root='./data/',
train=False,
transform=transforms.ToTensor())
#将下载的MNIST数据导入到dataloader中
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
batch_size=batch_size,
shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
batch_size=batch_size,
s