跟着B站浙江大学教授【深度学习框架pytorch】课程手敲的代码
边听课边注释
utils.py
#四个步骤:load data; bulid model; train; test
import torch
from matplotlib import pyplot as plt
def plot_curve(data): #绘制loss下降的曲线图
fig = plt.figure()
plt.plot(range(len(data)), data, color = 'blue')
plt.legend(['value'], loc = 'upper right')
plt.xlabel('step')
plt.ylabel('value')
plt.show()
def plot_images(img, label, name): #画图片(因为这里涉及到一个图片的识别),这个地方可以方便地看到图片的识别结果
fig = plt.figure()
for i in range(6):
plt.subplot(2, 3, i + 1)
plt.tight_layout()
plt.imshow(img[i][0]*0.3081+0.1307, cmap='gray', interpolation='none')
plt.title("{}:{}".format(name, label[i].item()))
plt.xticks([])
plt.yticks([])
plt.show
def one_hot(label, depth = 10): #需要通过scatter()完成one_hot编码
out = torch.zeros(label.size(0), depth)
idx = torch.LongTensor(label).view(-1,1)
out.scatter_(dim = 1, index = idx, value = 1)
return out
mnist_train.py
import matplotlib.pyplot
import torch
from torch import nn #nn是完成神经网络相关的一些工作
from torch.nn import functional as F #functional是常用的一些函数
from torch import optim #优化的工具包
import torchvision
from matplotlib import pyplot as plt
from utils import plot_images, plot_curve, one_hot
batch_size = 512 #一次并行处理512张图片
#step1. load dataset
train_loader = torch.utils.data.DataLoader(
torchvision.datasets.MNIST('mnist_data', train=True, download=True,
transform=torchvision.transforms.Compose([
torchvision.transforms.ToTensor(), #把numpy格式转换为tensor
torchvision.transforms.Normalize( #正则化过程
(0.1307,),(0.3081,)) #图像像素分布在0-1,所以要-0.1307,除以标准差0.3801,使得数据能够在0附近均匀分布
])),
batch_size=batch_size, shuffle=True) #batch_size一次行处理多少张图片,shuffle意味着加载时要做一个随机的打散
test_loader = torch.utils.data.DataLoader(
torchvision.datasets.MNIST('mnist_data', train=False, download=True,
transform=torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(
(0.1307,),(0.3081,))
])),
batch_size=batch_size, shuffle=False)
x, y = next(iter(train_loader))
print(x.shape, y.shape, x.min(), x.max()) #torch.Size([512, 1, 28, 28]) torch.Size([512]) tensor(-0.4242) tensor(2.8215)
#512张图片,1个通道,28行,28列。 因为有512张图片,所以有512个label
plot_images(x, y, 'image sample')
matplotlib.pyplot.show()
#step2. bulid a network
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
# xw+b
self.fc1 = nn.Linear(28*28, 256) #256根据经验确定
self.fc2 = nn.Linear(256, 64) #上面的输出是这里的输入
self.fc3 = nn.Linear(64, 10) #10是因为是十分类问题
def forward(self, x):
# x:[b,1,28,28]
# h1 = relu(xw1+b1)
x = F.relu(self.fc1(x))
# h2 = relu(xw2+b2)
x = F.relu(self.fc2(x))
# h3 = h2w3+b3
x = self.fc3(x) #最后一层加不加激活函数,取决于你的具体任务
return x
#step3:训练。训练的逻辑是:每一次求导,然后再去更新
net = Net() #完成网络的初始化
# [w1, b1, w2, b2, w3, b3]
optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.9) #通过optimizer优化器优化权值[w1, b1, w2, b2, w3, b3],同时还需要设置lr和momentum,moment是帮助更好优化的策略
train_loss = [] #把train_loss保存起来
for epoch in range(3): #对数据集迭代3次 内外嵌套循环 对数据集迭代
for batch_idx, (x, y) in enumerate(train_loader): #然后每次从train_loader这个数据集中sample(抽取)一个这样的batch,一个batch大概是512张图片,完成这个循环会对整个数据集迭代一遍。 内循环
# x:[b, 1, 28, 28], y:[512]
# print(x.shape, y.shape) #torch.Size([512, 1, 28, 28]) torch.Size([512])
# #torch.Size([512, 1, 28, 28]) torch.Size([512])
# #torch.Size([512, 1, 28, 28]) torch.Size([512])
# break
# [b, 1, 28, 28] => [b, 784] 需要把x打平成2维tensor
x = x.view(x.size(0), 28*28) #size(0)表示batch
# => [b, 10]
out = net(x) #预测的值
# [b,10]
y_onehot = one_hot(y) #将真实的y转换成one_hot 希望out接近y_onehot
# loss = mse(out, y_onehot)
loss = F.mse_loss(out, y_onehot) #通过mse(均方误差)计算真实值与预测值的误差
optimizer.zero_grad() #梯度清零
loss.backward() #计算梯度
# w' = w - lr*grad
optimizer.step() #更新梯度,从而得到新的[w1, b1, w2, b2, w3, b3]
#完成数据集的3次迭代后,我们得到optimal[w1, b1, w2, b2, w3, b3]
#打印loss
train_loss.append(loss.item())
if batch_idx % 10 == 0: #每隔10个batch打印一下loss,输出当前epoch,batch_idx,loss的具体数值
print(epoch, batch_idx, loss.item())
plot_curve(train_loss)
# we get optimal [w1, b1, w2, b2, w3, b3]
#step4:测试
total_correct = 0
#打印loss
for x, y in test_loader:
x = x.view(x.size(0), 28*28)
out = net(x) #得到网络的输出
#out: [b, 10] => pred: [b]
pred = out.argmax(dim = 1)
correct = pred.eq(y).sum().float().item() #item()取数值 当前batch正确的个数
total_correct += correct
total_num = len(test_loader.dataset) #总的测试的数量
acc = total_correct/total_num #准确率
print('test acc:', acc)
x, y = next(iter(test_loader))
out = net(x.view(x.size(0), 28*28))
pred = out.argmax(dim = 1)
plot_images(x, pred, 'test')
#后期可进行的工作:
#def net()中增加网络层数
#def forward()中最后一层可以用softmax()
#loss:F.mse_loss()改成交叉熵函数
#lr调一调 改一改
有错误欢迎指正~