softmax回归适用于分类问题。
softmax回归是一种常用的多分类模型,它的基本思想是将每个类别的得分通过softmax函数映射为概率,然后选择概率最大的这个类别作为预测结果。
在softmax回归中,对于每个样本,我们首先将它的特征乘以相应的权重,再求和,得到每个类别的得分。然后,将这些得分通过softmax函数映射为每个类别的概率,即将得分转化为非负数并且和为1的概率分布。我们选择概率最大的类别作为预测结果。softmax回归通常使用交叉熵作为损失函数,用于衡量预测结果与真实标签的相似度。
softmax回归同线性回归一样,也是一个单层神经网络。
假设一张图片只有4个像素:x1,x2,x3,x4.若要将图片分成三类,每张图片中的每个像素点将对应着一个权重,那么4个像素就有4个权重。输出为:
于是得到:
以下是一个使用PyTorch实现softmax回归的代码示例。这段代码实现了一个使用MNIST数据集训练的softmax回归模型。它包括了数据加载、模型定义、损失函数和优化器的设置、训练循环、以及在测试集上计算准确率等部分。
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision.transforms import Compose, ToTensor, Normalize
from tqdm import tqdm #tqdm是一个快速的,易扩展的进度条提示模块
# 定义超参数
batch_size = 100
learning_rate = 0.1
num_epochs = 10 #epochs指的就是训练过程中数据将被“轮”多少次
# 加载数据集,并进行预处理
#训练集
train_dataset = MNIST(root='data', train=True, transform=Compose([ToTensor(), Normalize((0.1307,), (0.3081,))]))
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
#测试集
test_dataset = MNIST(root='data', train=False, transform=Compose([ToTensor(), Normalize((0.1307,), (0.3081,))]))
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
# 定义模型
class SoftmaxRegression(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(in_features=784, out_features=10)
def forward(self, x):
x = torch.flatten(x, start_dim=1) # 将输入展平为一维向量
logits = self.linear(x)
probs = nn.functional.softmax(logits, dim=1)
return probs
model = SoftmaxRegression() #实例化
# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
# 训练模型
for epoch in range(num_epochs):
for images, labels in tqdm(train_loader):
# 前向传播
logits = model(images)
loss = criterion(logits, labels)
# 反向传播,更新模型参数
optimizer.zero_grad()
loss.backward()
optimizer.step()
# 计算在测试集上的准确率
correct = 0
total = 0
with torch.no_grad(): # 在测试集上不需要计算梯度
for images, labels in test_loader:
logits = model(images)
predicted_labels = torch.argmax(logits, dim=1)
correct += (predicted_labels == labels).sum().item()
total += labels.size(0)
accuracy = correct / total
print(f'Epoch {epoch + 1}: Train Loss={loss.item():.4f}, Test Accuracy={accuracy:.4f}')