注意缩进,
import torch
from torch import nn
from torch import optim
from torch.nn import functional as F
import torchvision
from matplotlib import pyplot as plt
batch_size = 512
train_loader = torch.utils.data.DataLoader(
torchvision.datasets.MNIST('mnist_data',train = True, download = True,
transform = torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(
(0.1307,),(0.3081,)
)
])),batch_size = batch_size,shuffle = True
)
test_loader = torch.utils.data.DataLoader(
torchvision.datasets.MNIST('mnist_data/',train = False, download = True,
transform = torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(
(0.1307,),(0.3081,)
)
])),batch_size = batch_size,shuffle = False
)
x, y = next(iter(train_loader))
print(x.shape)
class Net(nn.Module):
def __init__(self):
super(Net,self).__init__()
self.fc1 = nn.Linear(28*28, 256)
self.fc2 = nn.Linear(256,64)
self.fc3 = nn.Linear(64,10)
def forward(self, x):
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
def one_hot(label,depth = 10):
out = torch.zeros(label.size(0),depth)
idx = torch.LongTensor(label).view(-1,1)
out.scatter_(dim = 1,index = idx, value = 1)
return out
net = Net()
optimizer = optim.SGD(net.parameters(),lr = 0.01, momentum = 0.9)
#AdaGrad
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size = 200, gamma=0.1, last_epoch=-1)
for epoch in range(3):
for batch_idx,(x,y) in enumerate(train_loader):
x = x.view(x.size(0),28*28)
out = net(x)
y_onehot = one_hot(y)
loss = F.mse_loss(out,y_onehot)
optimizer.zero_grad()
loss.backward()
optimizer.step()
scheduler.step()
if batch_idx % 20 == 0:
print(epoch, batch_idx,loss.item())
total_correct = 0
for x,y in test_loader:
x = x.view(x.size(0),28*28)
out = net(x)
pred = out.argmax(dim = 1)
correct = pred.eq(y).sum().float().item()
total_correct += correct
total_num = len(test_loader.dataset)
acc = total_correct / total_num
print('test acc:',acc)