LightGBM

LightGBM

LigthGBM是boosting集合模型中的新进成员,由微软提供,它和XGBoost一样是对GBDT的高效实现,原理上它和GBDT及XGBoost类似,都采用损失函数的负梯度作为当前决策树的残差近似值,去拟合新的决策树。

LigthGBM与XGBoost异同点

  • 模型精度:LightGBM和XGBoost相当。
  • 缺失值特征:XGBoost和LightGBM都可以自动处理特征缺失值。
  • 分类特征:LightGBM直接支持类别特征,XGBoost不支持类别特征,需要OneHot编码预处理。
  • 决策树生长策略:多数GBDT工具使用按层生长 (level-wise)的策略,LightGBM使用带深度限制的按叶子生长 (leaf-wise)算法。
  • 训练速度:LightGBM远快于XGBoost。
  • 内存消耗:LightGBM远小于XGBoost。

LightGBM训练快和内存消耗小的原因

  • Histogram算法:直方图算法是替代XGBoost的预排序(pre-sorted)算法的。

\quad 预排序算法首先将样本按照特征取值排序,然后从全部特征取值中找到最优的分裂点位,该算法的候选分裂点数量与样本数量成正比。
在这里插入图片描述

\quad 而直方图算法通过将连续特征值离散化到固定数量(如255个)的bins上,使得候选分为点位为常数个(num_bins -1)。
在这里插入图片描述

\quad 此外,直方图算法还能够作直方图差加速。当节点分裂成两个时,右边叶子节点的直方图等于其父节点的直方图减去左边叶子节点的直方图。从而大大减少构建直方图的计算量。
在这里插入图片描述

  • GOSS算法:Gradient-based One-Side Sampling,即基于梯度的单边采样算法

\quad GOSS保留所有的梯度较大的实例,在梯度小的实例上使用随机采样。为了抵消对数据分布的影响,计算信息增益的时候,GOSS对小梯度的数据引入常量乘数。GOSS首先根据数据的梯度绝对值排序,选取top a个实例。然后在剩余的数据中随机采样b个实例。接着计算信息增益时为采样出的小梯度数据乘以(1-a)/b,这样算法就会更关注训练不足的实例,而不会过多改变原数据集的分布。
在这里插入图片描述

  • EFB算法:Exclusive Feature Bundling,即互斥特征绑定算法

\quad EFB是通过特征捆绑的方式减少特征维度(其实是降维技术)的方式,来提升计算效率。通常被捆绑的特征都是互斥的(一个特征值为零,一个特征值不为零),这样两个特征捆绑起来才不会丢失信息。如果两个特征并不是完全互斥(部分情况下两个特征都是非零值),可以用一个指标对特征不互斥程度进行衡量,称之为冲突比率,当这个值较小时,我们可以选择把不完全互斥的两个特征捆绑,而不影响最后的精度。
在这里插入图片描述

  • 高效并行与网络通信优化
import pandas as pd
import numpy as np
import scipy
# import xgboost as xgb
import lightgbm as lgb
import tensorflow as tf
import seaborn as sns
import matplotlib.pyplot as plt

# from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import GridSearchCV
from sklearn.model_selection import cross_val_score
from  sklearn import metrics 


# from sklearn.utils import class_weight
from sklearn.model_selection import train_test_split
# from sklearn.metrics import roc_auc_score,precision_recall_curve

# 打印行、列数量设置
pd.set_option('display.max_columns',200)
pd.set_option('display.max_rows',200)
pd.set_option('display.width',200)

#忽略一些版本不兼容等警告
import warnings
warnings.filterwarnings("ignore")
# 模型参数设定
model = lgb.LGBMClassifier(boosting_type='gbdt'
                          ,class_weight=None
                          ,colsample_bytree=1.0
                          ,importance_type='split'
                          ,learning_rate=0.1
                          ,max_depth=-1
                          ,min_child_samples=20
                          ,min_child_weight=0.001
                          ,min_split_gain=0.0
                          ,n_estimators=200
                          ,n_jobs=1
                          ,num_leaves=31
                          ,objective=None
                          ,random_state=None
                          ,reg_alpha=0.0
                          ,reg_lambda=0.0
                          ,silent=True
                          ,subsample=1.0
                          ,subsample_for_bin=200000)

LGBMClassifier参数

  • boosting 或者’boost’ 或者 ‘boosting_type’: 一个字符串,给出了基学习器模型算法。可以为:

\qquad ‘gbdt’: 表示传统的梯度提升决策树。默认值为’gbdt’

\qquad ‘rf’: 表示随机森林

\qquad ‘dart’: 表示带dropout 的gbdt

\qquad ‘goss’:表示Gradient-based One-Side Sampling 的gbdt

  • colsample_bytree [default=1]:在建立树时对特征采样的比例。缺省值为1.取值范围: ( 0 , 1 ] (0,1] (0,1]。subsample, colsample_bytree = 0.8: 这个是最常见的初始值了。典型值的范围在0.5-0.9之间

  • learning_rates:每一次提升的学习率的列表

  • max_depth: 一个整数,限制了树模型的最大深度,默认值为-1。如果小于0,则表示没有限制。

  • min_data_in_leaf 或者 min_data_per_leaf 或者 min_data或者min_child_samples: 一个整数,表示一个叶子节点上包含的最少样本数量。默认值为 20

  • min_sum_hessian_in_leaf 或者 min_sum_hessian_per_leaf或者 min_sum_hessian 或者 min_hessian或者min_child_weight: 一个浮点数,表示一个叶子节点上的最小hessian 之和。(也就是叶节点样本权重之和的最小值) 默认为1e-3

  • min_split_gain 或者min_gain_to_split: 一个浮点数,表示执行切分的最小增益,默认为0

  • n_estimators 迭代次数

  • num_leaves或者num_leaf:一个整数,给出了一棵树上的叶子数。默认为 31

  • objective: 目标函数

  • lambda_l1 或者reg_alpha: 一个浮点数,表示L1正则化系数。默认为0

  • lambda_l2 或者reg_lambda: 一个浮点数,表示L2正则化系数。默认为0

  • silent=True: 训练过程是否打印日志信息

  • bagging_fraction 或者sub_row 或者 subsample:一个浮点数,取值范围为[0.0,1.0], 默认值为0。如果小于1.0,则lightgbm 会在每次迭代中随机选择部分样本来训练(非重复采样)。如0.8 表示:在每棵树训练之前选择80% 的样本(非重复采样)来训练

  • bin_construct_sample_cnt 或者 subsample_for_bin:一个整数,表示用来构建直方图的样本的数量。默认为200000。如果数据非常稀疏,则可以设置为一个更大的值,如果设置更大的值,则会提供更好的训练效果,但是会增加数据加载时间。

  • use_missing: 一个布尔值,表示是否使用缺失值功能。默认为True如果为False 则禁用缺失值功能。

泰坦尼克号乘客生存率分析

traindata_path = u'D:/01_Project/99_test/ML/titanic/train.csv'
testdata_path = u'D:/01_Project/99_test/ML/titanic/test.csv'
testresult_path = u'D:/01_Project/99_test/ML/titanic/gender_submission.csv'
df_train = pd.read_csv(traindata_path)
df_test = pd.read_csv(testdata_path)
df_test['Survived'] = pd.read_csv(testresult_path)['Survived']
data_original = pd.concat([df_train,df_test],sort=False)
display (data_original.head(5))
PassengerIdSurvivedPclassNameSexAgeSibSpParchTicketFareCabinEmbarked
0103Braund, Mr. Owen Harrismale22.010A/5 211717.2500NaNS
1211Cumings, Mrs. John Bradley (Florence Briggs Th...female38.010PC 1759971.2833C85C
2313Heikkinen, Miss. Lainafemale26.000STON/O2. 31012827.9250NaNS
3411Futrelle, Mrs. Jacques Heath (Lily May Peel)female35.01011380353.1000C123S
4503Allen, Mr. William Henrymale35.0003734508.0500NaNS

字段注释

  • PassengerId => 乘客ID
  • Pclass => 乘客等级(1/2/3等舱位)
  • Name => 乘客姓名
  • Sex => 性别
  • Age => 年龄
  • SibSp => 堂兄弟/妹个数
  • Parch => 父母与小孩个数
  • Ticket => 船票信息
  • Fare => 票价
  • Cabin => 客舱
  • Embarked => 登船港口
# 查看数据
data_original.info()
<class 'pandas.core.frame.DataFrame'>
Int64Index: 1309 entries, 0 to 417
Data columns (total 12 columns):
PassengerId    1309 non-null int64
Survived       1309 non-null int64
Pclass         1309 non-null int64
Name           1309 non-null object
Sex            1309 non-null object
Age            1046 non-null float64
SibSp          1309 non-null int64
Parch          1309 non-null int64
Ticket         1309 non-null object
Fare           1308 non-null float64
Cabin          295 non-null object
Embarked       1307 non-null object
dtypes: float64(2), int64(5), object(5)
memory usage: 132.9+ KB
features = list(data_original.columns[data_original.dtypes != 'object'])
print (features)
['PassengerId', 'Survived', 'Pclass', 'Age', 'SibSp', 'Parch', 'Fare']
# 查看数字类型特征分布
data_original[features].describe()
PassengerIdSurvivedPclassAgeSibSpParchFare
count1309.0000001309.0000001309.0000001046.0000001309.0000001309.0000001308.000000
mean655.0000000.3773872.29488229.8811380.4988540.38502733.295479
std378.0200610.4849180.83783614.4134931.0416580.86556051.758668
min1.0000000.0000001.0000000.1700000.0000000.0000000.000000
25%328.0000000.0000002.00000021.0000000.0000000.0000007.895800
50%655.0000000.0000003.00000028.0000000.0000000.00000014.454200
75%982.0000001.0000003.00000039.0000001.0000000.00000031.275000
max1309.0000001.0000003.00000080.0000008.0000009.000000512.329200
# 查看类别特征值分布
print (data_original['Sex'].value_counts())
print (data_original['Embarked'].value_counts())
male      843
female    466
Name: Sex, dtype: int64
S    914
C    270
Q    123
Name: Embarked, dtype: int64

对类型特征进行转换

# 类别特征独热编码
data_onehot = pd.get_dummies(data_original,columns=['Sex','Embarked'])
# data_onehot = data_original.copy()
# data_original['Sex'].replace('male',0,inplace=True)   #inplace=True 替换
data_onehot.head()
PassengerIdSurvivedPclassNameAgeSibSpParchTicketFareCabinSex_femaleSex_maleEmbarked_CEmbarked_QEmbarked_S
0103Braund, Mr. Owen Harris22.010A/5 211717.2500NaN01001
1211Cumings, Mrs. John Bradley (Florence Briggs Th...38.010PC 1759971.2833C8510100
2313Heikkinen, Miss. Laina26.000STON/O2. 31012827.9250NaN10001
3411Futrelle, Mrs. Jacques Heath (Lily May Peel)35.01011380353.1000C12310001
4503Allen, Mr. William Henry35.0003734508.0500NaN01001
# 剔除非训练特征
drop_features = ['PassengerId', 'Survived', 'Name','Ticket','Cabin']
features_filted = list(data_onehot.columns.values)
for feature in drop_features:
    features_filted.remove(feature)
# features_filted = list(set(features_filted) - set(drop_features))
print (features_filted)

# 划分训练集和验证集
x_train, x_test, y_train, y_test = train_test_split(data_onehot[features_filted], data_onehot['Survived'], random_state=1, train_size=0.7)
display(x_train.shape)
display(x_test.shape)
display(y_train.shape)
display(y_train.shape)
['Pclass', 'Age', 'SibSp', 'Parch', 'Fare', 'Sex_female', 'Sex_male', 'Embarked_C', 'Embarked_Q', 'Embarked_S']

(916, 10)
(393, 10)
(916,)
(916,)

模型训练

model.fit(x_train, y_train,eval_set=[(x_train,y_train),(x_test,y_test)],
       eval_metric=['logloss','auc'],early_stopping_rounds=10,verbose=True)
[1]	training's binary_logloss: 0.595048	training's auc: 0.942605	valid_1's binary_logloss: 0.633623	valid_1's auc: 0.851209
Training until validation scores don't improve for 10 rounds
[2]	training's binary_logloss: 0.548373	training's auc: 0.945537	valid_1's binary_logloss: 0.592688	valid_1's auc: 0.851515
[3]	training's binary_logloss: 0.510333	training's auc: 0.945868	valid_1's binary_logloss: 0.560726	valid_1's auc: 0.849079
[4]	training's binary_logloss: 0.478709	training's auc: 0.948622	valid_1's binary_logloss: 0.534067	valid_1's auc: 0.853446
[5]	training's binary_logloss: 0.451073	training's auc: 0.950018	valid_1's binary_logloss: 0.51243	valid_1's auc: 0.854005
[6]	training's binary_logloss: 0.427676	training's auc: 0.951412	valid_1's binary_logloss: 0.49401	valid_1's auc: 0.858917
[7]	training's binary_logloss: 0.407227	training's auc: 0.952531	valid_1's binary_logloss: 0.479078	valid_1's auc: 0.86363
[8]	training's binary_logloss: 0.389879	training's auc: 0.953206	valid_1's binary_logloss: 0.466908	valid_1's auc: 0.864682
[9]	training's binary_logloss: 0.374723	training's auc: 0.953082	valid_1's binary_logloss: 0.456975	valid_1's auc: 0.865215
[10]	training's binary_logloss: 0.361178	training's auc: 0.954225	valid_1's binary_logloss: 0.446911	valid_1's auc: 0.869328
[11]	training's binary_logloss: 0.349364	training's auc: 0.954401	valid_1's binary_logloss: 0.44059	valid_1's auc: 0.868982
[12]	training's binary_logloss: 0.338051	training's auc: 0.955169	valid_1's binary_logloss: 0.435057	valid_1's auc: 0.869275
[13]	training's binary_logloss: 0.328267	training's auc: 0.956474	valid_1's binary_logloss: 0.429733	valid_1's auc: 0.870247
[14]	training's binary_logloss: 0.318925	training's auc: 0.957555	valid_1's binary_logloss: 0.427283	valid_1's auc: 0.870314
[15]	training's binary_logloss: 0.311077	training's auc: 0.958657	valid_1's binary_logloss: 0.42482	valid_1's auc: 0.869488
[16]	training's binary_logloss: 0.304002	training's auc: 0.959259	valid_1's binary_logloss: 0.423752	valid_1's auc: 0.867877
[17]	training's binary_logloss: 0.296665	training's auc: 0.960373	valid_1's binary_logloss: 0.420685	valid_1's auc: 0.870247
[18]	training's binary_logloss: 0.291046	training's auc: 0.960769	valid_1's binary_logloss: 0.421158	valid_1's auc: 0.86833
[19]	training's binary_logloss: 0.285206	training's auc: 0.961702	valid_1's binary_logloss: 0.421062	valid_1's auc: 0.867185
[20]	training's binary_logloss: 0.279439	training's auc: 0.962313	valid_1's binary_logloss: 0.421684	valid_1's auc: 0.866812
[21]	training's binary_logloss: 0.274464	training's auc: 0.962969	valid_1's binary_logloss: 0.423581	valid_1's auc: 0.86576
[22]	training's binary_logloss: 0.269943	training's auc: 0.963476	valid_1's binary_logloss: 0.423831	valid_1's auc: 0.866479
[23]	training's binary_logloss: 0.264949	training's auc: 0.964407	valid_1's binary_logloss: 0.423879	valid_1's auc: 0.867238
[24]	training's binary_logloss: 0.260584	training's auc: 0.965294	valid_1's binary_logloss: 0.425674	valid_1's auc: 0.866519
Early stopping, best iteration is:
[14]	training's binary_logloss: 0.318925	training's auc: 0.957555	valid_1's binary_logloss: 0.427283	valid_1's auc: 0.870314





LGBMClassifier(n_estimators=200, n_jobs=1)

参数说明

  • early_stopping_rounds:在连续加入10棵树之后,每一次模型的损失函数都没有下降,这时候停止加树,有监控作用

  • eval_set:进行测试的数据集

  • verbose=False不打印训练过程

  • objective 目标函数

\quad 回归任务

\qquad reg:linear (默认)

\qquad reg:logistic

\quad 二分类

\qquad binary:logistic 概率

\qquad binary:logitraw 类别

\quad 多分类

\qquad multi:softmax num_class=n 返回类别

\qquad multi:softprob num_class=n 返回概率

\qquad rank:pairwise

  • eval_metric

\quad 回归任务(默认rmse)

\qquad rmse–均方根误差

\qquad mae–平均绝对误差

\quad 分类任务(默认error)

\qquad auc–roc曲线下面积

\qquad error–错误率(二分类)

\qquad merror–错误率(多分类)

\qquad logloss–负对数似然函数(二分类)

\qquad mlogloss–负对数似然函数(多分类)

特征重要性

importance_df = pd.DataFrame({
    'features':x_train.columns.values,
    'importance':model.feature_importances_.tolist()
})
importance_df = importance_df.sort_values('importance',ascending=False)
importance_df
featuresimportance
1Age171
4Fare151
0Pclass26
2SibSp18
7Embarked_C16
5Sex_female14
3Parch11
9Embarked_S6
8Embarked_Q5
6Sex_male0
import seaborn as sns
import matplotlib.pyplot as plt
plt.figure(figsize=(10,6))
sns.barplot(importance_df['importance'][:20],importance_df['features'][:20])
plt.show()

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-eulNUELN-1616579576196)(output_19_0.png)]

混淆矩阵

pred_y_test = model.predict(x_test)
# m = metrics.confusion_matrix(y_test, pred_y_test)
# display (m)
tn, fp, fn, tp = metrics.confusion_matrix(y_test, pred_y_test).ravel()
print ('matrix    label1   label0')
print ('predict1  {:<6d}   {:<6d}'.format(int(tp), int(fp)))
print ('predict0  {:<6d}   {:<6d}'.format(int(fn), int(tn)))
matrix    label1   label0
predict1  116      15    
predict0  48       214   

交叉验证

验证模型得分

score_x = x_train
score_y = y_train
# 正确率
scores = cross_val_score(model, score_x, score_y, cv=5, scoring='accuracy')
print('交叉验证正确率为:'+str(scores.mean()))  
交叉验证正确率为:0.8394986932763127
# 精确率
scores = cross_val_score(model, score_x, score_y, cv=5, scoring='precision')
print('交叉验证精确率为:'+str(scores.mean()))  
交叉验证精确率为:0.8043572452360251
# 召回率
scores = cross_val_score(model, score_x, score_y, cv=5, scoring='recall')
print('交叉验证召回率为:'+str(scores.mean()))  
交叉验证召回率为:0.7424242424242424
# f1_score
scores = cross_val_score(model, score_x, score_y, cv=5, scoring='f1')
print('交叉验证f1_score为:'+str(scores.mean()))  
交叉验证f1_score为:0.7695365479961331

TopN

当样本不均衡且比较关注召回率时使用TopN来评估模型,泰坦尼克号乘客生存率预测不适合用TopN来评判模型预测好还。

ratio_list = [0.01,0.02,0.05,0.1,0.2]
test_label = pd.DataFrame(y_test)
index_of_label1 = model.classes_.tolist().index(1)
pred_y_test = model.predict(x_test)
proba_y_test = model.predict_proba(x_test)
test_label['predict'] = pred_y_test
test_label['label_1'] = proba_y_test[:,index_of_label1]
display (test_label.head())

label_1_nbr = len(test_label[test_label['Survived']==1])
print ('label_1_nbr:',label_1_nbr)
print ('sample number:',len(test_label))

for ratio in ratio_list:
    num = test_label.sort_values('label_1',ascending=False)[:int(ratio*test_label.shape[0])]['Survived'].sum()
    count = test_label.sort_values('label_1',ascending=False)[:int(ratio*test_label.shape[0])]['Survived'].count()
    print ('Top %.2f label_1_nbr:%d,sample_nbr:%d,recall:%f'%(ratio,num,count,1.0*num/label_1_nbr))
Survivedpredictlabel_1
201000.111261
115000.163787
255100.451979
212000.192941
195110.852887
label_1_nbr: 164
sample number: 393
Top 0.01 label_1_nbr:3,sample_nbr:3,recall:0.018293
Top 0.02 label_1_nbr:7,sample_nbr:7,recall:0.042683
Top 0.05 label_1_nbr:19,sample_nbr:19,recall:0.115854
Top 0.10 label_1_nbr:38,sample_nbr:39,recall:0.231707
Top 0.20 label_1_nbr:73,sample_nbr:78,recall:0.445122

网格搜索最佳参数

param_grid = [
{'n_estimators': [3, 10, 30],'learning_rate': [0.01,0.05,0.1]}
]

clf = lgb.LGBMClassifier()
grid_search = GridSearchCV(clf, param_grid, cv=5,scoring='neg_mean_squared_error')
grid_search.fit(x_train, y_train)
print (grid_search.best_params_)
print (grid_search.best_estimator_)
{'learning_rate': 0.1, 'n_estimators': 30}
LGBMClassifier(n_estimators=30)

Mnist 手写数字识别

# 加载数据
traindata_path = u'D:/01_Project/99_test/ML/mnist/mnist_train.csv'
testdata_path = u'D:/01_Project/99_test/ML/mnist/mnist_test.csv'
# testresult_path = u'D:/01_Project/99_test/ML/titanic/gender_submission.csv'
df_train = pd.read_csv(traindata_path)
x_train = df_train.iloc[:,1:]
y_train = df_train.iloc[:,0]
print (x_train.shape)
display (y_train[:5])
display (x_train.head(5))

df_test = pd.read_csv(testdata_path)
x_test = df_test.iloc[:,1:]
y_test = df_test.iloc[:,0]
print (x_test.shape)
display (y_test[:5])
display (x_test.head(5))
(59999, 784)
00.10.20.30.40.50.60.70.80.90.100.110.120.130.140.150.160.170.180.190.200.210.220.230.240.250.260.270.280.290.300.310.320.330.340.350.360.370.380.390.400.410.420.430.440.450.460.470.480.490.500.510.520.530.540.550.560.570.580.590.600.610.620.630.640.650.660.670.680.690.700.710.720.730.740.750.760.770.780.790.800.810.820.830.840.850.860.870.880.890.900.910.920.930.940.950.960.970.980.99...0.5180.5190.5200.5210.5220.5230.5240.5250.5260.5270.5280.5290.5300.5310.5320.5330.5340.5350.5360.5370.5380.5390.5400.5410.5420.5430.5440.5450.5460.5470.5480.5490.5500.5510.5520.5530.5540.5550.5560.5570.5580.5590.5600.5610.5620.5630.5640.5650.5660.5670.5680.5690.5700.5710.5720.5730.5740.5750.5760.5770.5780.5790.5800.5810.5820.5830.5840.5850.5860.5870.5880.5890.5900.5910.5920.5930.5940.5950.5960.5970.5980.5990.6000.6010.6020.6030.6040.6050.6060.6070.6080.6090.6100.6110.6120.6130.6140.6150.6160.617
00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000...0000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000
10000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000...000009625415300000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000
20000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000...0000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000
30000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000...009625218942000000000000000000000000141842521701100000000000000000000000014147252420000000000000000000000000000000000000
40000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000...0000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000

5 rows × 784 columns

(9999, 784)
00.10.20.30.40.50.60.70.80.90.100.110.120.130.140.150.160.170.180.190.200.210.220.230.240.250.260.270.280.290.300.310.320.330.340.350.360.370.380.390.400.410.420.430.440.450.460.470.480.490.500.510.520.530.540.550.560.570.580.590.600.610.620.630.640.650.660.670.680.690.700.710.720.730.740.750.760.770.780.790.800.810.820.830.840.850.860.870.880.890.900.910.920.930.940.950.960.970.980.99...254.22254.2352.20.5800.5810.5820.5830.5840.5850.5860.5870.5880.5890.5900.5910.5920.5930.5940.5950.5960.5970.5980.5990.6000.6010.602121254.24254.25219.1400.6030.6040.6050.6060.6070.6080.6090.6100.6110.6120.6130.6140.6150.6160.6170.6180.6190.6200.6210.6220.6230.6240.625121.1254.2620718.10.6260.6270.6280.6290.6300.6310.6320.6330.6340.6350.6360.6370.6380.6390.6400.6410.6420.6430.6440.6450.6460.6470.6480.6490.6500.6510.6520.6530.6540.6550.6560.6570.6580.6590.6600.6610.6620.6630.6640.6650.6660.667
00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000116125171255255150...0000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000
10000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000...0000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000
20000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000...0000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000
30000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000...0000179660000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000
40000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000...15710700000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000

5 rows × 784 columns

# 展示图片
img = np.array(x_train.iloc[0,:]).reshape(28, 28)
plt.imshow(img, cmap='Greys', interpolation='nearest')
plt.show();

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-scHuY8ZF-1616579576202)(output_34_0.png)]

# 模型参数设定
model = lgb.LGBMClassifier(boosting_type='gbdt'
                          ,class_weight=None
                          ,colsample_bytree=1.0
                          ,importance_type='split'
                          ,learning_rate=0.1
                          ,max_depth=-1
                          ,min_child_samples=20
                          ,min_child_weight=0.001
                          ,min_split_gain=0.0
                          ,n_estimators=100
                          ,n_jobs=1
                          ,num_leaves=31
                          ,objective='multi:softmax'
                          ,random_state=None
                          ,reg_alpha=0.0
                          ,reg_lambda=0.0
                          ,silent=True
                          ,subsample=1.0
                          ,subsample_for_bin=200000)
model.fit(x_train, y_train,eval_set=[(x_train,y_train),(x_test,y_test)],
       eval_metric=['logloss'],early_stopping_rounds=20,verbose=True)
[1]	training's multi_logloss: 1.67939	valid_1's multi_logloss: 1.68286
Training until validation scores don't improve for 20 rounds
[2]	training's multi_logloss: 1.37326	valid_1's multi_logloss: 1.38176
[3]	training's multi_logloss: 1.15968	valid_1's multi_logloss: 1.17275
[4]	training's multi_logloss: 0.997105	valid_1's multi_logloss: 1.01395
[5]	training's multi_logloss: 0.867815	valid_1's multi_logloss: 0.88653

.
.
.
[196] training’s multi_logloss: 0.000162272 valid_1’s multi_logloss: 0.0638985
[197] training’s multi_logloss: 0.000156578 valid_1’s multi_logloss: 0.0639573
[198] training’s multi_logloss: 0.000151258 valid_1’s multi_logloss: 0.0640798
[199] training’s multi_logloss: 0.000145979 valid_1’s multi_logloss: 0.0640982
Early stopping, best iteration is:
[179] training’s multi_logloss: 0.000294704 valid_1’s multi_logloss: 0.0635952

LGBMClassifier(n_estimators=200, n_jobs=1, objective='multi:softmax')
pred_y_test = model.predict(x_test)
m = metrics.confusion_matrix(y_test, pred_y_test)
display (m)

array([[ 969,    0,    0,    0,    0,    2,    3,    1,    5,    0],
       [   0, 1125,    2,    3,    0,    1,    1,    1,    2,    0],
       [   3,    0, 1012,    5,    1,    0,    1,    6,    4,    0],
       [   0,    0,    3,  992,    0,    3,    0,    8,    4,    0],
       [   0,    0,    5,    0,  962,    0,    3,    0,    2,   10],
       [   2,    0,    1,    8,    0,  869,    7,    2,    2,    1],
       [   5,    2,    0,    0,    2,    7,  938,    0,    4,    0],
       [   1,    0,   12,    3,    2,    0,    0, 1003,    1,    5],
       [   4,    0,    2,    0,    3,    0,    0,    2,  958,    5],
       [   4,    4,    1,    7,    6,    1,    0,    4,    1,  981]],
      dtype=int64)

多分类评估

  • precision_score
  • accuracy_score
  • recall_score
  • f1_score

\quad 二分类时average参数默认是binary,多分类时,可选参数有micro、macro、weighted和samples

\quad micro算法是指把所有的类放在一起算,具体到precision,就是把所有类的TP加和,再除以所有类的TP和FN的加和。因此micro方法下的precision和recall都等于accuracy

\quad macro方法就是先分别求出每个类的precision再算术平均

\quad 前面提到的macro算法是取算术平均,weighted算法就是在macro算法的改良版,不再是取算术平均、乘以固定weight(也就是1/3)了,而是乘以该类在总样本数中的占比

print ('precision_score:',metrics.precision_score(y_test,pred_y_test,labels=None,pos_label=1,average='weighted',sample_weight=None))
print ('accuracy_score:',metrics.accuracy_score(y_test,pred_y_test))
print ('recall_score:',metrics.recall_score(y_test,pred_y_test,labels=None,pos_label=1,average='weighted',sample_weight=None))
print ('f1_score:',metrics.f1_score(y_test,pred_y_test,labels=None,pos_label=1,average='weighted',sample_weight=None))
precision_score: 0.9810242986297515
accuracy_score: 0.980998099809981
recall_score: 0.980998099809981
f1_score: 0.9809998617011098

参考链接:

  • 4
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 5
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 5
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值