PyTorch多分类处理

下面是使用PyTorch完成多分类处理的代码。首先,通过torch.tensor将x_data和y_data导入为Tensor。然后,创建包含4个输入特征和3个输出类别的多分类模型。接下来,定义损失函数和优化器,使用梯度下降进行训练。在训练过程中,打印每200次迭代的代价函数值。使用[[1, 11, 7, 9], [1, 3, 4, 3], [1, 1, 0, 1]]预测结果。计算并打印上述数据的准确率。绘制损失函数图像。最后,打印模型的准确率、权重和截距。

import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt

# 导入数据
x_data = torch.tensor([[1, 2, 1, 1], [2, 1, 3, 2], [3, 1, 3, 4], [4, 1, 5, 5],
                       [1, 7, 5, 5], [1, 2, 5, 6], [1, 6, 6, 6], [1, 7, 7, 7]], dtype=torch.float32)
y_data = torch.tensor([[0, 0, 1], [0, 0, 1], [0, 0, 1], [0, 1, 0],
                       [0, 1, 0], [0, 1, 0], [1, 0, 0], [1, 0, 0]], dtype=torch.float32)

# 创建多分类模型
class MultiClassModel(nn.Module):
    def __init__(self):
        super(MultiClassModel, self).__init__()
        self.linear = nn.Linear(4, 3)
        self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
        out = self.linear(x)
        out = self.softmax(out)
        return out

model = MultiClassModel()

# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)

# 训练模型
losses = []
for epoch in range(1000):
    optimizer.zero_grad()
    outputs = model(x_data)
    loss = criterion(outputs, torch.argmax(y_data, dim=1))
    loss.backward()
    optimizer.step()
    losses.append(loss.item())

    if (epoch+1) % 200 == 0:
        print(f'Epoch: {epoch+1}, Loss: {loss.item()}')

# 使用[[1, 11, 7, 9], [1, 3, 4, 3], [1, 1, 0, 1]]打印预测结果
test_data = torch.tensor([[1, 11, 7, 9], [1, 3, 4, 3], [1, 1, 0, 1]], dtype=torch.float32)
predictions = model(test_data)
_, predicted_labels = torch.max(predictions, dim=1)
print(f'Predictions: {predicted_labels}')

# 计算准确率
_, actual_labels = torch.max(y_data, dim=1)
correct = (predicted_labels == actual_labels).sum().item()
accuracy = correct / len(actual_labels)
print(f'Accuracy: {accuracy}')

# 绘制损失函数图像
plt.plot(range(len(losses)), losses)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.show()

# 打印模型的准确率、权重和截距
print(f'Model Accuracy: {accuracy}')
print(f'Model Weights: {model.linear.weight}')
print(f'Model Bias: {model.linear.bias}')
 

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值