从零实现Softmax回归:深入理解多分类模型
引言
Softmax回归是深度学习中处理多分类问题的基础模型。与线性回归不同,Softmax回归能够处理多个类别的分类问题。本文将带您从零开始实现一个完整的Softmax回归模型,使用Fashion-MNIST数据集进行训练和评估。
数据准备
我们使用Fashion-MNIST数据集,它包含10个类别的服装图片,每个图片大小为28×28像素。首先设置批量大小为256:
batch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)
模型参数初始化
Softmax回归模型需要为每个输入特征和每个类别学习权重:
- 输入特征:将28×28图像展平为784维向量
- 输出类别:10个类别对应10个输出
权重矩阵W的形状为784×10,偏置b的形状为10×1:
num_inputs = 784
num_outputs = 10
W = np.random.normal(0, 0.01, (num_inputs, num_outputs))
b = np.zeros(num_outputs)
Softmax运算实现
Softmax函数将原始分数转换为概率分布:
- 对每个元素取指数
- 计算每行的和(规范化常数)
- 将每个元素除以其行的和
数学表达式为: $$ \mathrm{softmax}(\mathbf{X}){ij} = \frac{\exp(\mathbf{X}{ij})}{\sum_k \exp(\mathbf{X}_{ik})} $$
代码实现:
def softmax(X):
X_exp = d2l.exp(X)
partition = d2l.reduce_sum(X_exp, 1, keepdims=True)
return X_exp / partition # 广播机制
模型定义
将输入图像展平后,通过矩阵乘法和加法得到输出分数,再应用softmax:
def net(X):
return softmax(d2l.matmul(d2l.reshape(X, (-1, W.shape[0])), W) + b)
损失函数:交叉熵
交叉熵损失衡量预测概率分布与真实分布的差异:
def cross_entropy(y_hat, y):
return - d2l.log(y_hat[range(len(y_hat)), y])
评估指标:分类精度
分类精度是正确预测的比例:
def accuracy(y_hat, y):
y_hat = d2l.argmax(y_hat, axis=1)
cmp = d2l.astype(y_hat, y.dtype) == y
return float(d2l.reduce_sum(d2l.astype(cmp, y.dtype)))
训练过程
训练循环包括:
- 前向传播计算预测
- 计算损失
- 反向传播计算梯度
- 更新参数
def train_epoch_ch3(net, train_iter, loss, updater):
metric = Accumulator(3)
for X, y in train_iter:
with autograd.record():
y_hat = net(X)
l = loss(y_hat, y)
l.backward()
updater(X.shape[0])
metric.add(float(l.sum()), accuracy(y_hat, y), y.size)
return metric[0]/metric[2], metric[1]/metric[2]
完整训练
设置学习率0.1,训练10个epoch:
lr = 0.1
num_epochs = 10
train_ch3(net, train_iter, test_iter, cross_entropy, num_epochs, updater)
预测与评估
训练完成后,可以在测试集上评估模型并可视化预测结果:
def predict_ch3(net, test_iter, n=6):
for X, y in test_iter:
break
trues = d2l.get_fashion_mnist_labels(y)
preds = d2l.get_fashion_mnist_labels(d2l.argmax(net(X), axis=1))
titles = [true +'\n' + pred for true, pred in zip(trues, preds)]
d2l.show_images(d2l.reshape(X[0:n], (n, 28, 28)), 1, n, titles=titles[0:n])
关键点总结
- 模型结构:Softmax回归通过线性变换加softmax激活实现多分类
- 损失函数:交叉熵损失适合衡量概率分布差异
- 训练流程:与线性回归类似,但使用不同的损失函数
- 数值稳定性:实际实现需考虑数值稳定性问题(如log(0))
常见问题与思考
- 数值稳定性:直接实现softmax可能遇到数值上溢/下溢问题,如何改进?
- 损失函数定义域:交叉熵中对数函数的定义域限制如何解决?
- 应用场景:在医疗诊断等高风险领域,仅选择最大概率类别是否足够?
- 大规模分类:当类别数量极大时(如语言模型),softmax计算会有什么挑战?
通过本教程,您应该已经掌握了Softmax回归的核心概念和实现细节。这是理解更复杂神经网络模型的重要基础。
创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考