caret教程06:模型评价

介绍caret包中的常见的模型评价指标及可视化方法。

关于机器学习中到底有哪些评价指标,每种指标表示什么意思,大家需要自己学习。

二分类问题的评价指标基本上都是围绕混淆矩阵来的,所以你一定要搞清楚混淆矩阵!另外,临床预测模型领域常见的指标基本都是使用了机器学习领域的指标。

大家可以参考这几篇文章:

  • https://www.cnblogs.com/zongfa/p/9431807.html#top
  • https://blog.csdn.net/sinat_16388393/article/details/91427631
  • https://zhuanlan.zhihu.com/p/359997979

回归模型评价

postResample函数可以估计数值型结果变量的RMSE, R^2, MAE。

library(caret)
## Loading required package: ggplot2
## Loading required package: lattice
library(mlbench)
data(BostonHousing)

set.seed(280)
bh_index <- createDataPartition(BostonHousing$medv, p = .75, list = FALSE)
bh_tr <- BostonHousing[ bh_index, ]
bh_te <- BostonHousing[-bh_index, ]

set.seed(7279)
lm_fit <- train(medv ~ . + rm:lstat,
                data = bh_tr, 
                method = "lm")
bh_pred <- predict(lm_fit, bh_te)

lm_fit
## Linear Regression 
## 
## 381 samples
##  13 predictor
## 
## No pre-processing
## Resampling: Bootstrapped (25 reps) 
## Summary of sample sizes: 381, 381, 381, 381, 381, 381, ... 
## Resampling results:
## 
##   RMSE      Rsquared   MAE     
##   4.374098  0.7724562  2.963927
## 
## Tuning parameter 'intercept' was held constant at a value of TRUE

查看测试集的结果:

postResample(pred = bh_pred, obs = bh_te$medv)
##      RMSE  Rsquared       MAE 
## 4.0927043 0.8234427 2.8163731

预测类别型的模型评价

对于结果变量是分类的,直接给出到底属于哪一个类,这种类型的模型评价最常用的是混淆矩阵。

# 首先虚构一个二分类数据,包含真实结果,预测结果,每一个类的预测概率
set.seed(144)
true_class <- factor(sample(paste0("Class", 1:2), 
                            size = 1000,
                            prob = c(.2, .8), replace = TRUE))
true_class <- sort(true_class)
class1_probs <- rbeta(sum(true_class == "Class1"), 4, 1)
class2_probs <- rbeta(sum(true_class == "Class2"), 1, 2.5)
test_set <- data.frame(obs = true_class,
                       Class1 = c(class1_probs, class2_probs))
test_set$Class2 <- 1 - test_set$Class1
test_set$pred <- factor(ifelse(test_set$Class1 >= .5, "Class1", "Class2"))

psych::headTail(test_set)
##         obs Class1 Class2   pred
## 1    Class1   0.96   0.04 Class1
## 2    Class1   0.97   0.03 Class1
## 3    Class1   0.75   0.25 Class1
## 4    Class1   0.81   0.19 Class1
## ...    <NA>    ...    ...   <NA>
## 997  Class2    0.2    0.8 Class2
## 998  Class2   0.41   0.59 Class2
## 999  Class2   0.07   0.93 Class2
## 1000 Class2   0.84   0.16 Class1

看下每个类别的概率分布:

ggplot(test_set, aes(x = Class1)) + 
  geom_histogram(binwidth = .05) + 
  facet_wrap(~obs) + 
  xlab("Probability of Class #1")

unnamed-chunk-4-170387896

有了预测结果和真实结果,我们可以计算混淆矩阵,caret包给出的混淆矩阵非常详细!我一直很喜欢这个R包,因为它真的很牛逼!

confusionMatrix(data = test_set$pred, reference = test_set$obs)
## Confusion Matrix and Statistics
## 
##           Reference
## Prediction Class1 Class2
##     Class1    183    141
##     Class2     13    663
##                                           
##                Accuracy : 0.846           
##                  95% CI : (0.8221, 0.8678)
##     No Information Rate : 0.804           
##     P-Value [Acc > NIR] : 0.0003424       
##                                           
##                   Kappa : 0.6081          
##                                           
##  Mcnemar's Test P-Value : < 2.2e-16       
##                                           
##             Sensitivity : 0.9337          
##             Specificity : 0.8246          
##          Pos Pred Value : 0.5648          
##          Neg Pred Value : 0.9808          
##              Prevalence : 0.1960          
##          Detection Rate : 0.1830          
##    Detection Prevalence : 0.3240          
##       Balanced Accuracy : 0.8792          
##                                           
##        'Positive' Class : Class1          
## 

默认是根据sensitivity and specificity,可以更改:

confusionMatrix(data = test_set$pred, reference = test_set$obs, mode = "prec_recall")
## Confusion Matrix and Statistics
## 
##           Reference
## Prediction Class1 Class2
##     Class1    183    141
##     Class2     13    663
##                                           
##                Accuracy : 0.846           
##                  95% CI : (0.8221, 0.8678)
##     No Information Rate : 0.804           
##     P-Value [Acc > NIR] : 0.0003424       
##                                           
##                   Kappa : 0.6081          
##                                           
##  Mcnemar's Test P-Value : < 2.2e-16       
##                                           
##               Precision : 0.5648          
##                  Recall : 0.9337          
##                      F1 : 0.7038          
##              Prevalence : 0.1960          
##          Detection Rate : 0.1830          
##    Detection Prevalence : 0.3240          
##       Balanced Accuracy : 0.8792          
##                                           
##        'Positive' Class : Class1          
## 

positive参数可以控制计算哪一个类的结果,如果你不明白,可以翻看之前的推文:ROC曲线的两面性

除此之外,还有其他函数可以计算类别型模型的指标:sensitivity, specificity, posPredValue, negPredValue, precision, recall, and F_meas

直接用就行,非常简单,这个简单的逻辑也带到了新的R包:tidymodels中,二者使用逻辑如出一辙,以后也会详细介绍它。

sensitivity(data = test_set$pred, reference = test_set$obs)
## [1] 0.9336735
posPredValue(data = test_set$pred, reference = test_set$obs)
## [1] 0.5648148
F_meas(data = test_set$pred, reference = test_set$obs)
## [1] 0.7038462

train()函数中默认的类别型结果计算模型性能的函数也是postResample()

postResample(pred = test_set$pred, obs = test_set$obs)
##  Accuracy     Kappa 
## 0.8460000 0.6081345

类别概率型模型的评价

上面是针对给出具体的预测分类,这个是给出概率,两个都是针对分类变量的。

其实这两个使用方法和之前说过的yardsticks非常像,不太明白其中规律的可以看这个:tidymodels使用细节

非常多的指标可选,关于具体有哪些指标,以及指标的含义,建议大家自己专门学习,太多了,作为业余选手,很多我都用不到。

twoClassSummary(test_set, lev = levels(test_set$obs))
##       ROC      Sens      Spec 
## 0.9560044 0.9336735 0.8246269
prSummary(test_set, lev = levels(test_set$obs))
##       AUC Precision    Recall         F 
## 0.8582695 0.5648148 0.9336735 0.7038462 
mnLogLoss(test_set, lev = levels(test_set$obs))
##  logLoss 
## 0.370626

lift curve

提升曲线。

首先构建多个模型,我一直觉得caret语法相比于tidymodelsmlr3,真的是简单太多!

# 使用16个线程,加速,caret太慢了
library(doParallel)
## 载入需要的程辑包:foreach
## 载入需要的程辑包:iterators
## 载入需要的程辑包:parallel
## 载入需要的程辑包:earth
## 载入需要的程辑包:Formula
## 载入需要的程辑包:plotmo
## 载入需要的程辑包:plotrix
## 载入需要的程辑包:TeachingDemos
cl <- makePSOCKcluster(16)
registerDoParallel(cl)

# 虚构训练集和测试集
set.seed(2)
lift_training <- twoClassSim(1000)
lift_testing  <- twoClassSim(1000)

# 设置重抽样方法,选择预测结果是“概率”
ctrl <- trainControl(method = "cv", classProbs = TRUE,
                     summaryFunction = twoClassSummary)

# fda
set.seed(1045)
fda_lift <- train(Class ~ ., data = lift_training,
                  method = "fda", metric = "ROC",
                  tuneLength = 20,
                  trControl = ctrl)

# lda
set.seed(1045)
lda_lift <- train(Class ~ ., data = lift_training,
                  method = "lda", metric = "ROC",
                  trControl = ctrl)
# c50
library(C50)
set.seed(1045)
c5_lift <- train(Class ~ ., data = lift_training,
                 method = "C5.0", metric = "ROC",
                 tuneLength = 10,
                 trControl = ctrl,
                 control = C5.0Control(earlyStopping = FALSE))

## 使用测试集评估模型效果
lift_results <- data.frame(Class = lift_testing$Class)
lift_results$FDA <- predict(fda_lift, lift_testing, type = "prob")[,"Class1"]
lift_results$LDA <- predict(lda_lift, lift_testing, type = "prob")[,"Class1"]
lift_results$C5.0 <- predict(c5_lift, lift_testing, type = "prob")[,"Class1"]
head(lift_results)
##    Class        FDA       LDA      C5.0
## 1 Class1 0.99187063 0.8838205 0.8445830
## 2 Class1 0.99115613 0.7572450 0.8882418
## 3 Class1 0.80567440 0.8883830 0.5732098
## 4 Class2 0.05245632 0.0140480 0.1690251
## 5 Class1 0.76175025 0.9320695 0.4824400
## 6 Class2 0.13782751 0.0524154 0.3310495

stopCluster(cl)

画图很简单,就是plot:

# 设置主题,如果你不懂也没关系,这个方法有点过时了,现在是ggplot的天下
trellis.par.set(caretTheme())
# 先构建lift
lift_obj <- lift(Class ~ FDA + LDA + C5.0, data = lift_results)
# 画图
plot(lift_obj, values = 60, auto.key = list(columns = 3,
                                            lines = TRUE,
                                            points = FALSE))

在这里插入图片描述

ggplot2也支持的,是不是很简单?

ggplot(lift_obj, values = 60)

校准曲线

一模一样的流程!

# 设置主题
trellis.par.set(caretTheme())
# 构建校准曲线
cal_obj <- calibration(Class ~ FDA + LDA + C5.0,
                       data = lift_results,
                       cuts = 13)
# 画图
plot(cal_obj, type = "l", auto.key = list(columns = 3,
                                          lines = TRUE,
                                          points = FALSE))

在这里插入图片描述

ggplot2一样支持,tidymodels也支持校准曲线了。tidymodels支持校准曲线了

# 太好看了有木有!
ggplot(cal_obj)+theme_bw()

还有非常多其他的曲线可以画,大家自行探索即可。

今天这篇只是抛砖引玉,caret的模型评价指标非常全面!对于回归和分类的支持相当强大!大家如果能玩转这个包,绝壁也是大佬了!

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值