Python transformer——鸢尾花分类(基于pytorch)

在本文中,我们将介绍如何使用 PyTorch 实现一个基于 Transformer 架构的模型来对鸢尾花数据集进行分类。Transformer 架构最初在自然语言处理领域取得了巨大的成功,这里我们将其应用于分类任务中。

一、导入库

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from sklearn.preprocessing import StandardScaler
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split

二、定义 Transformer 模型

class MyTransformer(nn.Module):
    def __init__(self, d_model, nhead, num_layers, dropout, mean, scale):
        super(MyTransformer, self).__init__()
        self.mean = nn.Parameter(torch.tensor(mean, dtype=torch.float32), requires_grad=False)
        self.scale = nn.Parameter(torch.tensor(scale, dtype=torch.float32), requires_grad=False)
        self.input_linear = nn.Linear(d_model, d_model)
        self.positional_encoding = nn.Parameter(torch.zeros(1, 100, d_model))
        self.transformer = nn.Transformer(d_model=d_model, nhead=nhead, num_encoder_layers=num_layers,
                                          num_decoder_layers=num_layers, dropout=dropout)
        self.output_linear = nn.Linear(d_model, 3)  # Assuming 3 classes for classification

    def forward(self, x):
        # 标准化
        x = (x - self.mean) / self.scale
        x = x.unsqueeze(1)
        x = self.input_linear(x) + self.positional_encoding[:, :x.size(1)]

        # 使用相同的 x 作为目标序列(这是一个简化示例)
        tgt = x

        x = self.transformer(x, tgt)
        x = x.squeeze(1)
        x = self.output_linear(x)
        return x

这个自定义的MyTransformer类继承自nn.Module,在构造函数中,它初始化了线性层、参数化的均值和标准差、位置编码以及 Transformer 模块。前向传播过程中,首先对输入进行标准化,然后进行线性变换和添加位置编码,接着将输入作为目标序列传入 Transformer 模块,最后通过线性层输出分类结果。

三、加载示例数据

这里我们使用sklearn的load_iris函数加载鸢尾花数据集,并将其分为训练集和测试集。

iris = load_iris()
X = iris.data
y = iris.target

# 数据分割
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

四、数据标准化及保存参数

对训练数据进行标准化,并将计算得到的均值和标准差保存下来。

# 计算均值和标准差
scaler = StandardScaler()
scaler.fit(X_train)
mean, scale = scaler.mean_, scaler.scale_

# 保存均值和标准差
mean, scale = mean.tolist(), scale.tolist()

五、定义模型并训练

定义模型的超参数,创建模型实例,设置优化器和损失函数。将训练数据转换为 PyTorch 的 Tensor 类型,然后进行训练循环,在每个 epoch 中计算损失、反向传播并更新模型参数。

# 定义模型
d_model = X_train.shape[1]
nhead = 2
num_layers = 2
dropout = 0.1

model = MyTransformer(d_model, nhead, num_layers, dropout, mean, scale)

# 训练模型
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

# 转换数据为 Tensor
X_train_tensor = torch.tensor(X_train, dtype=torch.float32)
y_train_tensor = torch.tensor(y_train, dtype=torch.long)

# 训练循环
model.train()
for epoch in range(10):
    optimizer.zero_grad()
    outputs = model(X_train_tensor)
    loss = criterion(outputs, y_train_tensor)
    loss.backward()
    optimizer.step()
    print(f'Epoch {epoch+1}, Loss: {loss.item()}')

六、保存模型

最后,将训练好的模型参数以及均值和标准差保存到文件中。

torch.save({
    'model_state_dict': model.state_dict(),
    'mean': mean,
    'scale': scale,
}, 'my_transformer_with_params.pth')

print("Model trained and saved successfully.")

通过以上步骤,我们成功地实现了一个基于 Transformer 的模型来对鸢尾花数据集进行分类。
如有错误,还望及时指正。
末尾附上GitHub完整代码(包含训练后的鸢尾花数据集、模型文件以及推理代码):转至GitHub

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值