Minist手写数据集测试
(个人实践笔记,如有纰漏,烦请指出)
import torch
import torch.nn as nn
import torchvision.datasets as dsets
import torchvision.transforms as transforms
input_size = 28 * 28 # image size of MNIST data
num_classes = 10
num_epochs = 10
batch_size = 100
learning_rate = 1e-3
# MNIST dataset
train_dataset = dsets.MNIST(root = '../../data_sets/mnist', #选择数据的根目录
train = True, # 选择训练集
transform = transforms.ToTensor(), #转换成tensor变量
download = False) # 不从网络上download图片
test_dataset = dsets.MNIST(root = '../../data_sets/mnist', #选择数据的根目录
train = False, # 选择训练集
transform = transforms.ToTensor(), #转换成tensor变量
download = False) # 不从网络上download图片
#加载数据
t