从零实现Softmax回归:深入理解多分类模型

从零实现Softmax回归:深入理解多分类模型

d2l-zh 《动手学深度学习》:面向中文读者、能运行、可讨论。中英文版被70多个国家的500多所大学用于教学。 d2l-zh 项目地址: https://gitcode.com/gh_mirrors/d2/d2l-zh

引言

Softmax回归是深度学习中处理多分类问题的基础模型。与线性回归不同,Softmax回归能够处理多个类别的分类问题。本文将带您从零开始实现一个完整的Softmax回归模型,使用Fashion-MNIST数据集进行训练和评估。

数据准备

我们使用Fashion-MNIST数据集,它包含10个类别的服装图片,每个图片大小为28×28像素。首先设置批量大小为256:

batch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)

模型参数初始化

Softmax回归模型需要为每个输入特征和每个类别学习权重:

  1. 输入特征:将28×28图像展平为784维向量
  2. 输出类别:10个类别对应10个输出

权重矩阵W的形状为784×10,偏置b的形状为10×1:

num_inputs = 784
num_outputs = 10

W = np.random.normal(0, 0.01, (num_inputs, num_outputs))
b = np.zeros(num_outputs)

Softmax运算实现

Softmax函数将原始分数转换为概率分布:

  1. 对每个元素取指数
  2. 计算每行的和(规范化常数)
  3. 将每个元素除以其行的和

数学表达式为: $$ \mathrm{softmax}(\mathbf{X}){ij} = \frac{\exp(\mathbf{X}{ij})}{\sum_k \exp(\mathbf{X}_{ik})} $$

代码实现:

def softmax(X):
    X_exp = d2l.exp(X)
    partition = d2l.reduce_sum(X_exp, 1, keepdims=True)
    return X_exp / partition  # 广播机制

模型定义

将输入图像展平后,通过矩阵乘法和加法得到输出分数,再应用softmax:

def net(X):
    return softmax(d2l.matmul(d2l.reshape(X, (-1, W.shape[0])), W) + b)

损失函数:交叉熵

交叉熵损失衡量预测概率分布与真实分布的差异:

def cross_entropy(y_hat, y):
    return - d2l.log(y_hat[range(len(y_hat)), y])

评估指标:分类精度

分类精度是正确预测的比例:

def accuracy(y_hat, y):
    y_hat = d2l.argmax(y_hat, axis=1)
    cmp = d2l.astype(y_hat, y.dtype) == y
    return float(d2l.reduce_sum(d2l.astype(cmp, y.dtype)))

训练过程

训练循环包括:

  1. 前向传播计算预测
  2. 计算损失
  3. 反向传播计算梯度
  4. 更新参数
def train_epoch_ch3(net, train_iter, loss, updater):
    metric = Accumulator(3)
    for X, y in train_iter:
        with autograd.record():
            y_hat = net(X)
            l = loss(y_hat, y)
        l.backward()
        updater(X.shape[0])
        metric.add(float(l.sum()), accuracy(y_hat, y), y.size)
    return metric[0]/metric[2], metric[1]/metric[2]

完整训练

设置学习率0.1,训练10个epoch:

lr = 0.1
num_epochs = 10
train_ch3(net, train_iter, test_iter, cross_entropy, num_epochs, updater)

预测与评估

训练完成后,可以在测试集上评估模型并可视化预测结果:

def predict_ch3(net, test_iter, n=6):
    for X, y in test_iter:
        break
    trues = d2l.get_fashion_mnist_labels(y)
    preds = d2l.get_fashion_mnist_labels(d2l.argmax(net(X), axis=1))
    titles = [true +'\n' + pred for true, pred in zip(trues, preds)]
    d2l.show_images(d2l.reshape(X[0:n], (n, 28, 28)), 1, n, titles=titles[0:n])

关键点总结

  1. 模型结构:Softmax回归通过线性变换加softmax激活实现多分类
  2. 损失函数:交叉熵损失适合衡量概率分布差异
  3. 训练流程:与线性回归类似,但使用不同的损失函数
  4. 数值稳定性:实际实现需考虑数值稳定性问题(如log(0))

常见问题与思考

  1. 数值稳定性:直接实现softmax可能遇到数值上溢/下溢问题,如何改进?
  2. 损失函数定义域:交叉熵中对数函数的定义域限制如何解决?
  3. 应用场景:在医疗诊断等高风险领域,仅选择最大概率类别是否足够?
  4. 大规模分类:当类别数量极大时(如语言模型),softmax计算会有什么挑战?

通过本教程,您应该已经掌握了Softmax回归的核心概念和实现细节。这是理解更复杂神经网络模型的重要基础。

d2l-zh 《动手学深度学习》:面向中文读者、能运行、可讨论。中英文版被70多个国家的500多所大学用于教学。 d2l-zh 项目地址: https://gitcode.com/gh_mirrors/d2/d2l-zh

创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

邬筱杉Lewis

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值