-- coding:utf-8 --
import torch
import torch.nn as nn
import torchvision.datasets as dsets
import torchvision.transforms as transforms
from torch.autograd import Variable
模型参数
单个图像大小28*28=784
input_size = 784
输出维度
num_classes = 10
训练轮数
num_epochs = 10
加载批训练数据个数
batch_size = 50
学习率
learning_rate = 0.001
训练集(数据集如果已经下载了,就不会在运行时下载,root参数指向的是数据集目录)
train_dataset = dsets.MNIST(root=’./data’, train=True, transform=transforms.ToTensor(),
download=True)
测试集
test_dataset = dsets.MNIST(root=’./data’, train=False, transform=transforms.ToTensor())