import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
"""
torchvision内置了常用数据集和常见的模型
"""
import torchvision
from torchvision import datasets, transforms
transformation = transforms.Compose([transforms.ToTensor()
])
train_ds = datasets.MNIST("./data", train=True, transform=transformation, download=True)
test_ds = datasets.MNIST("./data", train=False, transform=transformation, download=True)
train_dl = DataLoader(train_ds, batch_size=64, shuffle=True)
test_dl = DataLoader(test_ds, batch_size=256)
imgs, labels = next(iter(train_dl))
print("imgs.shape:\t", imgs.shape)
"""
在pytorch里面图片的表示形式:[batch, channel, hight, width]
"""
img = imgs[0]
print("img.shape:\t", img.shape)
img = img.numpy()
print("img.shape:\t", img.shape)
img = np.squeeze(img)
print("np.squeeze(img):\t", img)
print("img.shape:\t", img.shape)
plt.imshow(img)
plt.show()
print("labels[0]:\t", labels[0])
def imshow(img):
npimg = img.numpy()
npimg = np.squeeze(npimg)
plt.imshow(npimg)
plt.figure(figsize=(10, 1))
for i, img in enumerate(imgs[:10]):
plt.subplot(1, 10, i+1)
imshow(img)
plt.show()
class Model(nn.Module):
def __init__(self):
super().__init__()
self.liner_1 = nn.Linear(784, 120)
self.liner_2 = nn.Linear(120, 84)
self.liner_3 = nn.Linear(84, 10)
def forward(self, input):
x = input.view(-1, 784)
x = F.relu(self.liner_1(x))
x = F.relu(self.liner_2(x))
x = self.liner_3(x)
return x
loss_fn = torch.nn.CrossEntropyLoss()
model = Model()
def fit(epoch, model, trainloader, testloader):
correct = 0
total = 0
running_loss = 0
for x, y in trainloader:
y_pred = model(x)
loss = loss_fn(y_pred, y)
optim.zero_grad()
loss.backward()
optim.step()
with torch.no_grad():
y_pred = torch.argmax(y_pred, dim=1)
correct += (y_pred == y).sum().item()
total += y.size(0)
running_loss += loss.item()
epoch_loss = running_loss / len(trainloader.dataset)
epoch_acc = correct / total
test_correct = 0
test_total = 0
test_running_loss = 0
with torch.no_grad():
for x, y in testloader:
y_pred = model(x)
loss = loss_fn(y_pred, y)
y_pred = torch.argmax(y_pred, dim=1)
test_correct += (y_pred == y).sum().item()
test_total += y.size(0)
test_running_loss += loss.item()
epoch_test_loss = running_loss / len(trainloader.dataset)
epoch_test_acc = correct / total
print("epoch:\t{} loss:\t{} accuracy:\t{} test_loss:\t{} test_accuracy:\t{}".format(epoch, round(epoch_loss, 3),
round(epoch_acc, 3),
round(epoch_test_loss, 3),
round(epoch_test_acc, 3)))
return epoch_loss, epoch_acc, epoch_test_loss, epoch_test_acc
optim = torch.optim.Adam(model.parameters(), lr=0.001)
epochs = 20
train_loss = []
train_acc = []
test_loss = []
test_acc = []
for epoch in range(epochs):
epoch_loss, epoch_acc, epoch_test_loss, epoch_test_acc = fit(epoch, model, train_dl, test_dl)
train_loss.append(epoch_loss)
train_acc.append(epoch_acc)
test_loss.append(epoch_test_loss)
test_loss.append(epoch_test_acc)
imgs.shape: torch.Size([64, 1, 28, 28])
img.shape: torch.Size([1, 28, 28])
img.shape: (1, 28, 28)
np.squeeze(img): [[0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. ]
[0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. ]
[0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. ]
[0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. ]
[0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. ]
[0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. ]
[0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0.09803922 0.69411767 0.61960787 0.8784314 0.39607844
0.00784314 0.1254902 0.7490196 0.04705882 0. 0.
0. 0. 0. 0. ]
[0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.14509805
0.6901961 0.9137255 0.9882353 0.99215686 0.9882353 0.9882353
0.15294118 0.33333334 0.9882353 0.27450982 0. 0.
0. 0. 0. 0. ]
[0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0.2901961 0.9019608
0.9882353 0.9882353 0.9882353 0.6039216 0.6 0.6
0.5254902 0.92941177 0.9882353 0.08235294 0. 0.
0. 0. 0. 0. ]
[0. 0. 0. 0. 0. 0.
0. 0. 0. 0.34117648 0.68235296 0.9882353
0.972549 0.5803922 0.09411765 0. 0. 0.07450981
0.5568628 0.8745098 0.32156864 0.00392157 0. 0.
0. 0. 0. 0. ]
[0. 0. 0. 0. 0. 0.
0. 0. 0.05882353 0.87058824 0.9882353 0.77254903
0.28235295 0. 0. 0.03137255 0.33333334 0.92941177
0.9882353 0.6784314 0. 0. 0. 0.
0. 0. 0. 0. ]
[0. 0. 0. 0. 0. 0.
0. 0. 0.39215687 0.9882353 0.7176471 0.01960784
0. 0. 0.15294118 0.80784315 0.9882353 0.9882353
0.92156863 0.07450981 0. 0. 0. 0.
0. 0. 0. 0. ]
[0. 0. 0. 0. 0. 0.
0. 0. 0.9490196 0.9882353 0.6431373 0.02352941
0.12156863 0.60784316 0.9411765 0.6117647 0.6745098 0.9137255
0.88235295 0. 0. 0. 0. 0.
0. 0. 0. 0. ]
[0. 0. 0. 0. 0. 0.
0. 0. 0.9490196 0.9882353 0.85490197 0.7882353
0.9882353 0.9137255 0.6509804 0.6039216 0.8745098 0.9882353
0.4745098 0. 0. 0. 0. 0.
0. 0. 0. 0. ]
[0. 0. 0. 0. 0. 0.
0. 0. 0.30980393 0.9882353 0.9882353 0.9882353
0.8039216 0.25882354 0.03529412 0.13333334 0.9882353 0.9882353
0.15294118 0. 0. 0. 0. 0.
0. 0. 0. 0. ]
[0. 0. 0. 0. 0. 0.
0. 0. 0.01176471 0.4117647 0.5568628 0.3529412
0.05098039 0. 0.19215687 0.69411767 0.9882353 0.6509804
0.00784314 0. 0. 0. 0. 0.
0. 0. 0. 0. ]
[0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0.43529412 1. 0.99215686 0.17254902
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. ]
[0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0.08235294 0.7647059 0.99215686 0.80784315 0.05098039
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. ]
[0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0.65882355 0.9882353 0.99215686 0.56078434 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. ]
[0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0.12941177 0.8666667 0.9882353 0.83137256 0.01176471 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. ]
[0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.01176471
0.6862745 0.9882353 0.9882353 0.43529412 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. ]
[0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.04705882
0.9882353 0.9882353 0.8392157 0.09411765 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. ]
[0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.4627451
0.9882353 0.9882353 0.5568628 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. ]
[0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0.0627451 0.92156863
0.9882353 0.9372549 0.27450982 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. ]
[0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0.34901962 0.9882353
0.9882353 0.36862746 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. ]
[0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0.34901962 0.9882353
0.6509804 0.03529412 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. ]
[0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. ]
[0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. ]]
img.shape: (28, 28)
labels[0]: tensor(9)
epoch: 0 loss: 0.005 accuracy: 0.906 test_loss: 0.005 test_accuracy: 0.906
epoch: 1 loss: 0.002 accuracy: 0.958 test_loss: 0.002 test_accuracy: 0.958
epoch: 2 loss: 0.001 accuracy: 0.971 test_loss: 0.001 test_accuracy: 0.971
epoch: 3 loss: 0.001 accuracy: 0.978 test_loss: 0.001 test_accuracy: 0.978
epoch: 4 loss: 0.001 accuracy: 0.983 test_loss: 0.001 test_accuracy: 0.983
epoch: 5 loss: 0.001 accuracy: 0.987 test_loss: 0.001 test_accuracy: 0.987
epoch: 6 loss: 0.001 accuracy: 0.988 test_loss: 0.001 test_accuracy: 0.988
epoch: 7 loss: 0.0 accuracy: 0.991 test_loss: 0.0 test_accuracy: 0.991
epoch: 8 loss: 0.0 accuracy: 0.992 test_loss: 0.0 test_accuracy: 0.992
epoch: 9 loss: 0.0 accuracy: 0.993 test_loss: 0.0 test_accuracy: 0.993
epoch: 10 loss: 0.0 accuracy: 0.994 test_loss: 0.0 test_accuracy: 0.994
epoch: 11 loss: 0.0 accuracy: 0.995 test_loss: 0.0 test_accuracy: 0.995
epoch: 12 loss: 0.0 accuracy: 0.995 test_loss: 0.0 test_accuracy: 0.995
epoch: 13 loss: 0.0 accuracy: 0.995 test_loss: 0.0 test_accuracy: 0.995
epoch: 14 loss: 0.0 accuracy: 0.996 test_loss: 0.0 test_accuracy: 0.996
epoch: 15 loss: 0.0 accuracy: 0.996 test_loss: 0.0 test_accuracy: 0.996
epoch: 16 loss: 0.0 accuracy: 0.996 test_loss: 0.0 test_accuracy: 0.996
epoch: 17 loss: 0.0 accuracy: 0.997 test_loss: 0.0 test_accuracy: 0.997
epoch: 18 loss: 0.0 accuracy: 0.997 test_loss: 0.0 test_accuracy: 0.997
epoch: 19 loss: 0.0 accuracy: 0.997 test_loss: 0.0 test_accuracy: 0.997