思路:
网络是三层:
h1=relu(w1*x+b1)
-输入28*28,输出256h2=relu(w2*h1+b2)
-输入256,输出64h3=w3*h2+b3
-输入64,输出10
其中256,64是自己构想的,10是固定的结果10分类
每次是512个28*28的1通道图片进入一起处理,处理三次得出结果
注意
每次计算梯度都要先清零:optimizer.zero_grad()
,不然梯度累加就不正确
import torch
from torch import nn
from torch.nn import functional as F
from torch import optim
import torchvision
from matplotlib import pyplot as plt
from utils import plot_image, plot_curve, one_hot
batch_size = 512
# step1. load dataset
train_loader = torch.utils.data.DataLoader(
torchvision.datasets.MNIST('mnist_data', train=True, download=False,
transform=torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),