机器学习——Softmax分类模型

Softmax分类很多时候,在多分类问题中我们希望输出的是取到某个类别的概率,或者说,我们希望分值大的那个类别被经常取到,而分值较小的那一项也有一定可能性偶尔被取到。Softmax即是这样一种模型,最后的输出是每个类别被取到的概率值

假设有一个数组表示中的第个元素,那么这个元素的Softmax值为:

通常我们采用“交叉熵

它用来衡量两个取值为正的函数的相似性。对于两个完全相同的函数,它们的交叉熵为零;交叉熵越大,两个函数差异越大,反之,两个函数差异越大;对于概率分布或者概率密度函数,如果取值均大于零,交叉熵可以度量两个随机分布的差异性

作为Softmax分类模型的代价函数(原因是该模型试图通过Softmax值去拟合样本的真实类别标记,那么根据交叉熵,我们就需要使得预测函数和真实函数的差异尽可能小

其中表示类别函数,目标类的,其余类的表示Softmax函数,可以作为某个样本属于某个类别的概率的估计;代表类别。 

然后对求导,再根据梯度下降算法就可以得到参数的更新公式。

需要注意的是,若各个类别之间若为互斥关系时,采用Softmax分类模型,若各类之间存在重叠划分情况,则采用个二分类,即OvR模型。这一点可由Softmax的代价函数看出:当样本属于第个类时,交叉熵中的这一项才存在(如若不然,该项为0)。换言之,在Softmax分类模型中,某个样本属于且只属于所有类别中的一个类,不能同时属于两个或多个类。而个二分类模型,从其操作实施流程可知,一个样本可能属于两个或多个类。

### Softmax 回归概述 Softmax 回归是一种用于解决多分类问题的监督学习算法。该方法通过将输入特征映射到多个离散类别上的概率分布来工作,从而允许模型预测给定输入最可能所属的类别。 #### 1. Softmax 函数基本概念 Softmax 函数接收一组实数值作为输入,并将其转换为表示各个类别的相对概率值向量[^1]。具体来说,对于任意一个输入 \( z \),其对应的第 i 类的概率可由下式给出: \[ p(y=i|z)=\frac{e^{z_i}}{\sum_{j=1}^{n}{e^{z_j}}} \] 其中 n 表示总共有多少个不同的类别;\( e^{z_i} \) 是自然常数 e 的幂运算结果;分母是对所有类别求和的结果,确保最终输出的是有效的概率分布——即各成分之和等于 1。 这种机制使得即使当某些原始得分远大于其他得分时,也能合理地反映不同选项之间的差异程度,而不会因为极端值的存在而导致整个系统的崩溃或失真。 #### 2. 底层实现细节 为了更好地理解如何实际操作这一过程,在 Python 中可以通过 PyTorch 或 TensorFlow 等框架轻松构建并训练一个基于 Softmax 的神经网络模型来进行图像识别任务。下面是一个使用 PyTorch 实现简单版本的例子: ```python import torch.nn as nn class SimpleNet(nn.Module): def __init__(self, input_size, num_classes): super(SimpleNet, self).__init__() self.linear = nn.Linear(input_size, num_classes) def forward(self, x): out = self.linear(x) return nn.functional.softmax(out, dim=-1) ``` 此代码片段展示了怎样创建一个接受固定大小输入并向指定数量的目标类别输出概率估计的小型全连接前馈网络结构。`forward()` 方法内部调用了 `nn.functional.softmax()` 来完成最后一步转化处理。 #### 3. 数据准备与可视化 以 Fashion-MNIST 图像分类为例,这是一个包含7万个灰度服装图片的数据集,分为十个不同类型的商品标签。加载数据之后通常还需要做一些预处理工作,比如标准化像素强度范围、划分训练测试集合等。接着就可以利用 matplotlib 等工具展示部分样例以便直观感受所要解决问题的特点[^2]: ```python from torchvision import datasets, transforms import matplotlib.pyplot as plt transform = transforms.Compose([transforms.ToTensor()]) train_dataset = datasets.FashionMNIST(root='./data', train=True, download=True, transform=transform) fig, axes = plt.subplots(2, 5, figsize=(10, 4)) for ax, img in zip(axes.flatten(), [train_dataset[i][0] for i in range(10)]): ax.imshow(img.squeeze().numpy(), cmap='gray') plt.show() ``` 这段脚本会下载Fashion-MNIST数据集并将前十张图片显示出来供观察者查看。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Eureka丶

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值