pytorch- softmax和分类模型

  1. softmax回归的基本概念
  2. 如何获取Fashion-MNIST数据集和读取数据
  3. softmax回归模型的从零开始实现,实现一个对Fashion-MNIST训练集中的图像数据进行分类的模型
  4. 使用pytorch重新实现softmax回归模型


  • 分类问题
    图像中的4像素分别记为 x 1 , x 2 , x 3 , x 4 x_1, x_2, x_3, x_4 x1,x2,x3,x4
    假设真实标签为狗、猫或者鸡,这些标签对应的离散值为 y 1 , y 2 , y 3 y_1, y_2, y_3 y1,y2,y3
    我们通常使用离散的数值来表示类别,例如 y 1 = 1 , y 2 = 2 , y 3 = 3 y_1=1, y_2=2, y_3=3 y1=1,y2=2,y3=3

  • 权重矢量

o 1 = x 1 w 11 + x 2 w 21 + x 3 w 31 + x 4 w 41 + b 1 \begin{aligned} o_1 &= x_1 w_{11} + x_2 w_{21} + x_3 w_{31} + x_4 w_{41} + b_1 \end{aligned} o1=x1w11+x2w21+x3w31+x4w41+b1

o 2 = x 1 w 12 + x 2 w 22 + x 3 w 32 + x 4 w 42 + b 2 \begin{aligned} o_2 &= x_1 w_{12} + x_2 w_{22} + x_3 w_{32} + x_4 w_{42} + b_2 \end{aligned} o2=x1w12+x2w22+x3w32+x4w42+b2

o 3 = x 1 w 13 + x 2 w 23 + x 3 w 33 + x 4 w 43 + b 3 \begin{aligned} o_3 &= x_1 w_{13} + x_2 w_{23} + x_3 w_{33} + x_4 w_{43} + b_3 \end{aligned} o3=x1w13+x2w23+x3w33+x4w43+b3

  • 神经网络图
    下图用神经网络图描绘了上面的计算。softmax回归同线性回归一样,也是一个单层神经网络。由于每个输出 o 1 , o 2 , o 3 o_1, o_2, o_3 o1,o2,o3的计算都要依赖于所有的输入 x 1 , x 2 , x 3 , x 4 x_1, x_2, x_3, x_4 x1,x2,x3,x4,softmax回归的输出层也是一个全连接层。

Image Name

s o f t m a x 回 归 是 一 个 单 层 神 经 网 络 \begin{aligned}softmax回归是一个单层神经网络\end{aligned} softmax

既然分类问题需要得到离散的预测输出,一个简单的办法是将输出值 o i o_i oi当作预测类别是 i i i的置信度,并将值最大的输出所对应的类作为预测输出,即输出 arg ⁡ max ⁡ i o i \underset{i}{\arg\max} o_i iargmaxoi。例如,如果 o 1 , o 2 , o 3 o_1,o_2,o_3 o1,o2,o3分别为 0.1 , 10 , 0.1 0.1,10,0.1 0.1,10,0.1,由于 o 2 o_2 o2最大,那么预测类别为2,其代表猫。

  • 输出问题
    1. 一方面,由于输出层的输出值的范围不确定,我们难以直观上判断这些值的意义。例如,刚才举的例子中的输出值10表示“很置信”图像类别为猫,因为该输出值是其他两类的输出值的100倍。但如果 o 1 = o 3 = 1 0 3 o_1=o_3=10^3 o1=o3=103,那么输出值10却又表示图像类别为猫的概率很低。
    2. 另一方面,由于真实标签是离散值,这些离散值与不确定范围的输出值之间的误差难以衡量。

softmax运算符(softmax operator)解决了以上两个问题。它通过下式将输出值变换成值为正且和为1的概率分布:

y ^ 1 , y ^ 2 , y ^ 3 = softmax ( o 1 , o 2 , o 3 ) \hat{y}_1, \hat{y}_2, \hat{y}_3 = \text{softmax}(o_1, o_2, o_3) y^1,y^2,y^3=softmax(o1,o2,o3)


y ^ 1 = exp ⁡ ( o 1 ) ∑ i = 1 3 exp ⁡ ( o i ) , y ^ 2 = exp ⁡ ( o 2 ) ∑ i = 1 3 exp ⁡ ( o i ) , y ^ 3 = exp ⁡ ( o 3 ) ∑ i = 1 3 exp ⁡ ( o i ) . \hat{y}1 = \frac{ \exp(o_1)}{\sum_{i=1}^3 \exp(o_i)},\quad \hat{y}2 = \frac{ \exp(o_2)}{\sum_{i=1}^3 \exp(o_i)},\quad \hat{y}3 = \frac{ \exp(o_3)}{\sum_{i=1}^3 \exp(o_i)}. y^1=i=13exp(oi)exp(o1),y^2=i=13exp(oi)exp(o2),y^3=i=13exp(oi)exp(o3).

容易看出 y ^ 1 + y ^ 2 + y ^ 3 = 1 \hat{y}_1 + \hat{y}_2 + \hat{y}_3 = 1 y^1+y^2+y^3=1 0 ≤ y ^ 1 , y ^ 2 , y ^ 3 ≤ 1 0 \leq \hat{y}_1, \hat{y}_2, \hat{y}_3 \leq 1 0y^1,y^2,y^31,因此 y ^ 1 , y ^ 2 , y ^ 3 \hat{y}_1, \hat{y}_2, \hat{y}_3 y^1,y^2,y^3是一个合法的概率分布。这时候,如果 y ^ 2 = 0.8 \hat{y}_2=0.8 y^2=0.8,不管 y ^ 1 \hat{y}_1 y^1 y ^ 3 \hat{y}_3 y^3的值是多少,我们都知道图像类别为猫的概率是80%。此外,我们注意到

arg ⁡ max ⁡ i o i = arg ⁡ max ⁡ i y ^ i \underset{i}{\arg\max} o_i = \underset{i}{\arg\max} \hat{y}_i iargmaxoi=iargmaxy^i


  • 计算效率
    • 单样本矢量计算表达式

W = [ w 11 w 12 w 13 w 21 w 22 w 23 w 31 w 32 w 33 w 41 w 42 w 43 ] , b = [ b 1 b 2 b 3 ] , \boldsymbol{W} = \begin{bmatrix} w_{11} & w_{12} & w_{13} \\ w_{21} & w_{22} & w_{23} \\ w_{31} & w_{32} & w_{33} \\ w_{41} & w_{42} & w_{43} \end{bmatrix},\quad \boldsymbol{b} = \begin{bmatrix} b_1 & b_2 & b_3 \end{bmatrix}, W=w11w21w31w41w12w22w32w42w13w23w33w43,b=[b1b2b3],

设高和宽分别为2个像素的图像样本 i i i的特征为

x ( i ) = [ x 1 ( i ) x 2 ( i ) x 3 ( i ) x 4 ( i ) ] , \boldsymbol{x}^{(i)} = \begin{bmatrix}x_1^{(i)} & x_2^{(i)} & x_3^{(i)} & x_4^{(i)}\end{bmatrix}, x(i)=[x1(i)x2(i)x3(i)x4(i)],


o ( i ) = [ o 1 ( i ) o 2 ( i ) o 3 ( i ) ] , \boldsymbol{o}^{(i)} = \begin{bmatrix}o_1^{(i)} & o_2^{(i)} & o_3^{(i)}\end{bmatrix}, o(i)=[o1(i)o2(i)o3(i)],


y ^ ( i ) = [ y ^ 1 ( i ) y ^ 2 ( i ) y ^ 3 ( i ) ] . \boldsymbol{\hat{y}}^{(i)} = \begin{bmatrix}\hat{y}_1^{(i)} & \hat{y}_2^{(i)} & \hat{y}_3^{(i)}\end{bmatrix}. y^(i)=[y^1(i)y^2(i)y^3(i)].

softmax回归对样本 i i i分类的矢量计算表达式为

o ( i ) = x ( i ) W + b , y ^ ( i ) = softmax ( o ( i ) ) . \begin{aligned} \boldsymbol{o}^{(i)} &= \boldsymbol{x}^{(i)} \boldsymbol{W} + \boldsymbol{b},\\ \boldsymbol{\hat{y}}^{(i)} &= \text{softmax}(\boldsymbol{o}^{(i)}). \end{aligned} o(i)y^(i)=x(i)W+b,=softmax(o(i)).

  • 小批量矢量计算表达式
    为了进一步提升计算效率,我们通常对小批量数据做矢量计算。广义上讲,给定一个小批量样本,其批量大小为 n n n,输入个数(特征数)为 d d d,输出个数(类别数)为 q q q。设批量特征为 X ∈ R n × d \boldsymbol{X} \in \mathbb{R}^{n \times d} XRn×d。假设softmax回归的权重和偏差参数分别为 W ∈ R d × q \boldsymbol{W} \in \mathbb{R}^{d \times q} WRd×q b ∈ R 1 × q \boldsymbol{b} \in \mathbb{R}^{1 \times q} bR1×q。softmax回归的矢量计算表达式为

O = X W + b , Y ^ = softmax ( O ) , \begin{aligned} \boldsymbol{O} &= \boldsymbol{X} \boldsymbol{W} + \boldsymbol{b},\\ \boldsymbol{\hat{Y}} &= \text{softmax}(\boldsymbol{O}), \end{aligned} OY^=XW+b,=softmax(O),

其中的加法运算使用了广播机制, O , Y ^ ∈ R n × q \boldsymbol{O}, \boldsymbol{\hat{Y}} \in \mathbb{R}^{n \times q} O,Y^Rn×q且这两个矩阵的第 i i i行分别为样本 i i i的输出 o ( i ) \boldsymbol{o}^{(i)} o(i)和概率分布 y ^ ( i ) \boldsymbol{\hat{y}}^{(i)} y^(i)


对于样本 i i i,我们构造向量 y ( i ) ∈ R q \boldsymbol{y}^{(i)}\in \mathbb{R}^{q} y(i)Rq ,使其第 y ( i ) y^{(i)} y(i)(样本 i i i类别的离散数值)个元素为1,其余为0。这样我们的训练目标可以设为使预测概率分布 y ^ ( i ) \boldsymbol{\hat y}^{(i)} y^(i)尽可能接近真实的标签概率分布 y ( i ) \boldsymbol{y}^{(i)} y(i)

  • 平方损失估计

L o s s = ∣ y ^ ( i ) − y ( i ) ∣ 2 / 2 \begin{aligned}Loss = |\boldsymbol{\hat y}^{(i)}-\boldsymbol{y}^{(i)}|^2/2\end{aligned} Loss=y^(i)y(i)2/2

然而,想要预测分类结果正确,我们其实并不需要预测概率完全等于标签概率。例如,在图像分类的例子里,如果 y ( i ) = 3 y^{(i)}=3 y(i)=3,那么我们只需要 y ^ 3 ( i ) \hat{y}^{(i)}_3 y^3(i)比其他两个预测值 y ^ 1 ( i ) \hat{y}^{(i)}_1 y^1(i) y ^ 2 ( i ) \hat{y}^{(i)}_2 y^2(i)大就行了。即使 y ^ 3 ( i ) \hat{y}^{(i)}_3 y^3(i)值为0.6,不管其他两个预测值为多少,类别预测均正确。而平方损失则过于严格,例如 y ^ 1 ( i ) = y ^ 2 ( i ) = 0.2 \hat y^{(i)}_1=\hat y^{(i)}_2=0.2 y^1(i)=y^2(i)=0.2 y ^ 1 ( i ) = 0 , y ^ 2 ( i ) = 0.4 \hat y^{(i)}_1=0, \hat y^{(i)}_2=0.4 y^1(i)=0,y^2(i)=0.4的损失要小很多,虽然两者都有同样正确的分类预测结果。

改善上述问题的一个方法是使用更适合衡量两个概率分布差异的测量函数。其中,交叉熵(cross entropy)是一个常用的衡量方法:

H ( y ( i ) , y ^ ( i ) ) = − ∑ j = 1 q y j ( i ) log ⁡ y ^ j ( i ) , H\left(\boldsymbol y^{(i)}, \boldsymbol {\hat y}^{(i)}\right ) = -\sum_{j=1}^q y_j^{(i)} \log \hat y_j^{(i)}, H(y(i),y^(i))=j=1qyj(i)logy^j(i),

其中带下标的 y j ( i ) y_j^{(i)} yj(i)是向量 y ( i ) \boldsymbol y^{(i)} y(i)中非0即1的元素,需要注意将它与样本 i i i类别的离散数值,即不带下标的 y ( i ) y^{(i)} y(i)区分。在上式中,我们知道向量 y ( i ) \boldsymbol y^{(i)} y(i)中只有第 y ( i ) y^{(i)} y(i)个元素 y ( i ) y ( i ) y^{(i)}{y^{(i)}} y(i)y(i)为1,其余全为0,于是 H ( y ( i ) , y ^ ( i ) ) = − log ⁡ y ^ y ( i ) ( i ) H(\boldsymbol y^{(i)}, \boldsymbol {\hat y}^{(i)}) = -\log \hat y_{y^{(i)}}^{(i)} H(y(i),y^(i))=logy^y(i)(i)。也就是说,交叉熵只关心对正确类别的预测概率,因为只要其值足够大,就可以确保分类结果正确。当然,遇到一个样本有多个标签时,例如图像里含有不止一个物体时,我们并不能做这一步简化。但即便对于这种情况,交叉熵同样只关心对图像中出现的物体类别的预测概率。

假设训练数据集的样本数为 n n n,交叉熵损失函数定义为
ℓ ( Θ ) = 1 n ∑ i = 1 n H ( y ( i ) , y ^ ( i ) ) , \ell(\boldsymbol{\Theta}) = \frac{1}{n} \sum_{i=1}^n H\left(\boldsymbol y^{(i)}, \boldsymbol {\hat y}^{(i)}\right ), (Θ)=n1i=1nH(y(i),y^(i)),

其中 Θ \boldsymbol{\Theta} Θ代表模型参数。同样地,如果每个样本只有一个标签,那么交叉熵损失可以简写成 ℓ ( Θ ) = − ( 1 / n ) ∑ i = 1 n log ⁡ y ^ y ( i ) ( i ) \ell(\boldsymbol{\Theta}) = -(1/n) \sum_{i=1}^n \log \hat y_{y^{(i)}}^{(i)} (Θ)=(1/n)i=1nlogy^y(i)(i)。从另一个角度来看,我们知道最小化 ℓ ( Θ ) \ell(\boldsymbol{\Theta}) (Θ)等价于最大化 exp ⁡ ( − n ℓ ( Θ ) ) = ∏ i = 1 n y ^ y ( i ) ( i ) \exp(-n\ell(\boldsymbol{\Theta}))=\prod_{i=1}^n \hat y_{y^{(i)}}^{(i)} exp(n(Θ))=i=1ny^y(i)(i),即最小化交叉熵损失函数等价于最大化训练数据集所有标签类别的联合预测概率






  1. torchvision.datasets: 一些加载数据的函数及常用的数据集接口;
  2. torchvision.models: 包含常用的模型结构(含预训练模型),例如AlexNet、VGG、ResNet等;
  3. torchvision.transforms: 常用的图片变换,例如裁剪、旋转等;
  4. torchvision.utils: 其他的一些有用的方法。
# import needed package
# 导入包
%matplotlib inline
from IPython import display
import matplotlib.pyplot as plt

import torch
import torchvision
import torchvision.transforms as transforms# 图片变换
import time

import sys
import d2lzh1981 as d2l


get dataset

mnist_train = torchvision.datasets.FashionMNIST(root='/home/kesci/input/FashionMNIST2065', train=True, download=True, transform=transforms.ToTensor())
# train set and test set
mnist_test = torchvision.datasets.FashionMNIST(root='/home/kesci/input/FashionMNIST2065', train=False, download=True, transform=transforms.ToTensor())

class torchvision.datasets.FashionMNIST(root, train=True, transform=None, target_transform=None, download=False)

  • root(string)– 数据集的根目录,其中存放processed/training.pt和processed/test.pt文件。
  • train(bool, 可选)– 如果设置为True,从training.pt创建数据集,否则从test.pt创建。
  • download(bool, 可选)– 如果设置为True,从互联网下载数据并放到root文件夹下。如果root目录下已经存在数据,不会再次下载。
  • transform(可被调用 , 可选)– 一种函数或变换,输入PIL图片,返回变换之后的数据。如:transforms.RandomCrop。
  • target_transform(可被调用 , 可选)– 一种函数或变换,输入目标,进行变换。
# show result 
print(len(mnist_train), len(mnist_test))
<class 'torchvision.datasets.mnist.FashionMNIST'>
60000 10000
print(mnist_train[0])# 可以看出,图像只有一个通道,每张图像像素是28*28
# plt.imshow(mnist_train[0])
(tensor([[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0039, 0.0000, 0.0000, 0.0510,
          0.2863, 0.0000, 0.0000, 0.0039, 0.0157, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0039, 0.0039, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0118, 0.0000, 0.1412, 0.5333,
          0.4980, 0.2431, 0.2118, 0.0000, 0.0000, 0.0000, 0.0039, 0.0118,
          0.0157, 0.0000, 0.0000, 0.0118],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0235, 0.0000, 0.4000, 0.8000,
          0.6902, 0.5255, 0.5647, 0.4824, 0.0902, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0471, 0.0392, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.6078, 0.9255,
          0.8118, 0.6980, 0.4196, 0.6118, 0.6314, 0.4275, 0.2510, 0.0902,
          0.3020, 0.5098, 0.2824, 0.0588],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0039, 0.0000, 0.2706, 0.8118, 0.8745,
          0.8549, 0.8471, 0.8471, 0.6392, 0.4980, 0.4745, 0.4784, 0.5725,
          0.5529, 0.3451, 0.6745, 0.2588],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0039, 0.0039, 0.0039, 0.0000, 0.7843, 0.9098, 0.9098,
          0.9137, 0.8980, 0.8745, 0.8745, 0.8431, 0.8353, 0.6431, 0.4980,
          0.4824, 0.7686, 0.8980, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.7176, 0.8824, 0.8471,
          0.8745, 0.8941, 0.9216, 0.8902, 0.8784, 0.8706, 0.8784, 0.8667,
          0.8745, 0.9608, 0.6784, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.7569, 0.8941, 0.8549,
          0.8353, 0.7765, 0.7059, 0.8314, 0.8235, 0.8275, 0.8353, 0.8745,
          0.8627, 0.9529, 0.7922, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0039, 0.0118, 0.0000, 0.0471, 0.8588, 0.8627, 0.8314,
          0.8549, 0.7529, 0.6627, 0.8902, 0.8157, 0.8549, 0.8784, 0.8314,
          0.8863, 0.7725, 0.8196, 0.2039],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0235, 0.0000, 0.3882, 0.9569, 0.8706, 0.8627,
          0.8549, 0.7961, 0.7765, 0.8667, 0.8431, 0.8353, 0.8706, 0.8627,
          0.9608, 0.4667, 0.6549, 0.2196],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0157, 0.0000, 0.0000, 0.2157, 0.9255, 0.8941, 0.9020,
          0.8941, 0.9412, 0.9098, 0.8353, 0.8549, 0.8745, 0.9176, 0.8510,
          0.8510, 0.8196, 0.3608, 0.0000],
         [0.0000, 0.0000, 0.0039, 0.0157, 0.0235, 0.0275, 0.0078, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.9294, 0.8863, 0.8510, 0.8745,
          0.8706, 0.8588, 0.8706, 0.8667, 0.8471, 0.8745, 0.8980, 0.8431,
          0.8549, 1.0000, 0.3020, 0.0000],
         [0.0000, 0.0118, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.2431, 0.5686, 0.8000, 0.8941, 0.8118, 0.8353, 0.8667,
          0.8549, 0.8157, 0.8275, 0.8549, 0.8784, 0.8745, 0.8588, 0.8431,
          0.8784, 0.9569, 0.6235, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0706, 0.1725, 0.3216, 0.4196,
          0.7412, 0.8941, 0.8627, 0.8706, 0.8510, 0.8863, 0.7843, 0.8039,
          0.8275, 0.9020, 0.8784, 0.9176, 0.6902, 0.7373, 0.9804, 0.9725,
          0.9137, 0.9333, 0.8431, 0.0000],
         [0.0000, 0.2235, 0.7333, 0.8157, 0.8784, 0.8667, 0.8784, 0.8157,
          0.8000, 0.8392, 0.8157, 0.8196, 0.7843, 0.6235, 0.9608, 0.7569,
          0.8078, 0.8745, 1.0000, 1.0000, 0.8667, 0.9176, 0.8667, 0.8275,
          0.8627, 0.9098, 0.9647, 0.0000],
         [0.0118, 0.7922, 0.8941, 0.8784, 0.8667, 0.8275, 0.8275, 0.8392,
          0.8039, 0.8039, 0.8039, 0.8627, 0.9412, 0.3137, 0.5882, 1.0000,
          0.8980, 0.8667, 0.7373, 0.6039, 0.7490, 0.8235, 0.8000, 0.8196,
          0.8706, 0.8941, 0.8824, 0.0000],
         [0.3843, 0.9137, 0.7765, 0.8235, 0.8706, 0.8980, 0.8980, 0.9176,
          0.9765, 0.8627, 0.7608, 0.8431, 0.8510, 0.9451, 0.2549, 0.2863,
          0.4157, 0.4588, 0.6588, 0.8588, 0.8667, 0.8431, 0.8510, 0.8745,
          0.8745, 0.8784, 0.8980, 0.1137],
         [0.2941, 0.8000, 0.8314, 0.8000, 0.7569, 0.8039, 0.8275, 0.8824,
          0.8471, 0.7255, 0.7725, 0.8078, 0.7765, 0.8353, 0.9412, 0.7647,
          0.8902, 0.9608, 0.9373, 0.8745, 0.8549, 0.8314, 0.8196, 0.8706,
          0.8627, 0.8667, 0.9020, 0.2627],
         [0.1882, 0.7961, 0.7176, 0.7608, 0.8353, 0.7725, 0.7255, 0.7451,
          0.7608, 0.7529, 0.7922, 0.8392, 0.8588, 0.8667, 0.8627, 0.9255,
          0.8824, 0.8471, 0.7804, 0.8078, 0.7294, 0.7098, 0.6941, 0.6745,
          0.7098, 0.8039, 0.8078, 0.4510],
         [0.0000, 0.4784, 0.8588, 0.7569, 0.7020, 0.6706, 0.7176, 0.7686,
          0.8000, 0.8235, 0.8353, 0.8118, 0.8275, 0.8235, 0.7843, 0.7686,
          0.7608, 0.7490, 0.7647, 0.7490, 0.7765, 0.7529, 0.6902, 0.6118,
          0.6549, 0.6941, 0.8235, 0.3608],
         [0.0000, 0.0000, 0.2902, 0.7412, 0.8314, 0.7490, 0.6863, 0.6745,
          0.6863, 0.7098, 0.7255, 0.7373, 0.7412, 0.7373, 0.7569, 0.7765,
          0.8000, 0.8196, 0.8235, 0.8235, 0.8275, 0.7373, 0.7373, 0.7608,
          0.7529, 0.8471, 0.6667, 0.0000],
         [0.0078, 0.0000, 0.0000, 0.0000, 0.2588, 0.7843, 0.8706, 0.9294,
          0.9373, 0.9490, 0.9647, 0.9529, 0.9569, 0.8667, 0.8627, 0.7569,
          0.7490, 0.7020, 0.7137, 0.7137, 0.7098, 0.6902, 0.6510, 0.6588,
          0.3882, 0.2275, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1569,
          0.2392, 0.1725, 0.2824, 0.1608, 0.1373, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000]]]), 9)
# 我们可以通过下标来访问任意一个样本
feature, label = mnist_train[0]
print(feature.shape, label)  # Channel x Height x Width

torch.Size([1, 28, 28]) 9


mnist_PIL = torchvision.datasets.FashionMNIST(root='/home/kesci/input/FashionMNIST2065', train=True, download=True)
PIL_feature, label = mnist_PIL[0]
<PIL.Image.Image image mode=L size=28x28 at 0x7F1A28117E48>
# 本函数已保存在d2lzh包中方便以后使用
def get_fashion_mnist_labels(labels):
    text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat',
                   'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']
    return [text_labels[int(i)] for i in labels]
def show_fashion_mnist(images, labels):
    # 这里的_表示我们忽略(不使用)的变量
    _, figs = plt.subplots(1, len(images), figsize=(12, 12))
    for f, img, lbl in zip(figs, images, labels):
        f.imshow(img.view((28, 28)).numpy())
X, y = [], []
# 导入10个文件
for i in range(10):
    X.append(mnist_train[i][0]) # 将第i个feature加到X中
    y.append(mnist_train[i][1]) # 将第i个label加到y中
# 展示10张图例子
show_fashion_mnist(X, get_fashion_mnist_labels(y))
# 读取数据
batch_size = 256
num_workers = 4
# 4个进程读取数据
# 每组样本是256个
train_iter =, batch_size=batch_size, shuffle=True, num_workers=num_workers)
test_iter =, batch_size=batch_size, shuffle=False, num_workers=num_workers)
start = time.time()
for X, y in train_iter:
print('%.2f sec' % (time.time() - start))
4.76 sec


import torch
import torchvision
import numpy as np
import sys
import d2lzh1981 as d2l



batch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size, root='/home/kesci/input/FashionMNIST2065')


num_inputs = 784
#因为28*28 = 784,我们使用数字表示
num_outputs = 10
# 输入的x的标签是10*784,所以w的维度是784*10,结果变成10*10矩阵
W = torch.tensor(np.random.normal(0, 0.01, (num_inputs, num_outputs)), dtype=torch.float)
b = torch.zeros(num_outputs, dtype=torch.float)
# 设置可以求导的地方
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], requires_grad=True)


X = torch.tensor([[1, 2, 3], [4, 5, 6]])
print(X.sum(dim=0, keepdim=True))  # dim为0,按照相同的列求和,并在结果中保留列特征
print(X.sum(dim=1, keepdim=True))  # dim为1,按照相同的行求和,并在结果中保留行特征
print(X.sum(dim=0, keepdim=False)) # dim为0,按照相同的列求和,不在结果中保留列特征
print(X.sum(dim=1, keepdim=False)) # dim为1,按照相同的行求和,不在结果中保留行特征
# 多维度操作
tensor([[5, 7, 9]])
tensor([[ 6],
tensor([5, 7, 9])
tensor([ 6, 15])


y ^ j = exp ⁡ ( o j ) ∑ i = 1 3 exp ⁡ ( o i ) \hat{y}_j = \frac{ \exp(o_j)}{\sum_{i=1}^3 \exp(o_i)} y^j=i=13exp(oi)exp(oj)

def softmax(X):
    X_exp = X.exp()
    partition = X_exp.sum(dim=1, keepdim=True)
    # 行求和,因为一行就是一个输入图片,它的目标是10个,也就是结果有10列
    # print("X size is ", X_exp.size())
    # print("partition size is ", partition, partition.size())
    return X_exp / partition  # 这里应用了广播机制
X = torch.rand((2, 5))
X_prob = softmax(X)
print(X_prob, '\n', X_prob.sum(dim=1))
tensor([[0.0488, 0.5695, 0.7776, 0.7719, 0.8610],
        [0.3604, 0.1872, 0.3215, 0.2875, 0.5442]])
tensor([[0.1103, 0.1856, 0.2285, 0.2272, 0.2484],
        [0.2027, 0.1704, 0.1949, 0.1884, 0.2436]]) 
 tensor([1.0000, 1.0000])


o ( i ) = x ( i ) W + b , y ^ ( i ) = softmax ( o ( i ) ) . \begin{aligned} \boldsymbol{o}^{(i)} &= \boldsymbol{x}^{(i)} \boldsymbol{W} + \boldsymbol{b},\\ \boldsymbol{\hat{y}}^{(i)} &= \text{softmax}(\boldsymbol{o}^{(i)}). \end{aligned} o(i)y^(i)=x(i)W+b,=softmax(o(i)).

def net(X):
    return softmax(, num_inputs)), W) + b)
    # num_inputs = 784,这里是矩阵的乘法


H ( y ( i ) , y ^ ( i ) ) = − ∑ j = 1 q y j ( i ) log ⁡ y ^ j ( i ) , H\left(\boldsymbol y^{(i)}, \boldsymbol {\hat y}^{(i)}\right ) = -\sum_{j=1}^q y_j^{(i)} \log \hat y_j^{(i)}, H(y(i),y^(i))=j=1qyj(i)logy^j(i),

ℓ ( Θ ) = 1 n ∑ i = 1 n H ( y ( i ) , y ^ ( i ) ) , \ell(\boldsymbol{\Theta}) = \frac{1}{n} \sum_{i=1}^n H\left(\boldsymbol y^{(i)}, \boldsymbol {\hat y}^{(i)}\right ), (Θ)=n1i=1nH(y(i),y^(i)),

ℓ ( Θ ) = − ( 1 / n ) ∑ i = 1 n log ⁡ y ^ y ( i ) ( i ) \ell(\boldsymbol{\Theta}) = -(1/n) \sum_{i=1}^n \log \hat y_{y^{(i)}}^{(i)} (Θ)=(1/n)i=1nlogy^y(i)(i)

y_hat = torch.tensor([[0.1, 0.3, 0.6], [0.3, 0.2, 0.5]])
y = torch.LongTensor([0, 2])
y_hat.gather(1, y.view(-1, 1))
# 需要重新分配内存
def cross_entropy(y_hat, y):
    return - torch.log(y_hat.gather(1, y.view(-1, 1)))
# 打call,按照指定轴的聚集



def accuracy(y_hat, y):
    return (y_hat.argmax(dim=1) == y).float().mean().item()
# 我们想要求每一行最大的列标号,我们就要指定dim=1,表示我们不要列了,保留行的size就可以了
# 假如我们想求每一列的最大行标,就可以指定dim=0,表示我们不要行了
print(accuracy(y_hat, y))
# 本函数已保存在d2lzh_pytorch包中方便以后使用。该函数将被逐步改进:它的完整实现将在“图像增广”一节中描述
def evaluate_accuracy(data_iter, net):
    acc_sum, n = 0.0, 0
    for X, y in data_iter:
        acc_sum += (net(X).argmax(dim=1) == y).float().sum().item()
        n += y.shape[0]
    return acc_sum / n
print(evaluate_accuracy(test_iter, net))


num_epochs, lr = 5, 0.1

# 本函数已保存在d2lzh_pytorch包中方便以后使用
def train_ch3(net, train_iter, test_iter, loss, num_epochs, batch_size,
              params=None, lr=None, optimizer=None):
    for epoch in range(num_epochs):
        train_l_sum, train_acc_sum, n = 0.0, 0.0, 0
        for X, y in train_iter:
            y_hat = net(X)
            l = loss(y_hat, y).sum()
            # 梯度清零
            if optimizer is not None:
            elif params is not None and params[0].grad is not None:
                for param in params:
            if optimizer is None:
                d2l.sgd(params, lr, batch_size)
            train_l_sum += l.item()
            train_acc_sum += (y_hat.argmax(dim=1) == y).sum().item()
            n += y.shape[0]
        test_acc = evaluate_accuracy(test_iter, net)
        print('epoch %d, loss %.4f, train acc %.3f, test acc %.3f'
              % (epoch + 1, train_l_sum / n, train_acc_sum / n, test_acc))

train_ch3(net, train_iter, test_iter, cross_entropy, num_epochs, batch_size, [W, b], lr)
epoch 1, loss 0.7866, train acc 0.750, test acc 0.793
epoch 2, loss 0.5703, train acc 0.812, test acc 0.810
epoch 3, loss 0.5250, train acc 0.825, test acc 0.817
epoch 4, loss 0.5015, train acc 0.833, test acc 0.825
epoch 5, loss 0.4858, train acc 0.837, test acc 0.826



X, y = iter(test_iter).next()
# print(y)
true_labels = d2l.get_fashion_mnist_labels(y.numpy())
pred_labels = d2l.get_fashion_mnist_labels(net(X).argmax(dim=1).numpy())
titles = [true + '\n' + pred for true, pred in zip(true_labels, pred_labels)]

d2l.show_fashion_mnist(X[0:9], titles[0:9])


# 加载各种包或者模块
import torch
from torch import nn
from torch.nn import init
import numpy as np
import sys
import d2lzh1981 as d2l



batch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size, root='/home/kesci/input/FashionMNIST2065')


num_inputs = 784
num_outputs = 10

class LinearNet(nn.Module):
    def __init__(self, num_inputs, num_outputs):
        super(LinearNet, self).__init__()
        self.linear = nn.Linear(num_inputs, num_outputs)
    def forward(self, x): # x 的形状: (batch, 1, 28, 28)
        y = self.linear(x.view(x.shape[0], -1))
        return y
# net = LinearNet(num_inputs, num_outputs)

class FlattenLayer(nn.Module):
    def __init__(self):
        super(FlattenLayer, self).__init__()
    def forward(self, x): # x 的形状: (batch, *, *, ...)
        return x.view(x.shape[0], -1)

from collections import OrderedDict
net = nn.Sequential(
        # FlattenLayer(),
        # LinearNet(num_inputs, num_outputs) 
           ('flatten', FlattenLayer()),
           ('linear', nn.Linear(num_inputs, num_outputs))]) # 或者写成我们自己定义的 LinearNet(num_inputs, num_outputs) 也可以


init.normal_(net.linear.weight, mean=0, std=0.01)
init.constant_(net.linear.bias, val=0)
Parameter containing:
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], requires_grad=True)


loss = nn.CrossEntropyLoss() # 下面是他的函数原型
# class torch.nn.CrossEntropyLoss(weight=None, size_average=None, ignore_index=-100, reduce=None, reduction='mean')


optimizer = torch.optim.SGD(net.parameters(), lr=0.1) # 下面是函数原型
# class torch.optim.SGD(params, lr=, momentum=0, dampening=0, weight_decay=0, nesterov=False)


num_epochs = 100
d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, batch_size, None, None, optimizer)
epoch 1, loss 0.0019, train acc 0.841, test acc 0.831
epoch 2, loss 0.0018, train acc 0.844, test acc 0.820
epoch 3, loss 0.0018, train acc 0.845, test acc 0.834
epoch 4, loss 0.0018, train acc 0.847, test acc 0.834
epoch 5, loss 0.0018, train acc 0.847, test acc 0.829
epoch 6, loss 0.0017, train acc 0.850, test acc 0.834
epoch 7, loss 0.0017, train acc 0.850, test acc 0.832
epoch 8, loss 0.0017, train acc 0.852, test acc 0.836
epoch 9, loss 0.0017, train acc 0.852, test acc 0.834
epoch 10, loss 0.0017, train acc 0.853, test acc 0.834
epoch 11, loss 0.0017, train acc 0.854, test acc 0.829
epoch 12, loss 0.0017, train acc 0.855, test acc 0.839
epoch 13, loss 0.0017, train acc 0.856, test acc 0.838
epoch 14, loss 0.0017, train acc 0.856, test acc 0.833
epoch 15, loss 0.0016, train acc 0.856, test acc 0.839
epoch 16, loss 0.0016, train acc 0.857, test acc 0.840
epoch 17, loss 0.0016, train acc 0.858, test acc 0.834
epoch 18, loss 0.0016, train acc 0.858, test acc 0.838
epoch 19, loss 0.0016, train acc 0.858, test acc 0.841
epoch 20, loss 0.0016, train acc 0.859, test acc 0.836
epoch 21, loss 0.0016, train acc 0.859, test acc 0.837
epoch 22, loss 0.0016, train acc 0.860, test acc 0.843
epoch 23, loss 0.0016, train acc 0.859, test acc 0.834
epoch 24, loss 0.0016, train acc 0.861, test acc 0.839
epoch 25, loss 0.0016, train acc 0.860, test acc 0.840
epoch 26, loss 0.0016, train acc 0.861, test acc 0.842
epoch 27, loss 0.0016, train acc 0.861, test acc 0.842
epoch 28, loss 0.0016, train acc 0.861, test acc 0.842
epoch 29, loss 0.0016, train acc 0.861, test acc 0.831
epoch 30, loss 0.0016, train acc 0.861, test acc 0.840
epoch 31, loss 0.0016, train acc 0.863, test acc 0.840
epoch 32, loss 0.0016, train acc 0.863, test acc 0.843
epoch 33, loss 0.0016, train acc 0.863, test acc 0.841
epoch 34, loss 0.0016, train acc 0.863, test acc 0.843
epoch 35, loss 0.0016, train acc 0.863, test acc 0.826
epoch 36, loss 0.0016, train acc 0.863, test acc 0.843
epoch 37, loss 0.0016, train acc 0.863, test acc 0.843
epoch 38, loss 0.0016, train acc 0.863, test acc 0.842
epoch 39, loss 0.0016, train acc 0.863, test acc 0.844
epoch 40, loss 0.0016, train acc 0.864, test acc 0.838
epoch 41, loss 0.0016, train acc 0.864, test acc 0.834
epoch 42, loss 0.0016, train acc 0.864, test acc 0.842
epoch 43, loss 0.0016, train acc 0.864, test acc 0.839
epoch 44, loss 0.0015, train acc 0.865, test acc 0.843
epoch 45, loss 0.0015, train acc 0.864, test acc 0.842
epoch 46, loss 0.0015, train acc 0.864, test acc 0.843
epoch 47, loss 0.0015, train acc 0.864, test acc 0.844
epoch 48, loss 0.0015, train acc 0.864, test acc 0.843
epoch 49, loss 0.0015, train acc 0.865, test acc 0.843
epoch 50, loss 0.0015, train acc 0.866, test acc 0.844
epoch 51, loss 0.0015, train acc 0.865, test acc 0.837
epoch 52, loss 0.0015, train acc 0.865, test acc 0.842
epoch 53, loss 0.0015, train acc 0.866, test acc 0.844
epoch 54, loss 0.0015, train acc 0.866, test acc 0.841
epoch 55, loss 0.0015, train acc 0.864, test acc 0.841
epoch 56, loss 0.0015, train acc 0.866, test acc 0.843
epoch 57, loss 0.0015, train acc 0.865, test acc 0.845
epoch 58, loss 0.0015, train acc 0.866, test acc 0.832
epoch 59, loss 0.0015, train acc 0.866, test acc 0.844
epoch 60, loss 0.0015, train acc 0.866, test acc 0.844
epoch 61, loss 0.0015, train acc 0.866, test acc 0.845
epoch 62, loss 0.0015, train acc 0.867, test acc 0.843
epoch 63, loss 0.0015, train acc 0.866, test acc 0.840
epoch 64, loss 0.0015, train acc 0.866, test acc 0.844
epoch 65, loss 0.0015, train acc 0.866, test acc 0.843
epoch 66, loss 0.0015, train acc 0.866, test acc 0.841
epoch 67, loss 0.0015, train acc 0.867, test acc 0.843
epoch 68, loss 0.0015, train acc 0.866, test acc 0.844
epoch 69, loss 0.0015, train acc 0.867, test acc 0.845
epoch 70, loss 0.0015, train acc 0.867, test acc 0.843
epoch 71, loss 0.0015, train acc 0.867, test acc 0.842
epoch 72, loss 0.0015, train acc 0.867, test acc 0.839
epoch 73, loss 0.0015, train acc 0.867, test acc 0.842
epoch 74, loss 0.0015, train acc 0.867, test acc 0.842
epoch 75, loss 0.0015, train acc 0.868, test acc 0.843
epoch 76, loss 0.0015, train acc 0.867, test acc 0.845
epoch 77, loss 0.0015, train acc 0.867, test acc 0.842
epoch 78, loss 0.0015, train acc 0.867, test acc 0.840
epoch 79, loss 0.0015, train acc 0.868, test acc 0.843
epoch 80, loss 0.0015, train acc 0.867, test acc 0.842
epoch 81, loss 0.0015, train acc 0.867, test acc 0.841
epoch 82, loss 0.0015, train acc 0.868, test acc 0.844
epoch 83, loss 0.0015, train acc 0.868, test acc 0.845
epoch 84, loss 0.0015, train acc 0.869, test acc 0.840
epoch 85, loss 0.0015, train acc 0.868, test acc 0.839
epoch 86, loss 0.0015, train acc 0.868, test acc 0.843
epoch 87, loss 0.0015, train acc 0.868, test acc 0.841
epoch 88, loss 0.0015, train acc 0.868, test acc 0.845
epoch 89, loss 0.0015, train acc 0.869, test acc 0.840
epoch 90, loss 0.0015, train acc 0.868, test acc 0.841
epoch 91, loss 0.0015, train acc 0.869, test acc 0.844
epoch 92, loss 0.0015, train acc 0.868, test acc 0.844
epoch 93, loss 0.0015, train acc 0.869, test acc 0.845
epoch 94, loss 0.0015, train acc 0.869, test acc 0.845
epoch 95, loss 0.0015, train acc 0.869, test acc 0.845
epoch 96, loss 0.0015, train acc 0.869, test acc 0.846
epoch 97, loss 0.0015, train acc 0.868, test acc 0.843
epoch 98, loss 0.0015, train acc 0.869, test acc 0.845
epoch 99, loss 0.0015, train acc 0.870, test acc 0.839
epoch 100, loss 0.0015, train acc 0.868, test acc 0.840


  • 使用e的指数形式归一化是由于假设未知来源误差服从高斯分布,化简的结果得到的, 如果说明数据与理论标准函数之间的偏差服从其他的概率分布,最后选取得到的归一化函数就不一定是e的指数形式,并不是随便选的
  • 没记错的话,线性回归是在误差服从高斯分布的基础上推导而来的,softmax回归是在误差服从多项式分布的基础上推导而来的
  • 1
  • 1
