第一章第5节 softmax回归的从零开始实现

这一节我们来动手实现softmax回归。首先导入本节实现所需的包或模块。

1.5.1 获取和读取数据

1.5.2 初始化模型参数

跟线性回归中的例子一样,我们将使用向量表示每个样本。已知每个样本输入时高和宽均为28像素的图像。模型的输入人向量的长度是28x28=784:该向量的每个元素对应图像中每个像素。由于图像有10个类别,单层神经网络输出层的输出个数为10,因此softmax回归的权重和偏差参数分别为784x10和1x10的矩阵。

同之前一样,我们要为模型参数附上梯度。

1.5.3 实现softmax运算

在介绍如何定义softmax回归之前,我们先描述一下对如何对多为NDArray按维度操作。在下面的例子中,给定一个NDArray矩阵X。我们可以只对其中同一列(axis=0)或同一行(axis=1)的元素求和,并在结果中保留行和列这两个维度(Keepdims=True)。

下面我们就可以定义前面小节里介绍的softmax运算了。在下面的函数中,矩阵X的行数是样本数,列数是输出个数。为了表达样本预测各个输出的概率,softmax运算会先通过exp函数对每个元素做指数运算,再对exp矩阵同行元素求和,最后令矩阵每行元素与改行元素之和相除。这样一来,最终得到的矩阵每行元素和为1且非负。因此,该矩阵每行都是合法的概率分布。softmax运算的输出矩阵中的任一行元素代表了一个样本在各个输出类别上的预测概率。

可以看到,对于随机输入,我们将每个元素变成了非负数,且每一行和为1。

1.5.4 定义模型

有了softmax运算,我们可以定义上节描述的softmax回归模型了。这里通过reshape函数将每张原始图像改成长度为num_inputs的向量。

1.5.5 定义损失函数

上一节中,我们介绍了softmax回归使用的交叉熵损失函数。为了得到标签的预测概率,我们可以使用pick函数。在下面的例子中,变量y_hat是2个样本在3个类别的预测概率,变量y是这2个样本的标签类别。通过使用pick函数,我们得到了2个样本的标签的预测概率。与“softmax回归”一节数学表述中标签类别离散值从1开始逐一递增不同,在代码中,标签类别的离散值是从0开始逐一递增的。

下面实现了“softmax回归”一节中介绍的交叉熵损失函数。

1.5.6 计算分类准确率

给定一个类别的预测概率分布y_hat,我们把预测概率最大的类别作为输出类别。如果它与真实类别y一致,说明这次预测时正确的。分类准确率即正确预测数量与总预测数量之比。

为了演示准确率的计算,下面定义准确率accuracy函数。其中y_hat.argmax(axis=1)返回矩阵y_hat每行中最大元素的索引,且返回结果与变量y形状相同。我们在“数据操作”一节介绍过,相等条件判断式(y_hat.argmax(axis=1) == y)是一个值为0(相等为假)或1(相等为真)的NDArray。由于标签类型为整数,我们先将变量变换为浮点数再进行相等条件判断。

让我们继续使用在演示pick函数时定义的变量y_hat和y,并将它们分别作为预测概率分布和标签。可以看到,第一个样本预测类别为2(该行元素0.6在本行的索引为2),与真实标签0不一致;第二个样本预测类别为2(该行最大元素0.5在本行的索引为2),与真实标签2一致。因此,这两个样本上的分类准确率为0.5。

类似地,我们可以评价模型net在数据集data_iter上的准确率。

因为我们随机初始化了模型net,所以这个随机模型的准确率应该接近于类别个数10的倒数0.1。

1.5.7 训练模型

训练softmax回归的实现跟“线性回归的从零开始实现”一节介绍的线性回归中的实现非常相似。我们同样适用小批量随机梯度下降来优化模型的损失函数。在训练模型是,迭代周期数num_epochs和学习率lr都是可以调的超参数。改变它们的值可能会得到分类更准确的模型。

 

1.5.8 预测

训练完成后,现在就可以演示如何对图像进行分类了。给定一系列图像(第三行图像输出),我们比较一下它们的真实标签(第一行文本输出)和模型预测的结果(第二行文本输出)。

  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值