HW2 三分类

该代码示例展示了如何利用PyTorch构建一个基于Transformer的三分类模型。首先,生成2D数据点并分配标签,然后划分训练集和测试集。接着,定义一个自定义数据集类并使用DataLoader加载数据。模型采用Transformer架构,经过训练后,计算并输出测试集的准确性。最后,通过散点图展示分类结果。
摘要由CSDN通过智能技术生成

使用ChatGPT生成三分类模型

1.生成代码

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
import matplotlib.pyplot as plt

# 生成2D数据点
num_samples = 300
input_size = 2
num_classes = 3
sequence_length = 1

data = np.random.randn(num_samples, sequence_length, input_size)

# 为数据点分配标签
labels = []
for point in data:
    if point[0, 0] + point[0, 1] > 0.3:
        labels.append(0)
    elif point[0, 0] - point[0, 1] > 0.3:
        labels.append(1)
    else:
        labels.append(2)
labels = np.array(labels)

# 划分训练集和测试集
train_data, test_data = data[:200], data[200:]
train_labels, test_labels = labels[:200], labels[200:]


class CustomDataset(Dataset):
    def __init__(self, data, labels):
        self.data = data
        self.labels = labels

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        return self.data[idx], self.labels[idx]


train_dataset = CustomDataset(train_data, train_labels)
test_dataset = CustomDataset(test_data, test_labels)

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False)


class TransformerClassifier(nn.Module):
    def __init__(self, input_size, num_classes):
        super(TransformerClassifier, self).__init__()
        self.pos_encoder = nn.Embedding(sequence_length, input_size)

        self.encoder_layer = nn.TransformerEncoderLayer(d_model=input_size, nhead=2)
        self.transformer = nn.TransformerEncoder(self.encoder_layer, num_layers=2)

        self.fc = nn.Linear(input_size, num_classes)

    def forward(self, x):
        pos = torch.arange(0, sequence_length).unsqueeze(0).repeat(x.size(0), 1)
        x = x + self.pos_encoder(pos)
        x = self.transformer(x.permute(1, 0, 2))
        x = self.fc(x[-1])
        return x


model = TransformerClassifier(input_size, num_classes)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# 训练
num_epochs = 100
loss_values = []

for epoch in range(num_epochs):
    epoch_loss = 0
    for i, (inputs, labels) in enumerate(train_loader):
        inputs, labels = inputs.float(), labels.long()

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()

    epoch_loss /= len(train_loader)
    loss_values.append(epoch_loss)
    print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {epoch_loss}")


# 测试
with torch.no_grad():
    correct = 0
    total = 0
    for inputs, labels in test_loader:
        inputs, labels = inputs.float(), labels.long()
        outputs = model(inputs)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    print(f"Accuracy: {100 * correct / total}%")

# 绘制散点图
fig, ax = plt.subplots()
colors = ['r', 'g', 'b']

# 绘制训练集
for i in range(num_classes):
    indices = train_labels == i
    ax.scatter(train_data[indices, 0, 0], train_data[indices, 0, 1], c=colors[i], marker='o',
               label=f'Train Class {i}')

# 绘制测试集
with torch.no_grad():
    for i in range(num_classes):
        indices = test_labels == i
        inputs = torch.tensor(test_data[indices], dtype=torch.float32)
        outputs = model(inputs)
        _, predicted = torch.max(outputs.data, 1)
        correct_indices = predicted == i
        incorrect_indices = predicted != i
        ax.scatter(test_data[indices, 0, 0][correct_indices], test_data[indices, 0, 1][correct_indices],
                   c=colors[i], marker='^', label=f'Correct Test Class {i}')
        ax.scatter(test_data[indices, 0, 0][incorrect_indices], test_data[indices, 0, 1][incorrect_indices],
                   c=colors[i], marker='x', label=f'Incorrect Test Class {i}')

ax.legend()
plt.xlabel('X1')
plt.ylabel('X2')
plt.title('Scatter Plot of Classification')
plt.show()

2.运行结果

在这里插入图片描述
在这里插入图片描述

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值