在本文中,我们将介绍如何使用 PyTorch 实现一个基于 Transformer 架构的模型来对鸢尾花数据集进行分类。Transformer 架构最初在自然语言处理领域取得了巨大的成功,这里我们将其应用于分类任务中。
一、导入库
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from sklearn.preprocessing import StandardScaler
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
二、定义 Transformer 模型
class MyTransformer(nn.Module):
def __init__(self, d_model, nhead, num_layers, dropout, mean, scale):
super(MyTransformer, self).__init__()
self.mean = nn.Parameter(torch.tensor(mean, dtype=torch.float32), requires_grad=False)
self.scale = nn.Parameter(torch.tensor(scale, dtype=torch.float32), requires_grad=False)
self.input_linear = nn.Linear(d_model, d_model)
self.positional_encoding = nn.Parameter(torch.zeros(1, 100, d_model))
self.transformer = nn.Transformer(d_model=d_model, nhead=nhead, num_encoder_layers=num_layers,
num_decoder_layers=num_layers, dropout=dropout)
self.output_linear = nn.Linear(d_model, 3) # Assuming 3 classes for classification
def forward(self, x):
# 标准化
x = (x - self.mean) / self.scale
x = x.unsqueeze(1)
x = self.input_linear(x) + self.positional_encoding[:, :x.size(1)]
# 使用相同的 x 作为目标序列(这是一个简化示例)
tgt = x
x = self.transformer(x, tgt)
x = x.squeeze(1)
x = self.output_linear(x)
return x
这个自定义的MyTransformer类继承自nn.Module,在构造函数中,它初始化了线性层、参数化的均值和标准差、位置编码以及 Transformer 模块。前向传播过程中,首先对输入进行标准化,然后进行线性变换和添加位置编码,接着将输入作为目标序列传入 Transformer 模块,最后通过线性层输出分类结果。
三、加载示例数据
这里我们使用sklearn的load_iris函数加载鸢尾花数据集,并将其分为训练集和测试集。
iris = load_iris()
X = iris.data
y = iris.target
# 数据分割
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
四、数据标准化及保存参数
对训练数据进行标准化,并将计算得到的均值和标准差保存下来。
# 计算均值和标准差
scaler = StandardScaler()
scaler.fit(X_train)
mean, scale = scaler.mean_, scaler.scale_
# 保存均值和标准差
mean, scale = mean.tolist(), scale.tolist()
五、定义模型并训练
定义模型的超参数,创建模型实例,设置优化器和损失函数。将训练数据转换为 PyTorch 的 Tensor 类型,然后进行训练循环,在每个 epoch 中计算损失、反向传播并更新模型参数。
# 定义模型
d_model = X_train.shape[1]
nhead = 2
num_layers = 2
dropout = 0.1
model = MyTransformer(d_model, nhead, num_layers, dropout, mean, scale)
# 训练模型
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()
# 转换数据为 Tensor
X_train_tensor = torch.tensor(X_train, dtype=torch.float32)
y_train_tensor = torch.tensor(y_train, dtype=torch.long)
# 训练循环
model.train()
for epoch in range(10):
optimizer.zero_grad()
outputs = model(X_train_tensor)
loss = criterion(outputs, y_train_tensor)
loss.backward()
optimizer.step()
print(f'Epoch {epoch+1}, Loss: {loss.item()}')
六、保存模型
最后,将训练好的模型参数以及均值和标准差保存到文件中。
torch.save({
'model_state_dict': model.state_dict(),
'mean': mean,
'scale': scale,
}, 'my_transformer_with_params.pth')
print("Model trained and saved successfully.")
通过以上步骤,我们成功地实现了一个基于 Transformer 的模型来对鸢尾花数据集进行分类。
如有错误,还望及时指正。
末尾附上GitHub完整代码(包含训练后的鸢尾花数据集、模型文件以及推理代码):转至GitHub