MachineLearning 18. 机器学习之贝叶斯分析类器(Naive Bayes)

04d120356ab34a86d7cb8f8df70005f7.png


简          介

    贝叶斯分类技术在众多分类技术中占有重要地位,也属于统计学分类的范畴,是一种非规则的分类方法,贝叶斯分类技术通过对已分类的样本子集进行训练,学习归纳出分类函数 (对离散变量的预测称作分类,对连续变量的分类称为回归),利用训练得到的分类器实现对未分类数据的分类。通过对比分析不同的分类算法,发现朴素贝叶斯分类算法(Naive Bayes),一种简单的贝叶斯分类算法,其应用效果比神经网络分类算法和判定树分类算法还要好,特别是待分类数据量非常大时,贝叶斯分类方法相较其他分类算法具有高准确率。

贝叶斯分类是统计学方法,它主要是基于贝叶斯定理。通过计算给定实例属于一个特定类的概率来对给定实例进行分类。贝叶斯分类具有以下特点:

(1)贝叶斯分类不把一个实例绝对的指派给某一种分类,而是通过计算得到实例属于某一分类的概率,具有最大概率的类就是该实例所属的分类; (2)一般情况下在贝叶斯分类中所有属性都潜在的对分类结果发挥作用,能够使所有的属性都参与到分类中;

(3)贝叶斯分类实例的属性可以是离散的、连续的,也可以是混合的。

贝叶斯方法因其在理论上给出了最小化误差的最优解决方法而被广泛应用于分类问题。在贝叶斯方法的基础上,提出了贝叶斯网络((Bayesian Network, BN)方法。朴素贝叶斯分类就是假定一个属性对于给定分类的影响独立于其他属性。这一假定被称作条件独立,对实力属性的这种假设大大简化了分类所需的计算量。大量的研究结果表明,虽然BN算法对属性结点之间的连接结构进行了限制,但是朴素贝叶斯的分类器的分类性能优于标准的贝叶斯网络分类器

035d3bdbe33247177e81b648ee6f6244.png

软件包安装

这里我们主要介绍caret,另外还有两个包同样可以实现GBM算法,软件包安装方法如下:

if(require(caret))
  install.packages("caret")

数据读取

这里我们选择之前在分析机器学习是曾经使用过的数据集:BreastCancer,可以跟我们之前的方法对比一下:

library(caret)
BreastCancer <- read.csv("wisc_bc_data.csv", stringsAsFactors = FALSE)
BreastCancer = BreastCancer[, -1]
dim(BreastCancer)
## [1] 568  31
str(BreastCancer)
## 'data.frame':	568 obs. of  31 variables:
##  $ diagnosis              : chr  "M" "M" "M" "M" ...
##  $ radius_mean            : num  20.6 19.7 11.4 20.3 12.4 ...
##  $ texture_mean           : num  17.8 21.2 20.4 14.3 15.7 ...
##  $ perimeter_mean         : num  132.9 130 77.6 135.1 82.6 ...
##  $ area_mean              : num  1326 1203 386 1297 477 ...
##  $ smoothness_mean        : num  0.0847 0.1096 0.1425 0.1003 0.1278 ...
##  $ compactne_mean         : num  0.0786 0.1599 0.2839 0.1328 0.17 ...
##  $ concavity_mean         : num  0.0869 0.1974 0.2414 0.198 0.1578 ...
##  $ concave_points_mean    : num  0.0702 0.1279 0.1052 0.1043 0.0809 ...
##  $ symmetry_mean          : num  0.181 0.207 0.26 0.181 0.209 ...
##  $ fractal_dimension_mean : num  0.0567 0.06 0.0974 0.0588 0.0761 ...
##  $ radius_se              : num  0.543 0.746 0.496 0.757 0.335 ...
##  $ texture_se             : num  0.734 0.787 1.156 0.781 0.89 ...
##  $ perimeter_se           : num  3.4 4.58 3.44 5.44 2.22 ...
##  $ area_se                : num  74.1 94 27.2 94.4 27.2 ...
##  $ smoothness_se          : num  0.00522 0.00615 0.00911 0.01149 0.00751 ...
##  $ compactne_se           : num  0.0131 0.0401 0.0746 0.0246 0.0335 ...
##  $ concavity_se           : num  0.0186 0.0383 0.0566 0.0569 0.0367 ...
##  $ concave_points_se      : num  0.0134 0.0206 0.0187 0.0188 0.0114 ...
##  $ symmetry_se            : num  0.0139 0.0225 0.0596 0.0176 0.0216 ...
##  $ fractal_dimension_se   : num  0.00353 0.00457 0.00921 0.00511 0.00508 ...
##  $ radius_worst           : num  25 23.6 14.9 22.5 15.5 ...
##  $ texture_worst          : num  23.4 25.5 26.5 16.7 23.8 ...
##  $ perimeter_worst        : num  158.8 152.5 98.9 152.2 103.4 ...
##  $ area_worst             : num  1956 1709 568 1575 742 ...
##  $ smoothness_worst       : num  0.124 0.144 0.21 0.137 0.179 ...
##  $ compactne_worst        : num  0.187 0.424 0.866 0.205 0.525 ...
##  $ concavity_worst        : num  0.242 0.45 0.687 0.4 0.535 ...
##  $ concave_points_worst   : num  0.186 0.243 0.258 0.163 0.174 ...
##  $ symmetry_worst         : num  0.275 0.361 0.664 0.236 0.399 ...
##  $ fractal_dimension_worst: num  0.089 0.0876 0.173 0.0768 0.1244 ...
table(BreastCancer$diagnosis)
## 
##   B   M 
## 357 211

基于机器学习构建临床预测模型

MachineLearning 1. 主成分分析(PCA)

MachineLearning 2. 因子分析(Factor Analysis)

MachineLearning 3. 聚类分析(Cluster Analysis)

MachineLearning 4. 癌症诊断方法之 K-邻近算法(KNN)

MachineLearning 5. 癌症诊断和分子分型方法之支持向量机(SVM)

MachineLearning 6. 癌症诊断机器学习之分类树(Classification Trees)

MachineLearning 7. 癌症诊断机器学习之回归树(Regression Trees)

MachineLearning 8. 癌症诊断机器学习之随机森林(Random Forest)

MachineLearning 9. 癌症诊断机器学习之梯度提升算法(Gradient Boosting)

MachineLearning 10. 癌症诊断机器学习之神经网络(Neural network)

MachineLearning 11. 机器学习之随机森林生存分析(randomForestSRC)

MachineLearning 12. 机器学习之降维方法t-SNE及可视化 (Rtsne)

MachineLearning 13. 机器学习之降维方法UMAP及可视化 (umap)

MachineLearning 14. 机器学习之集成分类器(AdaBoost)

MachineLearning 15. 机器学习之集成分类器(LogitBoost)

MachineLearning 16. 机器学习之梯度提升机(GBM)

MachineLearning 17. 机器学习之围绕中心点划分算法(PAM)

实例操作

### 数据预处理

数据预处理包括五个部,先判断数据是否有缺失,缺失数量,在进行如下步骤:

  1. 删除低方差的变量

  2. 删欧与其它自变最有很强相关性的变最

  3. 去除多重共线性

  4. 对数据标准化处理,并补足缺失值

  5. 特征筛选,递归特征消除法(RFE)

# 删除方差为0的变量
zerovar = nearZeroVar(BreastCancer[, -1])
zerovar
## integer(0)
# BreastCancer=BreastCancer[,-zerovar]

# 首先删除强相关的变量
descrCorr = cor(BreastCancer[, -1])
descrCorr[1:5, 1:5]
##                 radius_mean texture_mean perimeter_mean area_mean
## radius_mean       1.0000000   0.32938305      0.9978764 0.9873442
## texture_mean      0.3293830   1.00000000      0.3359176 0.3261929
## perimeter_mean    0.9978764   0.33591759      1.0000000 0.9865482
## area_mean         0.9873442   0.32619289      0.9865482 1.0000000
## smoothness_mean   0.1680940  -0.01776898      0.2045046 0.1748380
##                 smoothness_mean
## radius_mean          0.16809398
## texture_mean        -0.01776898
## perimeter_mean       0.20450464
## area_mean            0.17483805
## smoothness_mean      1.00000000
highCorr = findCorrelation(descrCorr, 0.9)
highCorr
##  [1]  7  8 23 21  3 24  1 13 14  2
BreastCancer = BreastCancer[, -(highCorr + 1)]
dim(BreastCancer)
## [1] 568  21
# 随后解决多重共线性,本例中不存在多重共线性问题
comboInfo = findLinearCombos(BreastCancer[, -1])
comboInfo
## $linearCombos
## list()
## 
## $remove
## NULL
# BreastCancer=BreastCancer[, -(comboInfo$remove+2)]
Process = preProcess(BreastCancer)
Process
## Created from 568 samples and 21 variables
## 
## Pre-processing:
##   - centered (20)
##   - ignored (1)
##   - scaled (20)
BreastCancer = predict(Process, BreastCancer)

特征选择

在进行数据挖掘时,我们并不需要将所有的自变量用来建模,而是从中选择若干最重要的变量,这称为特征选择(feature selection)。一种算法就是后向选择,即先将所有的变量都包括在模型中,然后计算其效能(如误差、预测精度)和变量重要排序,然后保留最重要的若干变量,再次计算效能,这样反复迭代,找出合适的自变量数目。这种算法的一个缺点在于可能会存在过度拟合,所以需要在此算法外再套上一个样本划分的循环。在caret包中的rfe命令可以完成这项任务。 functions是确定用什么样的模型进行自变量排序,包括:

  1. 随机森林rfFuncs,

  2. lmFuncs(线性回归),

  3. nbFuncs(朴素贝叶斯),

  4. treebagFuncs(装袋决策树),

  5. caretFuncs(自定义的训练模型)。

method是确定抽样方法,cv即交叉检验, 还有提升boot以及留一交叉检验LOOCV。

ctrl = rfeControl(functions = caretFuncs, method = "repeatedcv", verbose = FALSE,
    returnResamp = "final")
BreastCancer$diagnosis = as.factor(BreastCancer$diagnosis)
Profile = rfe(BreastCancer[, -1], BreastCancer$diagnosis, rfeControl = ctrl)
print(Profile)
## 
## Recursive feature selection
## 
## Outer resampling method: Cross-Validated (10 fold, repeated 1 times) 
## 
## Resampling performance over subset size:
## 
##  Variables Accuracy  Kappa AccuracySD KappaSD Selected
##          4   0.9348 0.8597    0.03471 0.07512         
##          8   0.9559 0.9040    0.02564 0.05713         
##         16   0.9594 0.9112    0.03271 0.07354        *
##         20   0.9576 0.9076    0.03177 0.07148         
## 
## The top 5 variables (out of 16):
##    concave_points_worst, area_mean, concavity_worst, radius_se, compactne_mean
plot(Profile)

584735ac3b3e1732fc61ade53ceda7cb.png

xyplot(Profile$results$Kappa ~ Profile$results$Variables, ylab = "Kappa", xlab = "Variables",
    type = c("g", "p", "l"), auto.key = TRUE)

08d9af24c805db8c0e7e6a798bbbc167.png

xyplot(Profile$results$Accuracy ~ Profile$results$Variables, ylab = "Accuracy", xlab = "Variables",
    type = c("g", "p", "l"), auto.key = TRUE)

5c2ab6a214e3a5d1135440a136e3e69a.png

数据分割

数据分割就是将数据分割为测试数据集和验证数据集,关于这个数据分割可以参考Topic 5. 样本量确定及分割,具体操作如下:

library(tidyverse)
library(sampling)
set.seed(123)
# 每层抽取70%的数据
train_id <- strata(BreastCancer, "diagnosis", size = rev(round(table(BreastCancer$diagnosis) *
    0.7)))$ID_unit
# 训练数据
trainData <- BreastCancer[train_id, ]
# 测试数据
testData <- BreastCancer[-train_id, ]

# 查看训练、测试数据中正负样本比例
prop.table(table(trainData$diagnosis))
## 
##         B         M 
## 0.6281407 0.3718593

prop.table(table(testData$diagnosis))
## 
##         B         M 
## 0.6294118 0.3705882

prop.table(table(BreastCancer$diagnosis))
## 
##         B         M 
## 0.6285211 0.3714789

可视化重要变量

我们可以使用featurePlot()函数可视化每个自变量的取值范围以及不同分类比较等问题。

对于分类模型选择:box, strip, density, pairs or ellipse

对于回归模型选择:pairs or scatter

#4. How to visualize the importance of variables using featurePlot()
featurePlot(x = trainData[, 2:21], 
            y = as.factor(trainData$diagnosis), 
            plot = "box", #For classification:box, strip, density, pairs or ellipse. For regression, pairs or scatter
            strip=strip.custom(par.strip.text=list(cex=.7)),
            scales = list(x = list(relation="free"), 
                          y = list(relation="free"))
)

035d8cfd84bba167ec35170c35a8dd41.png

定义测试参数

在正式训练前,首先需要使用trainControl函数定义模型训练参数,method确定多次交叉检验的抽样方法,number确定了划分的重数, repeats确定了反复次数。

fitControl <- trainControl(
  method = 'cv',                   # k-fold cross validation
  number = 5,                      # number of folds
  savePredictions = 'final',       # saves predictions for optimal tuning parameter
  classProbs = T,                  # should class probabilities be returned
  summaryFunction=twoClassSummary  # results summary function
)

构建NB分类器

使用train训练模型,本例中使用的时gbm算法,我们可以对一些参数进行手动调优,包括interaction.depth,n.trees,shrinkage,n.minobsinnode等参数,也可以使用默认参数

names(getModelInfo())
##   [1] "ada"                 "AdaBag"              "AdaBoost.M1"        
##   [4] "adaboost"            "amdai"               "ANFIS"              
##   [7] "avNNet"              "awnb"                "awtan"              
##  [10] "bag"                 "bagEarth"            "bagEarthGCV"        
##  [13] "bagFDA"              "bagFDAGCV"           "bam"                
##  [16] "bartMachine"         "bayesglm"            "binda"              
##  [19] "blackboost"          "blasso"              "blassoAveraged"     
##  [22] "bridge"              "brnn"                "BstLm"              
##  [25] "bstSm"               "bstTree"             "C5.0"               
##  [28] "C5.0Cost"            "C5.0Rules"           "C5.0Tree"           
##  [31] "cforest"             "chaid"               "CSimca"             
##  [34] "ctree"               "ctree2"              "cubist"             
##  [37] "dda"                 "deepboost"           "DENFIS"             
##  [40] "dnn"                 "dwdLinear"           "dwdPoly"            
##  [43] "dwdRadial"           "earth"               "elm"                
##  [46] "enet"                "evtree"              "extraTrees"         
##  [49] "fda"                 "FH.GBML"             "FIR.DM"             
##  [52] "foba"                "FRBCS.CHI"           "FRBCS.W"            
##  [55] "FS.HGD"              "gam"                 "gamboost"           
##  [58] "gamLoess"            "gamSpline"           "gaussprLinear"      
##  [61] "gaussprPoly"         "gaussprRadial"       "gbm_h2o"            
##  [64] "gbm"                 "gcvEarth"            "GFS.FR.MOGUL"       
##  [67] "GFS.LT.RS"           "GFS.THRIFT"          "glm.nb"             
##  [70] "glm"                 "glmboost"            "glmnet_h2o"         
##  [73] "glmnet"              "glmStepAIC"          "gpls"               
##  [76] "hda"                 "hdda"                "hdrda"              
##  [79] "HYFIS"               "icr"                 "J48"                
##  [82] "JRip"                "kernelpls"           "kknn"               
##  [85] "knn"                 "krlsPoly"            "krlsRadial"         
##  [88] "lars"                "lars2"               "lasso"              
##  [91] "lda"                 "lda2"                "leapBackward"       
##  [94] "leapForward"         "leapSeq"             "Linda"              
##  [97] "lm"                  "lmStepAIC"           "LMT"                
## [100] "loclda"              "logicBag"            "LogitBoost"         
## [103] "logreg"              "lssvmLinear"         "lssvmPoly"          
## [106] "lssvmRadial"         "lvq"                 "M5"                 
## [109] "M5Rules"             "manb"                "mda"                
## [112] "Mlda"                "mlp"                 "mlpKerasDecay"      
## [115] "mlpKerasDecayCost"   "mlpKerasDropout"     "mlpKerasDropoutCost"
## [118] "mlpML"               "mlpSGD"              "mlpWeightDecay"     
## [121] "mlpWeightDecayML"    "monmlp"              "msaenet"            
## [124] "multinom"            "mxnet"               "mxnetAdam"          
## [127] "naive_bayes"         "nb"                  "nbDiscrete"         
## [130] "nbSearch"            "neuralnet"           "nnet"               
## [133] "nnls"                "nodeHarvest"         "null"               
## [136] "OneR"                "ordinalNet"          "ordinalRF"          
## [139] "ORFlog"              "ORFpls"              "ORFridge"           
## [142] "ORFsvm"              "ownn"                "pam"                
## [145] "parRF"               "PART"                "partDSA"            
## [148] "pcaNNet"             "pcr"                 "pda"                
## [151] "pda2"                "penalized"           "PenalizedLDA"       
## [154] "plr"                 "pls"                 "plsRglm"            
## [157] "polr"                "ppr"                 "pre"                
## [160] "PRIM"                "protoclass"          "qda"                
## [163] "QdaCov"              "qrf"                 "qrnn"               
## [166] "randomGLM"           "ranger"              "rbf"                
## [169] "rbfDDA"              "Rborist"             "rda"                
## [172] "regLogistic"         "relaxo"              "rf"                 
## [175] "rFerns"              "RFlda"               "rfRules"            
## [178] "ridge"               "rlda"                "rlm"                
## [181] "rmda"                "rocc"                "rotationForest"     
## [184] "rotationForestCp"    "rpart"               "rpart1SE"           
## [187] "rpart2"              "rpartCost"           "rpartScore"         
## [190] "rqlasso"             "rqnc"                "RRF"                
## [193] "RRFglobal"           "rrlda"               "RSimca"             
## [196] "rvmLinear"           "rvmPoly"             "rvmRadial"          
## [199] "SBC"                 "sda"                 "sdwd"               
## [202] "simpls"              "SLAVE"               "slda"               
## [205] "smda"                "snn"                 "sparseLDA"          
## [208] "spikeslab"           "spls"                "stepLDA"            
## [211] "stepQDA"             "superpc"             "svmBoundrangeString"
## [214] "svmExpoString"       "svmLinear"           "svmLinear2"         
## [217] "svmLinear3"          "svmLinearWeights"    "svmLinearWeights2"  
## [220] "svmPoly"             "svmRadial"           "svmRadialCost"      
## [223] "svmRadialSigma"      "svmRadialWeights"    "svmSpectrumString"  
## [226] "tan"                 "tanSearch"           "treebag"            
## [229] "vbmpRadial"          "vglmAdjCat"          "vglmContRatio"      
## [232] "vglmCumulative"      "widekernelpls"       "WM"                 
## [235] "wsrf"                "xgbDART"             "xgbLinear"          
## [238] "xgbTree"             "xyf"
set.seed(2863)
model_NB <- train(diagnosis ~ ., data = trainData, method = "naive_bayes", tuneLength = 2,
    metric = "ROC", trControl = fitControl)
plot(model_NB, main = "NB")

66678ac3dd0ffd24636ae5d1ca9a61cc.png

计算变量重要性

#6.2 How to compute variable importance?
varimp_mars <- varImp(model_NB)
plot(varimp_mars, main="Variable Importance with BreastCancer")

c65ad477f88363d4a7e462dd060ff46d.png

计算混淆矩阵

对于分类模型的只需要看混淆矩阵比较清晰的看出来分类的正确性。

# 6.5. Confusion Matrix Compute the confusion matrix
predProb <- predict(model_NB, testData, type = "prob")
head(predProb)
##              B            M
## 1 2.842085e-20 1.000000e+00
## 2 1.290245e-56 1.000000e+00
## 3 9.999659e-01 3.411905e-05
## 4 1.031544e-09 1.000000e+00
## 5 7.900317e-23 1.000000e+00
## 6 2.765838e-19 1.000000e+00
predicted = predict(model_NB, testData)
testData$predProb = predProb$B
testData$diagnosis = as.factor(testData$diagnosis)
confusionMatrix(reference = testData$diagnosis, data = predicted, mode = "everything",
    positive = "B")
## Confusion Matrix and Statistics
## 
##           Reference
## Prediction  B  M
##          B 99 12
##          M  8 51
##                                           
##                Accuracy : 0.8824          
##                  95% CI : (0.8242, 0.9266)
##     No Information Rate : 0.6294          
##     P-Value [Acc > NIR] : 1.108e-13       
##                                           
##                   Kappa : 0.7445          
##                                           
##  Mcnemar's Test P-Value : 0.5023          
##                                           
##             Sensitivity : 0.9252          
##             Specificity : 0.8095          
##          Pos Pred Value : 0.8919          
##          Neg Pred Value : 0.8644          
##               Precision : 0.8919          
##                  Recall : 0.9252          
##                      F1 : 0.9083          
##              Prevalence : 0.6294          
##          Detection Rate : 0.5824          
##    Detection Prevalence : 0.6529          
##       Balanced Accuracy : 0.8674          
##                                           
##        'Positive' Class : B               
##

绘制ROC曲线

但是根据模型构建后需要进行准确性的评估我们就需要计算一下AUC,绘制ROC曲线来展示一下准确性。

library(ROCR)
pred = prediction(testData$predProb, testData$diagnosis)
perf = performance(pred, measure = "fpr", x.measure = "tpr")
plot(perf, lwd = 2, col = "blue", main = "ROC")
abline(a = 0, b = 1, col = 2, lwd = 1, lty = 2)

0c77b7c3d556fee54231dadd7d01547a.png

多个分类器比较

# Train the model using rf
model_rf = train(diagnosis ~ ., data = trainData, method = "rf", tuneLength = 2,
    trControl = fitControl)
model_rf
## Random Forest 
## 
## 398 samples
##  20 predictor
##   2 classes: 'B', 'M' 
## 
## No pre-processing
## Resampling: Cross-Validated (5 fold) 
## Summary of sample sizes: 319, 318, 319, 318, 318 
## Resampling results across tuning parameters:
## 
##   mtry  ROC        Sens   Spec     
##    2    0.9892575  0.980  0.9197701
##   20    0.9815977  0.956  0.9262069
## 
## ROC was used to select the optimal model using the largest value.
## The final value used for the model was mtry = 2.
# Train the model using adaboost
model_adaboost = train(diagnosis ~ ., data = trainData, method = "adaboost", tuneLength = 2,
    trControl = fitControl)
model_adaboost
## AdaBoost Classification Trees 
## 
## 398 samples
##  20 predictor
##   2 classes: 'B', 'M' 
## 
## No pre-processing
## Resampling: Cross-Validated (5 fold) 
## Summary of sample sizes: 319, 318, 319, 318, 318 
## Resampling results across tuning parameters:
## 
##   nIter  method         ROC        Sens   Spec     
##    50    Adaboost.M1    0.9881103  0.980  0.9183908
##    50    Real adaboost  0.8791931  0.988  0.9114943
##   100    Adaboost.M1    0.9894759  0.984  0.9250575
##   100    Real adaboost  0.8670460  0.984  0.9114943
## 
## ROC was used to select the optimal model using the largest value.
## The final values used for the model were nIter = 100 and method = Adaboost.M1.
# Train the model using Logitboost
model_LogitBoost = train(diagnosis ~ ., data = trainData, method = "LogitBoost",
    tuneLength = 2, trControl = fitControl)
model_LogitBoost
## Boosted Logistic Regression 
## 
## 398 samples
##  20 predictor
##   2 classes: 'B', 'M' 
## 
## No pre-processing
## Resampling: Cross-Validated (5 fold) 
## Summary of sample sizes: 318, 318, 318, 319, 319 
## Resampling results across tuning parameters:
## 
##   nIter  ROC        Sens   Spec     
##   11     0.9834069  0.952  0.9252874
##   21     0.9853494  0.980  0.9455172
## 
## ROC was used to select the optimal model using the largest value.
## The final value used for the model was nIter = 21.

# Train the model using GBM
model_GBM = train(diagnosis ~ ., data = trainData, method = "gbm", tuneLength = 2,
    trControl = fitControl)
model_GBM
## Stochastic Gradient Boosting 
## 
## 398 samples
##  20 predictor
##   2 classes: 'B', 'M' 
## 
## No pre-processing
## Resampling: Cross-Validated (5 fold) 
## Summary of sample sizes: 319, 319, 318, 318, 318 
## Resampling results across tuning parameters:
## 
##   interaction.depth  n.trees  ROC        Sens   Spec     
##   1                   50      0.9866759  0.972  0.8993103
##   1                  100      0.9880828  0.972  0.9193103
##   2                   50      0.9921057  0.972  0.9264368
##   2                  100      0.9874437  0.976  0.9466667
## 
## Tuning parameter 'shrinkage' was held constant at a value of 0.1
## 
## Tuning parameter 'n.minobsinnode' was held constant at a value of 10
## ROC was used to select the optimal model using the largest value.
## The final values used for the model were n.trees = 50, interaction.depth =
##  2, shrinkage = 0.1 and n.minobsinnode = 10.

# Train the model using NB
model_PAM = train(diagnosis ~ ., data = trainData, method = "pam", tuneLength = 2,
    trControl = fitControl)
## 123456789101112131415161718192021222324252627282930111111
model_PAM
## Nearest Shrunken Centroids 
## 
## 398 samples
##  20 predictor
##   2 classes: 'B', 'M' 
## 
## No pre-processing
## Resampling: Cross-Validated (5 fold) 
## Summary of sample sizes: 318, 319, 318, 318, 319 
## Resampling results across tuning parameters:
## 
##   threshold   ROC        Sens  Spec     
##    0.3730189  0.9552092  0.94  0.8108046
##   10.4445296  0.5000000  1.00  0.0000000
## 
## ROC was used to select the optimal model using the largest value.
## The final value used for the model was threshold = 0.3730189.

models_compare <- resamples(list(ADABOOST = model_adaboost, RF = model_rf, LOGITBOOST = model_LogitBoost,
    GBM = model_GBM, PAM = model_PAM, NB = model_NB))
summary(models_compare)
## 
## Call:
## summary.resamples(object = models_compare)
## 
## Models: ADABOOST, RF, LOGITBOOST, GBM, PAM, NB 
## Number of resamples: 5 
## 
## ROC 
##                 Min.   1st Qu.    Median      Mean   3rd Qu.      Max. NA's
## ADABOOST   0.9724138 0.9889655 0.9893333 0.9894759 0.9966667 1.0000000    0
## RF         0.9620000 0.9856667 0.9986207 0.9892575 1.0000000 1.0000000    0
## LOGITBOOST 0.9717241 0.9806897 0.9810000 0.9853494 0.9933333 1.0000000    0
## GBM        0.9780000 0.9900000 0.9966667 0.9921057 0.9972414 0.9986207    0
## PAM        0.9486667 0.9500000 0.9503448 0.9552092 0.9560000 0.9710345    0
## NB         0.9627586 0.9646667 0.9673333 0.9718575 0.9686667 0.9958621    0
## 
## Sens 
##            Min. 1st Qu. Median  Mean 3rd Qu. Max. NA's
## ADABOOST   0.98    0.98   0.98 0.984    0.98 1.00    0
## RF         0.96    0.98   0.98 0.980    0.98 1.00    0
## LOGITBOOST 0.92    0.98   1.00 0.980    1.00 1.00    0
## GBM        0.94    0.96   0.96 0.972    1.00 1.00    0
## PAM        0.90    0.94   0.94 0.940    0.96 0.96    0
## NB         0.90    0.92   0.92 0.928    0.94 0.96    0
## 
## Spec 
##                 Min.   1st Qu.    Median      Mean   3rd Qu.      Max. NA's
## ADABOOST   0.8275862 0.9310345 0.9333333 0.9250575 0.9333333 1.0000000    0
## RF         0.8333333 0.8333333 0.9655172 0.9197701 0.9666667 1.0000000    0
## LOGITBOOST 0.8965517 0.9310345 0.9333333 0.9455172 0.9666667 1.0000000    0
## GBM        0.8333333 0.9000000 0.9333333 0.9264368 0.9655172 1.0000000    0
## PAM        0.7666667 0.7931034 0.8000000 0.8108046 0.8275862 0.8666667    0
## NB         0.8666667 0.8965517 0.9333333 0.9190805 0.9333333 0.9655172    0
# Draw box plots to compare models
scales <- list(x = list(relation = "free"), y = list(relation = "free"))
bwplot(models_compare, scales = scales)

8342cbcadd5ccefe42f34b6f30de97f4.png

生成测试集结果

绘制 Calibration Curves

## Generate the test set results
results <- data.frame(Diagnosis = testData$diagnosis)
results$RF <- predict(model_rf, testData, type = "prob")[, "B"]
results$adaboost <- predict(model_adaboost, testData, type = "prob")[, "B"]
results$LogitBoost <- predict(model_LogitBoost, testData, type = "prob")[, "B"]
results$GBM <- predict(model_GBM, testData, type = "prob")[, "B"]
results$PAM <- predict(model_PAM, testData, type = "prob")[, "B"]
results$NB <- predict(model_NB, testData, type = "prob")[, "B"]
head(results)
##   Diagnosis    RF   adaboost   LogitBoost        GBM          PAM           NB
## 1         M 0.010 0.03767170 1.670142e-05 0.01580114 1.538730e-02 2.842085e-20
## 2         M 0.228 0.23080168 4.742587e-02 0.07651071 6.673408e-05 1.290245e-56
## 3         M 0.732 0.73120191 7.310586e-01 0.81300770 9.392887e-01 9.999659e-01
## 4         M 0.030 0.06663126 9.110512e-04 0.01953871 1.155797e-01 1.031544e-09
## 5         M 0.220 0.23191784 1.670142e-05 0.02014122 4.748940e-03 7.900317e-23
## 6         M 0.062 0.16372403 6.692851e-03 0.03055624 9.553552e-03 2.765838e-19
trellis.par.set(caretTheme())
cal_obj <- calibration(Diagnosis ~ RF + adaboost + LogitBoost + GBM + PAM + NB, data = results,
    cuts = 13)
plot(cal_obj, type = "l", auto.key = list(columns = 5, lines = TRUE, points = FALSE))

70ca3cf8119f9db6293d8892ff5f21ce.png

还有一种ggplot方法可以显示子集内部比例的置信区间:

ggplot(cal_obj)

d1662a126f0470215bca95047d8f0cd2.png

Reference

1. F. -J. Yang, "An Implementation of Naive Bayes Classifier," 2018 International Conference on Computational Science and Computational Intelligence (CSCI), Las Vegas, NV, USA, 2018, pp. 301-306,

2. Rish, Irina. (2001). An Empirical Study of the Naïve Bayes Classifier. IJCAI 2001 Work Empir Methods Artif Intell. 3. 

号外号外,桓峰基因单细胞生信分析免费培训课程即将开始快来报名吧!

桓峰基因,铸造成功的您!

未来桓峰基因公众号将不间断的推出单细胞系列生信分析教程,

敬请期待!!

桓峰基因官网正式上线,请大家多多关注,还有很多不足之处,大家多多指正!http://www.kyohogene.com/

桓峰基因和投必得合作,文章润色优惠85折,需要文章润色的老师可以直接到网站输入领取桓峰基因专属优惠券码:KYOHOGENE,然后上传,付款时选择桓峰基因优惠券即可享受85折优惠哦!https://www.topeditsci.com/

b7345713942aa3ab7edd624229aa61a1.png

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值