softmax和分类模型
softmax的基本概念
-
分类问题
一个简单的图像分类问题,输入图像的高和宽均为2像素,色彩为灰度。
图像中的4像素分别记为 x 1 , x 2 , x 3 , x 4 x_1,x_2,x_3,x_4 x1,x2,x3,x4。
假设真实标签为狗、猫或者鸡,这些标签对应的离散值为 y 1 , y 2 , y 3 y_1,y_2,y_3 y1,y2,y3。
通常使用离散的数值来表示类别,例如 y 1 = 1 , y 2 = 2 , y 3 = 3 y_1=1,y_2=2,y_3=3 y1=1,y2=2,y3=3。 -
权重矢量
o 1 = x 1 w 11 + x 2 w 21 + x 3 w 31 + x 4 w 41 + b 1 o 2 = x 1 w 12 + x 2 w 22 + x 3 w 32 + x 4 w 42 + b 2 o 3 = x 1 w 13 + x 2 w 23 + x 3 w 33 + x 4 w 43 + b 3 \begin{aligned} o_1&=x_1w_{11}+x_2w_{21}+x_3w_{31}+x_4w_{41}+b_1\\ o_2&=x_1w_{12}+x_2w_{22}+x_3w_{32}+x_4w_{42}+b_2\\ o_3&=x_1w_{13}+x_2w_{23}+x_3w_{33}+x_4w_{43}+b_3\\ \end{aligned} o1o2o3=x1w11+x2w21+x3w31+x4w41+b1=x1w12+x2w22+x3w32+x4w42+b2=x1w13+x2w23+x3w33+x4w43+b3 -
神经网络图
下图用神经网络图描绘了上面的计算。softmax回归同线性回归一样,也是一个单层神经网络。由于每个输出 o 1 , o 2 , o 3 o_1,o_2,o_3 o1,o2,o3的计算都要依赖于所有的输入 x 1 , x 2 , x 3 , x 4 x_1,x_2,x_3,x_4 x1,x2,x3,x4,softmax回归的输出层也是一个全连接层。 -
输出问题
直接使用输出层的输出有两个问题:
- 一方面,由于输出层的输出值的范围不确定,我们难以直观上判断这些值的意义。例如,刚才举的例子中的输出值10表示“很置信”图像类别为猫,因为该输出值是其他两类的输出值的100倍。但如果 o 1 = o 3 = 1 0 3 o_1=o_3=10^3 o1=o3=103,那么输出值10却又表示图像类别为猫的概率很低。
- 另一方面,由于真实标签是离散值,这些离散值与不确定范围的输出值之间的误差难以衡量。
softmax运算符(softmax operator)解决了以上两个问题。它通过下式将输出值变换成值为正且和为1的概率分布:
y
^
1
,
y
^
2
,
y
^
3
=
s
o
f
t
e
m
a
x
(
o
1
,
o
2
,
o
3
)
\hat{y}_1,\hat{y}_2,\hat{y}_3=softemax(o_1,o_2,o_3)
y^1,y^2,y^3=softemax(o1,o2,o3)其中
y
^
1
=
e
x
p
(
o
1
)
∑
i
=
1
3
e
x
p
(
o
i
)
,
y
^
2
=
e
x
p
(
o
2
)
∑
i
=
1
3
e
x
p
(
o
i
)
,
y
^
3
=
e
x
p
(
o
3
)
∑
i
=
1
3
e
x
p
(
o
i
)
\hat{y}_1=\dfrac{exp(o_1)}{\sum_{i=1}^3 exp(o_i)},\hat{y}_2=\dfrac{exp(o_2)}{\sum_{i=1}^3 exp(o_i)},\hat{y}_3=\dfrac{exp(o_3)}{\sum_{i=1}^3 exp(o_i)}
y^1=∑i=13exp(oi)exp(o1),y^2=∑i=13exp(oi)exp(o2),y^3=∑i=13exp(oi)exp(o3)容易看出
y
^
1
+
y
^
2
+
y
^
3
=
1
\hat{y}_1+\hat{y}_2+\hat{y}_3=1
y^1+y^2+y^3=1且
0
<
y
^
1
,
y
^
2
,
y
^
3
<
1
0<\hat{y}_1,\hat{y}_2,\hat{y}_3<1
0<y^1,y^2,y^3<1,因此
y
^
1
,
y
^
2
,
y
^
3
\hat{y}_1,\hat{y}_2,\hat{y}_3
y^1,y^2,y^3是一个合法的概率分布。这时候,如果
y
^
2
=
0.8
\hat{y}_2=0.8
y^2=0.8,不管
y
^
1
\hat{y}_1
y^1和
y
^
3
\hat{y}_3
y^3的值是多少,我们都知道图像类别为猫的概率是80%。此外,可以注意到
m
a
x
(
o
i
)
=
m
a
x
(
y
^
i
)
max(o_i)=max(\hat{y}_i)
max(oi)=max(y^i)因此softmax运算不改变预测类别输出。
def softmax(X):
X_exp = X.exp()
partition = X_exp.sum(dim=1, keepdim=True)
# print("X size is ", X_exp.size())
# print("partition size is ", partition, partition.size())
return X_exp / partition
X = torch.rand((2, 5))
X_prob = softmax(X)
print(X_prob, '\n', X_prob.sum(dim=1))
tensor([[0.1927, 0.2009, 0.1823, 0.1887, 0.2355],
[0.1274, 0.1843, 0.2536, 0.2251, 0.2096]])
tensor([1., 1.])
交叉熵损失函数
对于样本 i i i,构造向量 y ( i ) ∈ R q y^{(i)}\in R^q y(i)∈Rq,使其第 y ( i ) y^{(i)} y(i)(样本 i i i类别的离散数值)个元素为1,其余为0。这样训练目标可以设为使预测概率分布 y ^ ( i ) \hat{y}^{(i)} y^(i)尽可能接近真实的标签概率分布 y ( i ) y^{(i)} y(i)。
- 平方损失估计
L o s s = ∣ y ^ ( i ) − y ( i ) ∣ 2 / 2 Loss=|\hat{y}^{(i)}-y^{(i)}|^2/2 Loss=∣y^(i)−y(i)∣2/2
然而,想要预测分类结果正确,其实并不需要预测概率完全等于标签概率。例如,在图像分类的例子里,如果
y
(
i
)
=
3
y^{(i)}=3
y(i)=3,那么只需要
y
3
(
i
)
y^{(i)}_3
y3(i)比其他两个预测值
y
1
(
i
)
y^{(i)}_1
y1(i)和
y
2
(
i
)
y^{(i)}_2
y2(i)大就行了。即使
y
3
(
i
)
y^{(i)}_3
y3(i)值为0.6,不管其他两个预测值为多少,类别预测均正确。而平方损失则过于严格,例如
y
1
(
i
)
=
y
2
(
i
)
=
0.2
y^{(i)}_1=y^{(i)}_2=0.2
y1(i)=y2(i)=0.2比
y
1
(
i
)
=
0
,
y
2
(
i
)
=
0.4
y^{(i)}_1=0,y^{(i)}_2=0.4
y1(i)=0,y2(i)=0.4的损失要小很多,虽然两者都有同样正确的分类预测结果。
改善上述问题的一个方法是使用更适合衡量两个概率分布差异的测量函数。其中,交叉熵(cross entropy)是一个常用的衡量方法:
H
(
y
(
i
)
,
y
^
(
i
)
)
=
−
∑
j
=
1
q
y
j
(
i
)
log
y
j
(
i
)
H(y^{(i)},\hat{y}^{(i)})=-\sum_{j=1}^q y^{(i)}_j\log y^{(i)}_j
H(y(i),y^(i))=−j=1∑qyj(i)logyj(i)其中带下标的
y
j
(
i
)
y^{(i)}_j
yj(i)是向量中非0即1的元素,需要注意将它与样本类别的离散数值,即不带下标的
y
(
i
)
y^{(i)}
y(i)区分。在上式中,我们知道向量中只有第个元素为1,其余全为0,于是
H
(
y
(
i
)
,
y
^
(
i
)
)
=
−
log
y
^
y
(
i
)
(
i
)
H(y^{(i)},\hat{y}^{(i)})=-\log \hat y^{(i)}_{y^{(i)}}
H(y(i),y^(i))=−logy^y(i)(i)。也就是说,交叉熵只关心对正确类别的预测概率,因为只要其值足够大,就可以确保分类结果正确。当然,遇到一个样本有多个标签时,例如图像里含有不止一个物体时,并不能做这一步简化。但即便对于这种情况,交叉熵同样只关心对图像中出现的物体类别的预测概率。
假设训练数据集的样本数为 n n n,交叉熵损失函数定义为 l ( Θ ) = 1 n ∑ i = 1 n H ( y ( i ) , y ^ ( i ) ) l(\Theta)=\dfrac{1}{n}\sum_{i=1}^nH(y^{(i)},\hat{y}^{(i)}) l(Θ)=n1i=1∑nH(y(i),y^(i))其中 Θ \Theta Θ代表模型参数。同样地,如果每个样本只有一个标签,那么交叉熵损失可以简写成 l ( Θ ) = − 1 n ∑ i = 1 n log y ^ y ( i ) ( i ) l(\Theta)=-\dfrac{1}{n}\sum_{i=1}^n\log \hat y^{(i)}_{y^{(i)}} l(Θ)=−n1∑i=1nlogy^y(i)(i)。从另一个角度来看,我们知道最小化 l ( Θ ) l(\Theta) l(Θ)等价于 e x p ( − n l ( Θ ) ) = ∏ i = 1 n y ^ y ( i ) ( i ) exp(-nl(\Theta))=\prod_{i=1}^n \hat y^{(i)}_{y^{(i)}} exp(−nl(Θ))=∏i=1ny^y(i)(i)最大化,即最小化交叉熵损失函数等价于最大化训练数据集所有标签类别的联合预测概率。
def cross_entropy(y_hat, y):
'''gather(dim,index) 取值 dim=0/1 0表示按列取 1表示按行取;index表示取值的下标'''
return - torch.log(y_hat.gather(1, y.view(-1, 1)))
loss = nn.CrossEntropyLoss() # 下面是他的函数原型
# class torch.nn.CrossEntropyLoss(weight=None, size_average=None, ignore_index=-100, reduce=None, reduction='mean')
y_hat = torch.tensor([[0.1, 0.3, 0.6], [0.3, 0.2, 0.5]])
y = torch.LongTensor([0, 2])
y_hat.gather(1, y.view(-1, 1))#必须为y.view(-1, 1) y.view(1, -1)会报错
tensor([[0.1000],
[0.5000]])
获取Fashion-MNIST训练集和读取数据
图像分类数据集中最常用的是手写数字识别数据集MNIST。但大部分模型在MNIST上的分类精度都超过了95%。为了更直观地观察算法之间的差异,将使用一个图像内容更加复杂的数据集Fashion-MNIST。
这里会使用torchvision包,它是服务于PyTorch深度学习框架的,主要用来构建计算机视觉模型。torchvision主要由以下几部分构成:
- torchvision.datasets: 一些加载数据的函数及常用的数据集接口;
- torchvision.models: 包含常用的模型结构(含预训练模型),例如AlexNet、VGG、ResNet等;
- torchvision.transforms: 常用的图片变换,例如裁剪、旋转等;
- torchvision.utils: 其他的一些有用的方法。
class torchvision.datasets.FashionMNIST(root, train=True, transform=None, target_transform=None, download=False)
- root(string) – 数据集的根目录,其中存放processed/training.pt和processed/test.pt文件;
- train(bool, 可选) – 如果设置为True,从training.pt创建数据集,否则从test.pt创建;
- download(bool, 可选) – 如果设置为True,从互联网下载数据并放到root文件夹下。如果root目录下已经存在数据,不会再次下载;
- transform(可被调用 , 可选) – 一种函数或变换,输入PIL图片,返回变换之后的数据。如:transforms.RandomCrop;
- target_transform(可被调用 , 可选) – 一种函数或变换,输入目标,进行变换。
对多维Tensor按维度操作
X = torch.tensor([[1, 2, 3], [4, 5, 6]])
print(X.sum(dim=0, keepdim=True)) # dim为0,按照相同的列求和,并在结果中保留列特征
print(X.sum(dim=1, keepdim=True)) # dim为1,按照相同的行求和,并在结果中保留行特征
print(X.sum(dim=0, keepdim=False)) # dim为0,按照相同的列求和,不在结果中保留列特征
print(X.sum(dim=1, keepdim=False)) # dim为1,按照相同的行求和,不在结果中保留行特征
tensor([[5, 7, 9]])
tensor([[ 6],
[15]])
tensor([5, 7, 9])
tensor([ 6, 15])
其他函数
#iter()函数:生成了一个迭代器
list_ = [1, 2, 3, 4, 5]
it = iter(list_)
for i in range(5):
line = next(it)
print("第%d 行, %s" %(i, line))
第0 行, 1
第1 行, 2
第2 行, 3
第3 行, 4
第4 行, 5