基于pytorch的手写数字识别
手写数字识别
本文基于MNIST数据集,使用torch基于自定义网络结构实现对数据集的识别,并且对交叉熵损失函数进行详细说明。
数据获取
import torch
from torch import nn
from torch import utils
from torch.nn import functional as F
from torch import optim
import torchvision
from visdom import Visdom
batch_size = 512
train_loader = utils.data.DataLoader(
torchvision.datasets.MNIST('mnist_data', train=True, download=True,
transform=torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize((0.1307,), (0.3801,))
])), batch_size=batch_size, shuffle=