Mnist分类任务
网络基本构建与训练方法,常用函数解析
torch.nn.functional模块
nn.Module模块
读取Mnist数据集
会自动进行下载(如果下载出问题可以在这个github下载,在文件根目录下创建data/mnist/,放入mnist文件夹即可)
%matplotlib inline
from pathlib import Path
import requests
DATA_PATH =Path("data")
PATH = DATA_PATH / "mnist"
PATH.mkdir(parents=True, exist_ok=True)
URL = "http://deeplearning.net/mnist.pkl.gz"
FILENAME = "mnist.pkl.gz"
if not (PATH / FILENAME).exists():
content = requests.get(URL + FILENAME).content
(PATH / FILENAME).open("wb").write(content)
import pickle
import gzip
with gzip.open((PATH / FILENAME).as_posix(),"rb") as f:
((x_train,y_train), (x_valid,y_valid),_) = pickle.load(f, encoding="latin-1")
from matplotlib import pyplot
import numpy as np
pyplot.imshow(x_train[0].reshape((28,28)),cmap="gray")
print(x_train.shape)
(50000, 784)
50000是样本数字,784是像素点个数=28*28*1= h*w*颜色通道, 在黑白图中只有一个颜色通道
import torch
x_train, y_train, x_valid, y_valid = map(
torch.tensor, (x_train, y_train, x_valid, y_valid)
)
n, c = x_train.shape
x_train, x_train.shape, y_train.min(), y_train.max()
print(x_train, y_train)
print(x_train.shape)
print(y_train.min(), y_train.max())
tensor([[0., 0., 0., ..., 0., 0., 0.], [0., 0., 0., ..., 0., 0., 0.], [0., 0., 0., ..., 0., 0., 0.], ..., [0., 0., 0., ..., 0., 0., 0.], [0., 0., 0., ..., 0., 0., 0.], [0., 0., 0., ..., 0., 0., 0.]]) tensor([5, 0, 4, ..., 8, 4, 8]) torch.Size([50000, 784]) tensor(0) tensor(9)
torch.nn.functional中有很多功能,后续会常用到。那么什么时候用nn.Module,什么时候用nn.function呢,一般情况下,如果模型有可学习的参数,最好用 nn.Module,其他情况nn.function会更简单一些
import torch.nn.functional as F
loss_func = F.cross_entropy
def model(xb):
return xb.mm(weights) + bias
bs = 64
xb = x_train[0:bs]
yb = y_train[0:bs]
weights = torch.randn([784, 10], dtype = torch.float, requires_grad=True)
bs = 64
bias = torch.zeros(10,requires_grad=True)
print(loss_func(model(xb),yb))
tensor(13.7654, grad_fn=<NllLossBackward0>)
创建一个model来简化代码
必须继承nn.Module且在其构造函数中需要调用nn.Module的构造函数
无需写反向传播函数,nn.Module能够利用autograd自动实现反向传播
Module中的可学习参数可以通过named_parameters()或者paramters()返回迭代器
from torch import nn
class Mnist_NN(nn.Module):
def __init__(self):
super().__init__()
self.hidden1 = nn.Linear(784, 128)
self.hidden2 = nn.Linear(128, 256)
self.out = nn.Linear(256, 10)
def forward(self,x):
x = F.relu(self.hidden1(x))
x = F.relu(self.hidden2(x))
x = self.out(x)
return x
net = Mnist_NN()
print(net)
Mnist_NN( (hidden1): Linear(in_features=784, out_features=128, bias=True) (hidden2): Linear(in_features=128, out_features=256, bias=True) (out): Linear(in_features=256, out_features=10, bias=True) )
可以打印我们定义好名字里的权重和偏置项
for name, parameter in net.named_parameters():
print(name,parameter,parameter.size())
hidden1.weight Parameter containing: tensor([[-0.0325, -0.0265, 0.0178, ..., 0.0236, -0.0253, -0.0139], [ 0.0308, 0.0199, -0.0114, ..., 0.0319, -0.0297, 0.0206], [-0.0147, -0.0082, 0.0126, ..., -0.0241, -0.0222, 0.0104], ..., [-0.0152, -0.0187, 0.0223, ..., 0.0195, 0.0154, 0.0051], [ 0.0291, 0.0240, -0.0014, ..., -0.0316, 0.0345, 0.0214], [ 0.0138, 0.0118, 0.0059, ..., -0.0114, 0.0332, -0.0270]], requires_grad=True) torch.Size([128, 784]) hidden1.bias Parameter containing: tensor([ 0.0219, 0.0226, 0.0348, -0.0243, 0.0202, -0.0118, 0.0346, -0.0002, -0.0283, -0.0212, -0.0095, -0.0199, 0.0057, 0.0276, 0.0143, 0.0054, -0.0161, 0.0051, -0.0192, -0.0025, 0.0189, -0.0038, 0.0085, -0.0146, 0.0177, 0.0061, -0.0138, 0.0245, 0.0140, 0.0343, 0.0356, 0.0205, -0.0245, -0.0167, 0.0282, -0.0108, 0.0239, -0.0163, -0.0229, 0.0036, 0.0009, 0.0230, -0.0138, -0.0155, -0.0064, 0.0093, -0.0183, 0.0110, -0.0118, -0.0205, 0.0016, -0.0222, -0.0291, -0.0331, -0.0143, 0.0260, -0.0066, -0.0024, 0.0276, -0.0136, -0.0265, 0.0211, 0.0253, -0.0289, -0.0162, 0.0335, 0.0345, -0.0209, -0.0156, -0.0180, 0.0076, -0.0350, 0.0277, 0.0195, -0.0021, 0.0170, 0.0041, -0.0119, -0.0008, -0.0015, 0.0237, 0.0309, -0.0213, 0.0351, 0.0204, 0.0217, 0.0099, -0.0098, -0.0136, -0.0297, 0.0317, -0.0105, 0.0293, 0.0130, 0.0253, 0.0099, -0.0207, -0.0041, 0.0094, 0.0176, -0.0293, 0.0131, 0.0274, -0.0302, 0.0082, -0.0308, -0.0162, 0.0076, 0.0265, -0.0062, 0.0130, -0.0330, -0.0103, 0.0308, 0.0115, 0.0160, -0.0197, -0.0085, 0.0040, 0.0033, -0.0086, 0.0281, 0.0263, -0.0082, 0.0184, -0.0117, -0.0284, -0.0238], requires_grad=True) torch.Size([128]) hidden2.weight Parameter containing: tensor([[-0.0340, -0.0734, 0.0346, ..., 0.0559, -0.0386, -0.0846], [ 0.0821, -0.0031, 0.0393, ..., 0.0061, 0.0190, -0.0220], [-0.0009, 0.0106, -0.0100, ..., -0.0401, 0.0322, -0.0491], ..., [-0.0390, 0.0538, -0.0025, ..., 0.0705, -0.0584, -0.0758], [-0.0805, 0.0092, -0.0577, ..., -0.0677, 0.0392, 0.0216], [-0.0809, -0.0496, 0.0092, ..., 0.0328, -0.0212, -0.0592]], requires_grad=True) torch.Size([256, 128]) hidden2.bias Parameter containing: tensor([ 3.4251e-02, 6.6245e-02, -5.3668e-02, -2.6804e-02, 6.2958e-02, -9.6022e-03, -6.8485e-02, 4.8541e-02, -5.9051e-02, 4.6563e-02, -5.6918e-02, 1.7858e-02, 8.2025e-03, 2.9629e-03, -1.5890e-02, -6.8271e-02, -6.1588e-02, -4.5848e-02, 9.3048e-03, -2.9849e-02, 8.0482e-02, 7.4823e-02, -8.7341e-02, -2.9350e-02, -3.2482e-02, -5.5491e-02, -3.0718e-02, 3.8584e-03, 7.6252e-02, 6.1728e-02, -5.9369e-02, 1.9814e-02, -4.4111e-02, -8.1931e-02, 5.8540e-02, -7.8754e-02, 6.1414e-02, 5.6715e-02, -3.9198e-02, 5.3766e-02, -8.8324e-02, -1.5136e-02, 2.3646e-02, 5.9494e-02, 2.2706e-02, -6.8846e-02, 5.5657e-02, 1.4645e-03, 6.5911e-02, 2.0665e-02, 6.7541e-02, -7.0019e-02, -3.5764e-02, 4.7367e-02, -8.4237e-02, 1.7093e-02, -8.1570e-02, 1.0706e-02, -5.1387e-02, 4.2129e-02, 1.8284e-03, -3.9581e-02, -3.6075e-02, -6.6867e-02, -2.4424e-02, -1.0475e-02, 5.6918e-02, -6.2000e-02, 5.0072e-02, -8.2728e-02, -7.2127e-02, 3.4423e-02, -6.4720e-02, 5.8055e-02, 1.6779e-02, -7.0746e-02, -5.2716e-02, -5.1131e-03, 9.4748e-03, 2.4467e-02, 8.2609e-02, -3.9185e-02, -7.0271e-02, 3.6122e-02, -5.2537e-02, 1.5991e-02, -4.4633e-02, 6.4542e-02, -6.3199e-02, 4.3374e-02, 4.8441e-02, -4.5691e-02, -3.7080e-02, -5.6189e-02, 3.5212e-03, 3.6487e-02, -7.0284e-02, -6.9327e-02, -5.8588e-02, 7.6064e-02, -3.9589e-02, -5.5521e-02, -5.5006e-02, 2.7749e-03, -5.6363e-02, 1.9204e-03, -7.5818e-02, -5.8002e-02, 5.6914e-02, -5.1079e-02, 4.6740e-02, 1.7789e-02, -5.2705e-03, -1.1432e-02, 6.5533e-02, 2.4519e-02, -8.1965e-02, 1.8052e-02, 8.7573e-02, 8.9896e-03, 8.2436e-02, 5.1366e-02, 8.4385e-03, -7.5690e-02, 4.4200e-02, 9.7071e-03, -7.8598e-02, -4.1634e-02, -8.3067e-02, -7.1623e-03, 2.3093e-02, -7.4160e-02, 5.5457e-02, -1.6331e-02, 8.2332e-02, 6.6278e-03, -4.3818e-02, -6.5338e-02, -2.7475e-02, -5.3869e-02, -3.6781e-02, -1.5129e-02, -1.8047e-02, 5.0949e-03, 8.1808e-03, -2.4383e-02, -6.0043e-02, -3.9112e-02, 6.2894e-02, 7.3802e-02, -2.9938e-02, -6.8608e-03, -1.6504e-03, 5.6096e-02, -6.5193e-02, 6.0587e-02, -4.6924e-02, -2.5903e-02, -6.5839e-02, -6.3925e-02, -4.2316e-02, -7.7518e-02, 8.3411e-02, -1.8981e-02, 1.4119e-02, -3.8588e-02, 2.0915e-02, -5.4507e-02, 7.9109e-02, -6.8472e-03, 4.4487e-02, 4.2394e-06, 6.2874e-02, -2.4991e-02, 3.1016e-02, 2.4549e-02, -3.6681e-03, -7.3148e-02, -4.5987e-02, 7.7052e-02, -2.7164e-02, 2.3189e-02, -2.8427e-02, 3.0965e-02, 7.4590e-02, -5.9826e-02, -1.9704e-02, -5.7558e-02, -7.8640e-02, 5.7251e-02, -2.4419e-02, -5.4338e-02, -3.5999e-02, 2.8274e-02, -5.1797e-02, 1.3047e-02, -4.3136e-02, 2.6374e-02, -4.1364e-02, -2.8603e-02, -2.2389e-02, 1.5212e-02, -4.9188e-02, 5.9365e-02, -5.5241e-02, -6.8918e-02, -2.2441e-02, 8.8711e-03, -6.9141e-02, -4.5297e-02, -2.1198e-02, 8.0370e-02, -8.7273e-02, 7.5763e-02, 8.6972e-02, -1.4734e-03, -5.7440e-02, 1.9224e-02, -1.1871e-02, 3.2531e-02, 4.8892e-02, 7.4199e-02, -5.4005e-02, -9.7446e-03, -1.3905e-02, -6.7276e-02, 7.1744e-02, -6.7407e-02, 6.1195e-02, 6.6671e-02, -3.3406e-02, 4.8141e-02, -7.3937e-02, -1.0370e-02, -3.7342e-02, -7.0859e-02, -1.6061e-02, 3.7811e-02, -7.5927e-02, 4.5883e-02, 6.8066e-02, 2.1853e-02, -2.7309e-03, 7.8225e-02, 3.9355e-02, -8.8315e-02, 3.8017e-02, -4.9178e-02, -6.8704e-02, 8.1418e-02, -1.0785e-02, -5.9308e-02, -2.2016e-02, 4.5492e-02, 5.6116e-02, -8.6721e-02], requires_grad=True) torch.Size([256]) out.weight Parameter containing: tensor([[ 0.0394, -0.0572, -0.0190, ..., 0.0153, 0.0324, 0.0284], [-0.0188, 0.0441, -0.0214, ..., 0.0510, 0.0085, 0.0408], [ 0.0054, -0.0528, 0.0188, ..., -0.0515, -0.0436, -0.0300], ..., [ 0.0105, -0.0091, -0.0491, ..., 0.0312, -0.0291, -0.0333], [ 0.0023, 0.0401, 0.0541, ..., -0.0160, 0.0430, -0.0306], [ 0.0065, -0.0045, 0.0404, ..., 0.0034, -0.0315, -0.0219]], requires_grad=True) torch.Size([10, 256]) out.bias Parameter containing: tensor([-0.0267, -0.0337, -0.0451, 0.0126, 0.0417, -0.0079, 0.0078, 0.0493, -0.0088, -0.0089], requires_grad=True) torch.Size([10])
使用TensorDataset和Dateloder来简化
from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader
train_ds = TensorDataset(x_train, y_train)
train_dl = DataLoader(train_ds, batch_size=bs, shuffle=True)
valid_ds = TensorDataset(x_valid,y_valid)
valis_dl = DataLoader(valid_ds, batch_size=bs * 2)
def get_data(train_ds, valid_ds, bs):
return (
DataLoader(train_ds, batch_size=bs, shuffle=True),
DataLoader(valid_ds,batch_size=bs * 2),
)
一般模型训练时加上model.train(),这样会正常使用Batch Normalization和Dropout 测试的时候一般选择model.eval(),这样就不会使用Batch Normalization和Dropout
import numpy as np
def fit(steps, model, loss_func, opt, train_dl, valid_dl):
for step in range(steps):
model.train()
for xb, yb, in train_dl:
loss_batch(model, loss_func, xb, yb, opt)
model.eval()
with torch.no_grad():
losses, nums = zip(
*[loss_batch(model, loss_func, xb, yb, opt) for xb, yb in valid_dl]
)
val_loss = np.sum(np.multiply(losses, nums)) / np.sum(nums)
print('当前step:'+str(step), '验证集损失:'+str(val_loss))
from torch import optim
def get_model():
model = Mnist_NN()
return model, optim.SGD(model.parameters(), lr=0.001)
def loss_batch(model, loss_func, xb, yb, opt=None):
loss = loss_func(model(xb), yb)
if opt is not None:
loss.requires_grad_(True)
loss.backward()
opt.step()
opt.zero_grad()
return loss.item(), len(xb)
三行搞定
train_dl, valid_dl = get_data(train_ds, valid_ds, bs)
model, opt = get_model()
fit(25, model, loss_func, opt, train_dl, valid_dl)
当前step:0 验证集损失:2.2827181243896484 当前step:1 验证集损失:2.2478832061767577 当前step:2 验证集损失:2.1945471561431886 当前step:3 验证集损失:2.1076549449920656 当前step:4 验证集损失:1.9676035526275635 当前step:5 验证集损失:1.764887052345276 当前step:6 验证集损失:1.522310170173645 当前step:7 验证集损失:1.287570206260681 当前step:8 验证集损失:1.0948256267547607 当前step:9 验证集损失:0.9480554515838623 当前step:10 验证集损失:0.8378450462341308 当前step:11 验证集损失:0.7521841385841369 当前step:12 验证集损失:0.6849521107673645 当前step:13 验证集损失:0.6307255907058715 当前step:14 验证集损失:0.5866909859657288 当前step:15 验证集损失:0.5503184311389923 当前step:16 验证集损失:0.5185965992450714 当前step:17 验证集损失:0.4931147786140442 当前step:18 验证集损失:0.4708563284873962 当前step:19 验证集损失:0.452236319065094 当前step:20 验证集损失:0.435598934841156 当前step:21 验证集损失:0.42189919352531435 当前step:22 验证集损失:0.409069230890274 当前step:23 验证集损失:0.39823391227722166 当前step:24 验证集损失:0.3882803885936737
这样一个手写数字识别就完成了。