Task01 赛题理解及baseline学习
参考资料:Datawhale指导文档
1. 问题描述
该赛题是一个多分类问题(共四类),训练集给出心跳信号序列及该序列所属类别,目标是训练一个分类模型,当给出心跳序列时,能够自动对该序列进行分类。
平台给出的训练集包含10万条数据,测试集包含2万条数据,训练集和测试集示例如下:
2. baseline学习
2.1 数据预处理
这一部分以baseline中定义的函数及调用过程为主线来描述。
2.1.1 函数 reduce_mem_usage(df)
这一个函数主要是进行了数据转换。以train为例,首先查看数据类型如下(test和train相比只是缺少了label):
train.detypes
out[]:
id int64
heartbeat_signals object
label float64
dtype: object
以int为例,int8/16/32/64所能存储的数据大小不同,占用的存储空间也不同,该函数的作用是要让可以用更少空间存储的数据不占用更大的空间。该函数将’id’和’label’下的数据,能转换就转换到最低位,将’heartbeat_signals’下的数据转换为’categoty’类型。
2.1.2 简单预处理
对于train和test,分别进行如下预处理:
(以train为例)
从csv直接读取的数据中,heartbeat_signals特征是用逗号分隔的一列数,在预处理中首先将他们拆分成多列,然后用s_i给第i列信号命名,最后用2.1.1定义的reduce_mem_usage函数降低train的存储空间。
对test也做相同处理。
2.1.3 训练数据/测试数据准备
将处理后的train中心跳信号作为特征(s_i们),'label’作为标签,分别获得train_x和train_y;将处理后的test中心跳信号作为特征,获得x_test。
2.2 模型训练
2.2.1 函数abs_sum(y_pre,y_tru)
用于计算y_pre(预测值)和y_tru(真实值)每一个对应差绝对值的总和。
2.2.2 函数cv_model(clf, train_x, train_y, test_x, clf_name)
构建交叉验证模型。baseline中设置的是5折交叉验证,打乱数据集,随机种子为2021。同时将lightgbm模型的参数设置在了该模型中。
2.2.3 函数lgb_model(x_train, y_train, x_test)
调用2.2.2中cv_model,训练并输出结果。
最后将结果转换成平台要求的形式,输出csv文件。