上节博客中我们一起探讨了机器学习中的回归问题,回归问题主要是为了解决连续的问题,生活中还有很多应用场景不是连续的,预测的结果是离散的,比如识别一个垃圾的种类或者是识别一个花卉的具体种类,这种问题如果使用简单的线性回归是达不到目的的,这个时候就需要引入分类去解决,今天我们一起来看下如何去使用Pyotrch实现Softmax分类。
开始之前给大家安利一下我之前写的使用tensorflow2构建物体分类模型的博客,我在博客中详细介绍了数据集收集、模型构建和模型使用三个方面,结合视频你也可以快速构建自己的物体分类模型,快去试试吧!
(2条消息) 手把手教你用tensorflow2.3训练自己的分类数据集_dejahu的博客-CSDN博客
回归VS分类
首先我们来看下回归和分类之间的区别,回归估计的是一个连续值,分类预测的是一个离散的类别,我们可以用下面的一张图来清洗展示回归和分类的区别,可以明显看到,分类问题中,输出的个数变多,并且需要学习的参数也变多了。
softmax数学原理
数学表达
介绍softmax之前我们需要先了解输出如何进行编码,对于一般的物体分类问题,我们会采用独热(one hot)编码
的形式来表示物体的类别。比如对于手写数字问题,一共有10类数字,我们使用大小为10的向量表示,在该数字的位置上标1,其余位置标0,该向量就是这个数字的标签,比如对于数字1而言,他的one-hot向量就是[0,1,0,0,0,0,0,0,0,0]。数学表示如下:
softmax的主要原理是为了将输出的多个值转化为概率值,以手写字体为例,我们如果能将输出的10个值转化为对应数字的概率值就很容易将其应用在手写数字的预测上,其中这10个浮点数的和为1,每个浮点数表示对应数字的概率值,数学表达如下:
模型构建
和线性回归一样,首先我们需要构建一个模型,然后指定损失函数和优化器,采用随机梯度下降的方式来训练模型,主要的不同点在于损失函数的计算,不同于线性回归中的平方差损失,这里的函数采用交叉熵损失函数
,主要是用来衡量两个概率之间的损失,其中样本的标签使用
y
y
y表示,样本的预测值使用
y
^
\hat{y}
y^表示,则他们之前的损失为:
求导的时候,其梯度就是真实概率和预测概率之间的区别:
常见的损失函数
损失函数在机器学习的模型中用来衡量预测值和真实值的误差,好的损失函数可以准确分析出两者之间的误差,并且能够帮助模型快速地梯度下降到最优解,一个好的损失函数好比是一个老教授,能够准确指出你的不足并帮助你快速调整备考状态。下面介绍几个常用的损失函数,其中黄色表示模型,蓝色表示损失函数,绿色表示梯度,注意梯度的大小,将会对模型参数的优化起到重要的作用。
L1 Loss
L2 Loss
Huber Robust Loss
softmax代码实现
代码的实现部分,我们使用的是大名鼎鼎的mnist数据集,也就是手写数字数据集,这个数据集有28*28大小的灰度图片构成,我们这个使用的还不是卷积神经网络,所以读入的时候只能将二维的图像展开,28*28也就是784,每张图像对应的就是784的向量,输出是10,对应10个数字,代码如下:
import torch
from torch import nn
from d2l import torch as d2l
# 加载数据集
batch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)
# PyTorch不会隐式地调整输入的形状。
# 因此,我们定义了展平层(flatten)在线性层前调整网络输入的形状
# 784是输入的向量数量,10是输出的向量数量
net = nn.Sequential(nn.Flatten(), nn.Linear(784, 10))
# 初始化权重,标准差是0.01
def init_weights(m):
if type(m) == nn.Linear:
nn.init.normal_(m.weight, std=0.01)
net.apply(init_weights)
# 在交叉熵损失函数中传递未归一化的预测,并同时计算softmax及其对数
loss = nn.CrossEntropyLoss()
# 使用学习率为0.1的小批量随机梯度下降作为优化算法
trainer = torch.optim.SGD(net.parameters(), lr=0.1)
num_epochs = 10
# 训练的过程和之前还是一样的,如文中所述,其实就是变了损失函数
d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, trainer)
可以看到,随着训练的进行,损失在不断下降,准确率在不断提升,最终的准确率大概在0.9左右。