MachineLearning 15. 机器学习之集成分类器(LogitBoost)

c9725063443ff2abf447437fcecda97a.png

这期介绍一下NB的最佳集成分类方法之一 LogitBoost,并实现在具体数据集上的应用,尤其是临床数据。


简    介

2019年国际核心脏病学和心脏 CT 会议(ICNC)发表一项研究称,机器学习能够以90%以上的准确率预测死亡或心脏病发作。

随着大数据时代的到来,我们的生活渐渐依赖于人工智能。而这一切都得益于于人工智能对个人大量数据的追踪和算法技术的精确画像才能实现。近年来,人工智能也慢慢涉猎到医学界中。

上一期我们介绍了AdaBoost,这种方法对于误分类点的惩罚过大造成对异常点敏感,并且不能预测类别的概率。LogitBoost算法可以弥补这两个缺憾,它采样前向分步可加模型,损失函数选取为对数损失函数,每一步迭代的时候是考虑牛顿下降法应用于误差函数来更新模型。

be3f6a2ac9d94914e9f89ea33b732ea0.png

软件安装
if(!require('caTools')) {
  install.packages('caTools')
}
## 载入需要的程辑包:caTools

数据读取

library("caTools")
library(caret)
library(tidyverse)
BreastCancer <- read.csv("wisc_bc_data.csv", stringsAsFactors = FALSE)
dim(BreastCancer)
## [1] 568  32
str(BreastCancer)
## 'data.frame':	568 obs. of  32 variables:
##  $ id                     : int  842517 84300903 84348301 84358402 843786 844359 84458202 844981 84501001 845636 ...
##  $ diagnosis              : chr  "M" "M" "M" "M" ...
##  $ radius_mean            : num  20.6 19.7 11.4 20.3 12.4 ...
##  $ texture_mean           : num  17.8 21.2 20.4 14.3 15.7 ...
##  $ perimeter_mean         : num  132.9 130 77.6 135.1 82.6 ...
##  $ area_mean              : num  1326 1203 386 1297 477 ...
##  $ smoothness_mean        : num  0.0847 0.1096 0.1425 0.1003 0.1278 ...
##  $ compactne_mean         : num  0.0786 0.1599 0.2839 0.1328 0.17 ...
##  $ concavity_mean         : num  0.0869 0.1974 0.2414 0.198 0.1578 ...
##  $ concave_points_mean    : num  0.0702 0.1279 0.1052 0.1043 0.0809 ...
##  $ symmetry_mean          : num  0.181 0.207 0.26 0.181 0.209 ...
##  $ fractal_dimension_mean : num  0.0567 0.06 0.0974 0.0588 0.0761 ...
##  $ radius_se              : num  0.543 0.746 0.496 0.757 0.335 ...
##  $ texture_se             : num  0.734 0.787 1.156 0.781 0.89 ...
##  $ perimeter_se           : num  3.4 4.58 3.44 5.44 2.22 ...
##  $ area_se                : num  74.1 94 27.2 94.4 27.2 ...
##  $ smoothness_se          : num  0.00522 0.00615 0.00911 0.01149 0.00751 ...
##  $ compactne_se           : num  0.0131 0.0401 0.0746 0.0246 0.0335 ...
##  $ concavity_se           : num  0.0186 0.0383 0.0566 0.0569 0.0367 ...
##  $ concave_points_se      : num  0.0134 0.0206 0.0187 0.0188 0.0114 ...
##  $ symmetry_se            : num  0.0139 0.0225 0.0596 0.0176 0.0216 ...
##  $ fractal_dimension_se   : num  0.00353 0.00457 0.00921 0.00511 0.00508 ...
##  $ radius_worst           : num  25 23.6 14.9 22.5 15.5 ...
##  $ texture_worst          : num  23.4 25.5 26.5 16.7 23.8 ...
##  $ perimeter_worst        : num  158.8 152.5 98.9 152.2 103.4 ...
##  $ area_worst             : num  1956 1709 568 1575 742 ...
##  $ smoothness_worst       : num  0.124 0.144 0.21 0.137 0.179 ...
##  $ compactne_worst        : num  0.187 0.424 0.866 0.205 0.525 ...
##  $ concavity_worst        : num  0.242 0.45 0.687 0.4 0.535 ...
##  $ concave_points_worst   : num  0.186 0.243 0.258 0.163 0.174 ...
##  $ symmetry_worst         : num  0.275 0.361 0.664 0.236 0.399 ...
##  $ fractal_dimension_worst: num  0.089 0.0876 0.173 0.0768 0.1244 ...
table(BreastCancer$diagnosis)
## 
##   B   M 
## 357 211

data <- select(BreastCancer, -1) %>%
    mutate_at("diagnosis", as.factor)
sum(is.na(data))
## [1] 0

实例操作

数据分割

library(sampling)
set.seed(123)
# 每层抽取70%的数据
train_id <- strata(data, "diagnosis", size = rev(round(table(data$diagnosis) * 0.7)))$ID_unit
# 训练数据
trainData <- data[train_id, ]
# 测试数据
testData <- data[-train_id, ]

# 查看训练、测试数据中正负样本比例
prop.table(table(trainData$diagnosis))
## 
##         B         M 
## 0.6281407 0.3718593

prop.table(table(testData$diagnosis))
## 
##         B         M 
## 0.6294118 0.3705882

prop.table(table(BreastCancer$diagnosis))
## 
##         B         M 
## 0.6285211 0.3714789

这里我们选择之前在分析机器学习是曾经使用过的数据集:BreastCancer,可以跟我们之前的方法对比一下:

可视化变量的重要性

我们可以使用featurePlot()函数可视化每个自变量的取值范围以及不同分类比较等问题。

对于分类模型选择:box, strip, density, pairs or ellipse

对于回归模型选择:pairs or scatter

#4. How to visualize the importance of variables using featurePlot()
featurePlot(x = trainData[, 2:31], 
            y = as.factor(trainData$diagnosis), 
            plot = "box", #For classification:box, strip, density, pairs or ellipse. For regression, pairs or scatter
            strip=strip.custom(par.strip.text=list(cex=.7)),
            scales = list(x = list(relation="free"), 
                          y = list(relation="free"))
)

05104b6d933314d6d62cfc4116b3d51c.png

定义测试参数

fitControl <- trainControl(
  method = 'cv',                   # k-fold cross validation
  number = 5,                      # number of folds
  savePredictions = 'final',       # saves predictions for optimal tuning parameter
  classProbs = T,                  # should class probabilities be returned
  summaryFunction=twoClassSummary  # results summary function
)

构建Logitboost分类器

统计学软件包caret中的函数train()都有哪些算法可以使用?如下:

names(getModelInfo())
##   [1] "ada"                 "AdaBag"              "AdaBoost.M1"        
##   [4] "adaboost"            "amdai"               "ANFIS"              
##   [7] "avNNet"              "awnb"                "awtan"              
##  [10] "bag"                 "bagEarth"            "bagEarthGCV"        
##  [13] "bagFDA"              "bagFDAGCV"           "bam"                
##  [16] "bartMachine"         "bayesglm"            "binda"              
##  [19] "blackboost"          "blasso"              "blassoAveraged"     
##  [22] "bridge"              "brnn"                "BstLm"              
##  [25] "bstSm"               "bstTree"             "C5.0"               
##  [28] "C5.0Cost"            "C5.0Rules"           "C5.0Tree"           
##  [31] "cforest"             "chaid"               "CSimca"             
##  [34] "ctree"               "ctree2"              "cubist"             
##  [37] "dda"                 "deepboost"           "DENFIS"             
##  [40] "dnn"                 "dwdLinear"           "dwdPoly"            
##  [43] "dwdRadial"           "earth"               "elm"                
##  [46] "enet"                "evtree"              "extraTrees"         
##  [49] "fda"                 "FH.GBML"             "FIR.DM"             
##  [52] "foba"                "FRBCS.CHI"           "FRBCS.W"            
##  [55] "FS.HGD"              "gam"                 "gamboost"           
##  [58] "gamLoess"            "gamSpline"           "gaussprLinear"      
##  [61] "gaussprPoly"         "gaussprRadial"       "gbm_h2o"            
##  [64] "gbm"                 "gcvEarth"            "GFS.FR.MOGUL"       
##  [67] "GFS.LT.RS"           "GFS.THRIFT"          "glm.nb"             
##  [70] "glm"                 "glmboost"            "glmnet_h2o"         
##  [73] "glmnet"              "glmStepAIC"          "gpls"               
##  [76] "hda"                 "hdda"                "hdrda"              
##  [79] "HYFIS"               "icr"                 "J48"                
##  [82] "JRip"                "kernelpls"           "kknn"               
##  [85] "knn"                 "krlsPoly"            "krlsRadial"         
##  [88] "lars"                "lars2"               "lasso"              
##  [91] "lda"                 "lda2"                "leapBackward"       
##  [94] "leapForward"         "leapSeq"             "Linda"              
##  [97] "lm"                  "lmStepAIC"           "LMT"                
## [100] "loclda"              "logicBag"            "LogitBoost"         
## [103] "logreg"              "lssvmLinear"         "lssvmPoly"          
## [106] "lssvmRadial"         "lvq"                 "M5"                 
## [109] "M5Rules"             "manb"                "mda"                
## [112] "Mlda"                "mlp"                 "mlpKerasDecay"      
## [115] "mlpKerasDecayCost"   "mlpKerasDropout"     "mlpKerasDropoutCost"
## [118] "mlpML"               "mlpSGD"              "mlpWeightDecay"     
## [121] "mlpWeightDecayML"    "monmlp"              "msaenet"            
## [124] "multinom"            "mxnet"               "mxnetAdam"          
## [127] "naive_bayes"         "nb"                  "nbDiscrete"         
## [130] "nbSearch"            "neuralnet"           "nnet"               
## [133] "nnls"                "nodeHarvest"         "null"               
## [136] "OneR"                "ordinalNet"          "ordinalRF"          
## [139] "ORFlog"              "ORFpls"              "ORFridge"           
## [142] "ORFsvm"              "ownn"                "pam"                
## [145] "parRF"               "PART"                "partDSA"            
## [148] "pcaNNet"             "pcr"                 "pda"                
## [151] "pda2"                "penalized"           "PenalizedLDA"       
## [154] "plr"                 "pls"                 "plsRglm"            
## [157] "polr"                "ppr"                 "pre"                
## [160] "PRIM"                "protoclass"          "qda"                
## [163] "QdaCov"              "qrf"                 "qrnn"               
## [166] "randomGLM"           "ranger"              "rbf"                
## [169] "rbfDDA"              "Rborist"             "rda"                
## [172] "regLogistic"         "relaxo"              "rf"                 
## [175] "rFerns"              "RFlda"               "rfRules"            
## [178] "ridge"               "rlda"                "rlm"                
## [181] "rmda"                "rocc"                "rotationForest"     
## [184] "rotationForestCp"    "rpart"               "rpart1SE"           
## [187] "rpart2"              "rpartCost"           "rpartScore"         
## [190] "rqlasso"             "rqnc"                "RRF"                
## [193] "RRFglobal"           "rrlda"               "RSimca"             
## [196] "rvmLinear"           "rvmPoly"             "rvmRadial"          
## [199] "SBC"                 "sda"                 "sdwd"               
## [202] "simpls"              "SLAVE"               "slda"               
## [205] "smda"                "snn"                 "sparseLDA"          
## [208] "spikeslab"           "spls"                "stepLDA"            
## [211] "stepQDA"             "superpc"             "svmBoundrangeString"
## [214] "svmExpoString"       "svmLinear"           "svmLinear2"         
## [217] "svmLinear3"          "svmLinearWeights"    "svmLinearWeights2"  
## [220] "svmPoly"             "svmRadial"           "svmRadialCost"      
## [223] "svmRadialSigma"      "svmRadialWeights"    "svmSpectrumString"  
## [226] "tan"                 "tanSearch"           "treebag"            
## [229] "vbmpRadial"          "vglmAdjCat"          "vglmContRatio"      
## [232] "vglmCumulative"      "widekernelpls"       "WM"                 
## [235] "wsrf"                "xgbDART"             "xgbLinear"          
## [238] "xgbTree"             "xyf"
set.seed(2863)
model_LogitBoost <- train(diagnosis ~ ., data = trainData, method = "LogitBoost",
    tuneLength = 4, metric = "ROC", trControl = fitControl)
plot(model_LogitBoost, main = "LogitBoost")

0d4275a9d53cebaaae1db9b2751c983d.png

找到阈值

resample_stats <- thresholder(model_LogitBoost, threshold = seq(0.5, 1, by = 0.05),
    final = TRUE)
ggplot(resample_stats, aes(x = prob_threshold, y = J)) + geom_point()

7809497a3c79e986f7bb818f8ce59a72.png

ggplot(resample_stats, aes(x = prob_threshold, y = Dist)) + geom_point()

0ab12b0e799123f8d4811f1ddc4e3b3d.png

ggplot(resample_stats, aes(x = prob_threshold, y = Sensitivity)) + geom_point() +
    geom_point(aes(y = Specificity), col = "red")

0fc963c328d12224688680aaa043f995.png

计算变量重要性

对于回归和分类模型的变量需要确定一下变量的重要性,我们可以看到BreastCancer的数据集里面自变量有30个,需要我们进一步筛选一下重要的变量,或者说决定性变量,去掉那些无关紧要可有可无的变量,减少计算复杂度以及后期可操作性。

# 6.2 How to compute variable importance?
varimp_mars <- varImp(model_LogitBoost)
plot(varimp_mars, main = "Variable Importance with BreastCancer")

1942466e9c0ae0ca8d4b6dc3b1c7652e.png

计算混淆矩阵

对于分类模型的只需要看混淆矩阵比较清晰的看出来分类的正确性。

# 6.5. Confusion Matrix Compute the confusion matrix
predProb <- predict(model_LogitBoost, testData, type = "prob")
head(predProb)
##              B         M
## 1 1.026188e-10 1.0000000
## 2 9.110512e-04 0.9990889
## 3 4.742587e-02 0.9525741
## 4 4.139938e-08 1.0000000
## 5 1.670142e-05 0.9999833
## 6 3.059022e-07 0.9999997
predicted = predict(model_LogitBoost, testData)
testData$predProb = predProb$B
confusionMatrix(reference = testData$diagnosis, data = predicted, mode = "everything",
    positive = "B")
## Confusion Matrix and Statistics
## 
##           Reference
## Prediction   B   M
##          B 105   4
##          M   2  59
##                                           
##                Accuracy : 0.9647          
##                  95% CI : (0.9248, 0.9869)
##     No Information Rate : 0.6294          
##     P-Value [Acc > NIR] : <2e-16          
##                                           
##                   Kappa : 0.9238          
##                                           
##  Mcnemar's Test P-Value : 0.6831          
##                                           
##             Sensitivity : 0.9813          
##             Specificity : 0.9365          
##          Pos Pred Value : 0.9633          
##          Neg Pred Value : 0.9672          
##               Precision : 0.9633          
##                  Recall : 0.9813          
##                      F1 : 0.9722          
##              Prevalence : 0.6294          
##          Detection Rate : 0.6176          
##    Detection Prevalence : 0.6412          
##       Balanced Accuracy : 0.9589          
##                                           
##        'Positive' Class : B               
##

绘制ROC曲线

但是根据模型构建后需要进行准确性的评估我们就需要计算一下AUC,绘制ROC曲线来展示一下准确性。

library(ROCR)
pred = prediction(testData$predProb, testData$diagnosis)
perf = performance(pred, measure = "fpr", x.measure = "tpr")
plot(perf, lwd = 2, col = "blue", main = "ROC")
abline(a = 0, b = 1, col = 2, lwd = 1, lty = 2)

e74f0c33b20749791357d9bf91e629ce.png

构建随机森林和Adaboost分类器

# Train the model using rf
model_rf = train(diagnosis ~ ., data = trainData, method = "rf", tuneLength = 5,
    trControl = fitControl)
model_rf
## Random Forest 
## 
## 398 samples
##  30 predictor
##   2 classes: 'B', 'M' 
## 
## No pre-processing
## Resampling: Cross-Validated (5 fold) 
## Summary of sample sizes: 319, 318, 318, 318, 319 
## Resampling results across tuning parameters:
## 
##   mtry  ROC        Sens   Spec     
##    2    0.9913655  0.976  0.9124138
##    9    0.9890874  0.976  0.9193103
##   16    0.9884667  0.964  0.9193103
##   23    0.9861310  0.960  0.9193103
##   30    0.9868920  0.960  0.9259770
## 
## ROC was used to select the optimal model using the largest value.
## The final value used for the model was mtry = 2.
# Train the model using adaboost
model_adaboost = train(diagnosis ~ ., data = trainData, method = "adaboost", tuneLength = 2,
    trControl = fitControl)
model_adaboost
## AdaBoost Classification Trees 
## 
## 398 samples
##  30 predictor
##   2 classes: 'B', 'M' 
## 
## No pre-processing
## Resampling: Cross-Validated (5 fold) 
## Summary of sample sizes: 319, 318, 319, 318, 318 
## Resampling results across tuning parameters:
## 
##   nIter  method         ROC        Sens   Spec     
##    50    Adaboost.M1    0.9843356  0.960  0.9186207
##    50    Real adaboost  0.8821609  0.972  0.9050575
##   100    Adaboost.M1    0.9883678  0.972  0.9186207
##   100    Real adaboost  0.8758391  0.976  0.9119540
## 
## ROC was used to select the optimal model using the largest value.
## The final values used for the model were nIter = 100 and method = Adaboost.M1.

多个分类器比较

models_compare <- resamples(list(ADABOOST = model_adaboost, RF = model_rf, LOGITBOOST = model_LogitBoost))
summary(models_compare)
## 
## Call:
## summary.resamples(object = models_compare)
## 
## Models: ADABOOST, RF, LOGITBOOST 
## Number of resamples: 5 
## 
## ROC 
##                 Min.   1st Qu.    Median      Mean   3rd Qu.      Max. NA's
## ADABOOST   0.9800000 0.9820000 0.9886667 0.9883678 0.9951724 0.9960000    0
## RF         0.9760000 0.9862069 0.9973333 0.9913655 0.9986207 0.9986667    0
## LOGITBOOST 0.9573333 0.9951724 0.9958621 0.9895402 0.9993333 1.0000000    0
## 
## Sens 
##            Min. 1st Qu. Median  Mean 3rd Qu. Max. NA's
## ADABOOST   0.90    0.98   0.98 0.972    1.00    1    0
## RF         0.94    0.98   0.98 0.976    0.98    1    0
## LOGITBOOST 0.96    0.98   1.00 0.988    1.00    1    0
## 
## Spec 
##                 Min.   1st Qu.    Median      Mean   3rd Qu.      Max. NA's
## ADABOOST   0.8620690 0.9000000 0.9310345 0.9186207 0.9333333 0.9666667    0
## RF         0.8333333 0.8965517 0.9000000 0.9124138 0.9655172 0.9666667    0
## LOGITBOOST 0.8666667 0.8965517 0.9310345 0.9388506 1.0000000 1.0000000    0
# Draw box plots to compare models
scales <- list(x = list(relation = "free"), y = list(relation = "free"))
bwplot(models_compare, scales = scales)

92b53034e71a2d9b354e7fc4a160cf7a.png

基于机器学习构建临床预测模型

MachineLearning 1. 主成分分析(PCA)

MachineLearning 2. 因子分析(Factor Analysis)

MachineLearning 3. 聚类分析(Cluster Analysis)

MachineLearning 4. 癌症诊断方法之 K-邻近算法(KNN)

MachineLearning 5. 癌症诊断和分子分型方法之支持向量机(SVM)

MachineLearning 6. 癌症诊断机器学习之分类树(Classification Trees)

MachineLearning 7. 癌症诊断机器学习之回归树(Regression Trees)

MachineLearning 8. 癌症诊断机器学习之随机森林(Random Forest)

MachineLearning 9. 癌症诊断机器学习之梯度提升算法(Gradient Boosting)

MachineLearning 10. 癌症诊断机器学习之神经网络(Neural network)

MachineLearning 11. 机器学习之随机森林生存分析(randomForestSRC)

MachineLearning 14. 机器学习之集成分类器(AdaBoost)

Reference
  1. Dettling and Buhlmann (2002), Boosting for Tumor Classification of Gene Expression Data.

  2. Xavier Robin, Natacha Turck, Alexandre Hainard, et al. (2011) “pROC: an open-source package for R and S+ to analyze and compare ROC curves”. BMC Bioinformatics, 7, 77. doi: 10.1186/1471-2105-12-77.

  3. Jerome Friedman, Trevor Hastie, and Robert Tibshirani. “Additive Logistic Regression: A Statistical View of Boosting”. The Annals of Statistics. Volume 28, Number 2 (2000), pp. 337–374. JSTOR. Project Euclid.

号外号外,桓峰基因单细胞生信分析免费培训课程即将开始快来报名吧!

桓峰基因,铸造成功的您!

未来桓峰基因公众号将不间断的推出单细胞系列生信分析教程,

敬请期待!!

桓峰基因官网正式上线,请大家多多关注,还有很多不足之处,大家多多指正!http://www.kyohogene.com/

桓峰基因和投必得合作,文章润色优惠85折,需要文章润色的老师可以直接到网站输入领取桓峰基因专属优惠券码:KYOHOGENE,然后上传,付款时选择桓峰基因优惠券即可享受85折优惠哦!https://www.topeditsci.com/

cb547215f0aa30ab6272128bf9f148fe.png

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值