tidymodels超参数设置

本篇主要介绍tidymodelsdials包的使用,dials主要用于创建超参数的值。主要有以下3方面的作用:

  • 用于创建、展示、查询超参数
  • R中的不同包对于同一个超参数有不同的名字,dials使用统一的名字
  • tidymodels中的其他R包进行配合

每个算法有很多超参数,在tidymodels中并不是每个超参数都可以调整,并且这些超参数的名字你也不知道。

如何知道这些超参数在tidymodels中的名字,并且知道每个算法有哪些超参数可以调整呢?这些我在另一篇推文中有详细的介绍:tidymodels的算法选择包parsnip的使用说明

每个超参数在dials中都有默认的取值范围,可以直接通过超参数名进行查看:

library(tidymodels)
## ── Attaching packages ────────────────────────────────────── tidymodels 1.0.0 ──
## ✔ broom        1.0.3     ✔ recipes      1.0.4
## ✔ dials        1.1.0     ✔ rsample      1.1.1
## ✔ dplyr        1.1.1     ✔ tibble       3.2.1
## ✔ ggplot2      3.4.1     ✔ tidyr        1.3.0
## ✔ infer        1.0.4     ✔ tune         1.0.1
## ✔ modeldata    1.1.0     ✔ workflows    1.1.2
## ✔ parsnip      1.0.3     ✔ workflowsets 1.0.0
## ✔ purrr        1.0.1     ✔ yardstick    1.1.0
## ── Conflicts ───────────────────────────────────────── tidymodels_conflicts() ──
## ✖ purrr::discard() masks scales::discard()
## ✖ dplyr::filter()  masks stats::filter()
## ✖ dplyr::lag()     masks stats::lag()
## ✖ recipes::step()  masks stats::step()
## • Use tidymodels_prefer() to resolve common conflicts.

# 比如决策树中的cp超参数
cost_complexity()
## Cost-Complexity Parameter (quantitative)
## Transformer: log-10 [1e-100, Inf]
## Range (transformed scale): [-10, -1]

可以看到在dials中默认的cp范围是[-10,-1],这其实是经过log10转换之后的,原始的cp的范围应该是[10^-10, 0.1]。

也可以通过以下函数查看默认的范围:

cost_complexity() %>% range_get()
## $lower
## [1] 1e-10
## 
## $upper
## [1] 0.1

可以手动设置超参数的范围:

# 2种方法都可以,效果一样
cost_complexity() %>% range_set(c(-5,1))
## Cost-Complexity Parameter (quantitative)
## Transformer: log-10 [1e-100, Inf]
## Range (transformed scale): [-5, 1]
cost_complexity(range = c(-5,1))
## Cost-Complexity Parameter (quantitative)
## Transformer: log-10 [1e-100, Inf]
## Range (transformed scale): [-5, 1]

还可以手动设置超参数的具体值:

# 这个参数默认是log的,所以要用log
cost_complexity() %>% value_set(log10(c(0.001,0.01)))
## Cost-Complexity Parameter (quantitative)
## Transformer: log-10 [1e-100, Inf]
## Range (transformed scale): [-10, -1]
## Values: 2

有些超参数的值一开始并不能被设定好,一定要有数据才行,这种情况在dials中是通过占位符表示的,比如一个?

mtry()
## # Randomly Selected Predictors (quantitative)
## Range: [1, ?]
sample_size()
## # Observations Sampled (quantitative)
## Range: [?, ?]

同时设置多个超参数的范围也是可以的。

比如在弹性网络模型中有lambdaalpha两个超参数可以调整,可通过以下方法查看超参数的范围:

glmnet_set <- parameters(list(lambda = penalty(), alpha = mixture()))
glmnet_set
## Collection of 2 parameters for tuning
## 
##  identifier    type    object
##      lambda penalty nparam[+]
##       alpha mixture nparam[+]

可以被更改:

# 更改alpha的范围
update(glmnet_set, alpha = mixture(c(.3, .6)))
## Collection of 2 parameters for tuning
## 
##  identifier    type    object
##      lambda penalty nparam[+]
##       alpha mixture nparam[+]

tidymodels中提供两种常见的超参数调优方式,分别是网格搜索和迭代搜索。在进行网格搜索时,我们经常要提前设置超参数的网格范围,dials提供几个常用的函数快速进行设置。

我最常用的是这个grid_regular函数:

# 3种方法
grid_regular(mixture(),
             penalty(),
             levels = 3 # 设置每个超参数取几个值
             )

grid_regular(mixture(),
             penalty(),
             levels = c(3,4)
             )

grid_regular(mixture(),
             penalty(),
             levels = c(mixture = 3, penalty = 4)
             )
## # A tibble: 9 × 2
##   mixture      penalty
##     <dbl>        <dbl>
## 1     0   0.0000000001
## 2     0.5 0.0000000001
## 3     1   0.0000000001
## 4     0   0.00001     
## 5     0.5 0.00001     
## 6     1   0.00001     
## 7     0   1           
## 8     0.5 1           
## 9     1   1

通过grid_regular函数生成的网格在tidymodels中被称为规则网格,这也是大家常说的网格搜索中的网格。

除了规则网格外,还有不规则网格,常见的生成方法是通过grid_random实现,这个函数生成的网格又被叫做随机网格,也就是大家常说的随机搜索

set.seed(1041)
grid_random(
  mixture(),
  penalty(),
  size = 6 # 设置一共有几组值
)
## # A tibble: 6 × 2
##   mixture     penalty
##     <dbl>       <dbl>
## 1   0.200 0.0176     
## 2   0.750 0.000388   
## 3   0.191 0.000000159
## 4   0.929 0.00000176 
## 5   0.143 0.0442     
## 6   0.973 0.0110

但是grid_random生成的随机网格有几个问题,首先就是,如果网格的规模比较小,那么参数值组合之间可能会有重叠;其次是随机网格需要覆盖整个超参数空间,这就要求超参数值的数量要足够多,这样会很耗费时间。

所以tidymodels提供了另外几种、被称为空间填充设计的,生成不规则网格的方法,用来弥补随机网格的缺点。

比如:grid_max_entropy()通过最大熵设计生成网格, grid_latin_hypercube()通过拉丁超立方设计生成网格。它们的用法都是一样的。

在实际使用中我们经常要先建立模型,然后选择合适的超参数范围(虽然是有默认范围的,但有时我们可能想要更改)。

比如选择一个随机森林算法:

library(parsnip)

# 建立模型设定
rf_spec <- rand_forest(mode = "classification", 
                       mtry = tune(),
                       min_n = tune(),
                       trees = 500
                       )

此时的算法超参数是有默认值的:

# 查看所有的超参数
rf_spec %>% extract_parameter_set_dials()
## Collection of 2 parameters for tuning
## 
##  identifier  type    object
##        mtry  mtry nparam[?]
##       min_n min_n nparam[+]
## 
## Model parameters needing finalization:
##    # Randomly Selected Predictors ('mtry')
## 
## See `?dials::finalize` or `?dials::update.parameters` for more information.

# 查看某一个超参数
rf_spec %>% extract_parameter_dials("min_n")
## Minimal Node Size (quantitative)
## Range: [2, 40]

# 或者像上面介绍的,直接使用参数名
min_n()
## Minimal Node Size (quantitative)
## Range: [2, 40]

可以通过多种方法进行更改,首先是更新:

# 更新默认超参数的值
rf_param <- rf_spec %>% 
  extract_parameter_set_dials() %>% 
  update(min_n = min_n(c(10,50)),
         mtry = mtry(c(1,5))
         )

然后再通过网格函数产生网格:

grid_rf <- grid_regular(rf_param,
                        levels = 5
                        )
grid_rf
## # A tibble: 25 × 2
##     mtry min_n
##    <int> <int>
##  1     1    10
##  2     2    10
##  3     3    10
##  4     4    10
##  5     5    10
##  6     1    20
##  7     2    20
##  8     3    20
##  9     4    20
## 10     5    20
## # … with 15 more rows

或者不用这么麻烦,直接设置超参数的范围:

# 建议用这种,最简单直接!
rf_grid <- grid_regular(mtry(c(1,5)),
                        min_n(c(10,50)),
                        levels = 5
                        )
rf_grid
## # A tibble: 25 × 2
##     mtry min_n
##    <int> <int>
##  1     1    10
##  2     2    10
##  3     3    10
##  4     4    10
##  5     5    10
##  6     1    20
##  7     2    20
##  8     3    20
##  9     4    20
## 10     5    20
## # … with 15 more rows

或者使用以下方法:

rf_params <- parameters(mtry(c(1,5)),
                        min_n(c(10,50))
                        )
gridd <- grid_regular(rf_params,levels = 5)
gridd
## # A tibble: 25 × 2
##     mtry min_n
##    <int> <int>
##  1     1    10
##  2     2    10
##  3     3    10
##  4     4    10
##  5     5    10
##  6     1    20
##  7     2    20
##  8     3    20
##  9     4    20
## 10     5    20
## # … with 15 more rows

dials包作为tidymodels的一部分,主要作用就是设置超参数,可以看到dials中设置超参数的方法很灵活,在实际使用中需要注意,选择自己最喜欢的一种即可,尤其是设定超参数网格的方法,大家一定不要搞混了!

  • 11
    点赞
  • 12
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值