FTTransformer,一个很能打的模型

FTTransformer,是一个BERT模型架构在结构化数据集上的迁移变体。和BERT一样,它非常能打。

它可能是少数能够在大多数结构化数据集上取得超过或者匹配LightGBM结果的深度模型。

本范例我们将应用它在来对Covertype植被覆盖数据集进行一个多分类任务。

我们在测试集取得了91%的准确率,相比之下LightGBM只有83%的准确率。

公众号算法美食屋后台回复关键词:torchkeras,获取本文notebook源码和所用Covertype数据集下载链接。

〇,原理讲解

FTTransformer是一个可以用于结构化(tabular)数据的分类和回归任务的模型。

FT 即 Feature Tokenizer的意思,把结构化数据中的离散特征和连续特征都像单词一样编码成一个向量。

从而可以像对text数据那样 应用 Transformer对 Tabular数据进行特征抽取。

值得注意的是,它对Transformer作了一些微妙的改动以适应 Tabular数据。

例如:去除第一个Transformer输入的LayerNorm层,仿照BERT的设计增加了output token(CLS token) 与features token 一起进行进入Transformer参与注意力计算。

一,准备数据

 
 
import numpy as np 
import pandas as pd 
from sklearn.model_selection import train_test_split


file_path = "covertype.parquet"
dfdata = pd.read_parquet(file_path)
...


dftmp, dftest_raw = train_test_split(dfdata, random_state=42, test_size=0.2)
dftrain_raw, dfval_raw = train_test_split(dftmp, random_state=42, test_size=0.2)


print("len(dftrain) = ",len(dftrain_raw))
print("len(dfval) = ",len(dfval_raw))
print("len(dftest) = ",len(dftest_raw))
dfdata.shape =  (581012, 13)
target_col =  Cover_Type
cat_cols =  ['Wilderness_Area', 'Soil_Type']
num_cols =  ['Elevation', 'Aspect', 'Slope', '...']
len(dftrain) =  371847
len(dfval) =  92962
len(dftest) =  116203
 
 
from torchkeras.tabular import TabularPreprocessor
from sklearn.preprocessing import OrdinalEncoder


#特征工程
...


dftest = pipe.transform(dftest_raw.drop(target_col,axis=1))
dftest[target_col] = encoder.transform(
    dftest_raw[target_col].values.reshape(-1,1)).astype(np.int32)
 
 
from torchkeras.tabular import TabularDataset
from torch.utils.data import Dataset,DataLoader 


def get_dataset(dfdata):
    return TabularDataset(
                data = dfdata,
                task = 'classification',
                target = [target_col],
                continuous_cols = pipe.get_numeric_features(),
                categorical_cols = pipe.get_embedding_features()
        )


def get_dataloader(ds,batch_size=1024,num_workers=0,shuffle=False):
    dl = DataLoader(
            ds,
            batch_size=batch_size,
            shuffle=shuffle,
            num_workers=num_workers,
            pin_memory=False,
        )
    return dl 
    
ds_train = get_dataset(dftrain)
ds_val = get_dataset(dfval)
ds_test = get_dataset(dftest)


dl_train = get_dataloader(ds_train,shuffle=True)
dl_val = get_dataloader(ds_val,shuffle=False)
dl_test = get_dataloader(ds_test,shuffle=False)
 
 
for batch in dl_train:
    break

二,定义模型

 
 
from torchkeras.tabular.models import FTTransformerConfig,FTTransformerModel


model_config = FTTransformerConfig(
    task="classification",
    num_attn_blocks=3
)


config = model_config.merge_dataset_config(ds_train)
net = FTTransformerModel(config = config)


#初始化参数
net.reset_weights()
net.data_aware_initialization(dl_train)


print(net.backbone.output_dim)
print(net.hparams.output_dim)

三,训练模型

 
 
from torchkeras import KerasModel 
from torchkeras.tabular import StepRunner 
KerasModel.StepRunner = StepRunner
 
 
import torch 
from torch import nn 
class Accuracy(nn.Module):
    def __init__(self):
        super().__init__()


        self.correct = nn.Parameter(torch.tensor(0.0),requires_grad=False)
        self.total = nn.Parameter(torch.tensor(0.0),requires_grad=False)


    def forward(self, preds: torch.Tensor, targets: torch.Tensor):
        preds = preds.argmax(dim=-1)
        targets = targets.reshape(-1)
        m = (preds == targets).sum()
        n = targets.shape[0] 
        self.correct += m 
        self.total += n
        
        return m/n


    def compute(self):
        return self.correct.float() / self.total 
    
    def reset(self):
        self.correct -= self.correct
        self.total -= self.total
 
 
keras_model = KerasModel(net,
                   loss_fn=None,
                   optimizer = torch.optim.AdamW(net.parameters(),lr = 1e-3),
                   metrics_dict = {"acc":Accuracy()}
                   )
 
 
keras_model.fit(
    train_data = dl_train,
    val_data= dl_val,
    ckpt_path='checkpoint',
    epochs=20,
    patience=10,
    monitor="val_acc", 
    mode="max",
    plot = True,
    wandb = False
)

e12eab2822f1a30f4c356b0086c24504.png

四,评估模型

 
 
keras_model.evaluate(dl_val)
{'val_loss': 0.22164690216164012, 'val_acc': 0.9103181958198547}
 
 
keras_model.evaluate(dl_test)
{'val_loss': 0.22033428426897317, 'val_acc': 0.9109489321708679}

五,使用模型

 
 
from tqdm import tqdm 
net = net.cpu()
net.eval()
preds = []
with torch.no_grad():
    for batch in tqdm(dl_test):
        preds.append(net.predict(batch))
 
 
yhat_list = [yd.argmax(dim=-1).tolist() for yd in preds]
yhat = []
for yd in yhat_list:
    yhat.extend(yd)
yhat = encoder.inverse_transform(np.array(yhat).reshape(-1,1))
 
 
dftest_raw = dftest_raw.rename(columns = {target_col: 'y'})
dftest_raw['yhat'] = yhat
 
 
from sklearn.metrics import classification_report
print(classification_report(y_true = dftest_raw['y'],y_pred = dftest_raw['yhat']))
precision    recall  f1-score   support

           1       0.90      0.91      0.91     42557
           2       0.92      0.92      0.92     56500
           3       0.92      0.90      0.91      7121
           4       0.85      0.82      0.83       526
           5       0.78      0.75      0.77      1995
           6       0.84      0.82      0.83      3489
           7       0.92      0.91      0.91      4015

    accuracy                           0.91    116203
   macro avg       0.88      0.86      0.87    116203
weighted avg       0.91      0.91      0.91    116203
 
 
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix




# 计算混淆矩阵
cm = confusion_matrix(dftest_raw['y'], dftest_raw['yhat'])


# 将混淆矩阵转换为DataFrame
df_cm = pd.DataFrame(cm, index=['Actual {}'.format(i) for i in range(cm.shape[0])],
                     columns=['Predicted {}'.format(i) for i in range(cm.shape[1])])


# 使用seaborn绘制混淆矩阵
plt.figure(figsize=(10,7))
sns.heatmap(df_cm, annot=True, fmt='d', cmap='Blues', cbar=False)
plt.title('Confusion Matrix')
plt.xlabel('Predicted labels')
plt.ylabel('True labels')
plt.show()

5b7bae7113e44c9b7ab6efd4eed5e6aa.png

六,保存模型

最佳模型权重已经保存在ckpt_path = 'checkpoint'位置了。

 
 
net.load_state_dict(torch.load('checkpoint'))

七,与LightGBM对比

 
 
import pandas as pd 
import lightgbm as lgb
from sklearn.preprocessing import OrdinalEncoder
from sklearn.metrics import accuracy_score 


dftmp, dftest_raw = train_test_split(dfdata, random_state=42, test_size=0.2)
dftrain_raw, dfval_raw = train_test_split(dftmp, random_state=42, test_size=0.2)


dftrain = dftrain_raw.copy()
dfval = dfval_raw.copy()
dftest = dftest_raw.copy()


target_col = 'Cover_Type'
cat_cols = ['Wilderness_Area', 'Soil_Type']


encoder = OrdinalEncoder()


dftrain[target_col] = encoder.fit_transform(dftrain[target_col].values.reshape(-1,1)) 
dfval[target_col] = encoder.transform(dfval[target_col].values.reshape(-1,1))
dftest[target_col] = encoder.transform(dftest[target_col].values.reshape(-1,1))


for col in cat_cols:
    dftrain[col] = dftrain[col].astype(int)
    dfval[col] = dfval[col].astype(int)
    dftest[col] = dftest[col].astype(int)


ds_train = lgb.Dataset(dftrain.drop(columns=[target_col]), label=dftrain[target_col],categorical_feature=cat_cols)
ds_val = lgb.Dataset(dfval.drop(columns=[target_col]), label=dfval[target_col],categorical_feature=cat_cols)
ds_test = lgb.Dataset(dftest.drop(columns=[target_col]), label=dftest[target_col],categorical_feature=cat_cols)




import lightgbm as lgb


params = {
    'n_estimators':500,
    'boosting_type': 'gbdt',
    'objective':'multiclass',
    'num_class': 7,  # 类别数量
    'metric': 'multi_logloss', 
    'learning_rate': 0.01,
    'verbose': 1,
    'early_stopping_round':50
}
model = lgb.train(params, ds_train, 
        valid_sets=[ds_val], 
        valid_names=['validate']
        )


y_pred_val = model.predict(dfval.drop(target_col,axis = 1), num_iteration=model.best_iteration)
y_pred_val = np.argmax(y_pred_val, axis=1)


y_pred_test = model.predict(dftest.drop(target_col,axis = 1), num_iteration=model.best_iteration)
y_pred_test = np.argmax(y_pred_test, axis=1)


val_score = accuracy_score(dfval[target_col], y_pred_val)
test_score = accuracy_score(dftest[target_col], y_pred_test) 


print('val_score = ',val_score)
print('test_score = ' , test_score)
val_score =  0.8321464684494739
test_score =  0.8329389086340284

公众号算法美食屋后台回复关键词:torchkeras,获取本文notebook源码和更多有趣范例。

5d60346ee03cf89d0202eb2b8f477145.png

fb972353b35008610686024e8a6163d0.png

f69f8e559906511f1514627edee6675a.png

de4e1ccb4d3db90ce5041038c14c7014.png

  • 3
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
校园失物招领系统管理系统按照操作主体分为管理员和用户。管理员的功能包括字典管理、论坛管理、公告信息管理、失物招领管理、失物认领管理、寻物启示管理、寻物认领管理、用户管理、管理员管理。用户的功能等。该系统采用了Mysql数据库,Java语言,Spring Boot框架等技术进行编程实现。 校园失物招领系统管理系统可以提高校园失物招领系统信息管理问题的解决效率,优化校园失物招领系统信息处理流程,保证校园失物招领系统信息数据的安全,它是一个非常可靠,非常安全的应用程序。 ,管理员权限操作的功能包括管理公告,管理校园失物招领系统信息,包括失物招领管理,培训管理,寻物启事管理,薪资管理等,可以管理公告。 失物招领管理界面,管理员在失物招领管理界面中可以对界面中显示,可以对失物招领信息的失物招领状态进行查看,可以添加新的失物招领信息等。寻物启事管理界面,管理员在寻物启事管理界面中查看寻物启事种类信息,寻物启事描述信息,新增寻物启事信息等。公告管理界面,管理员在公告管理界面中新增公告,可以删除公告。公告类型管理界面,管理员在公告类型管理界面查看公告的工作状态,可以对公告的数据进行导出,可以添加新公告的信息,可以编辑公告信息,删除公告信息。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值