本篇主要介绍tidymodels
中dials
包的使用,dials
主要用于创建超参数的值。主要有以下3方面的作用:
- 用于创建、展示、查询超参数
- R中的不同包对于同一个超参数有不同的名字,
dials
使用统一的名字 - 和
tidymodels
中的其他R包进行配合
每个算法有很多超参数,在tidymodels
中并不是每个超参数都可以调整,并且这些超参数的名字你也不知道。
如何知道这些超参数在tidymodels
中的名字,并且知道每个算法有哪些超参数可以调整呢?这些我在另一篇推文中有详细的介绍:tidymodels的算法选择包parsnip的使用说明
每个超参数在dials
中都有默认的取值范围,可以直接通过超参数名进行查看:
library(tidymodels)
## ── Attaching packages ────────────────────────────────────── tidymodels 1.0.0 ──
## ✔ broom 1.0.3 ✔ recipes 1.0.4
## ✔ dials 1.1.0 ✔ rsample 1.1.1
## ✔ dplyr 1.1.1 ✔ tibble 3.2.1
## ✔ ggplot2 3.4.1 ✔ tidyr 1.3.0
## ✔ infer 1.0.4 ✔ tune 1.0.1
## ✔ modeldata 1.1.0 ✔ workflows 1.1.2
## ✔ parsnip 1.0.3 ✔ workflowsets 1.0.0
## ✔ purrr 1.0.1 ✔ yardstick 1.1.0
## ── Conflicts ───────────────────────────────────────── tidymodels_conflicts() ──
## ✖ purrr::discard() masks scales::discard()
## ✖ dplyr::filter() masks stats::filter()
## ✖ dplyr::lag() masks stats::lag()
## ✖ recipes::step() masks stats::step()
## • Use tidymodels_prefer() to resolve common conflicts.
# 比如决策树中的cp超参数
cost_complexity()
## Cost-Complexity Parameter (quantitative)
## Transformer: log-10 [1e-100, Inf]
## Range (transformed scale): [-10, -1]
可以看到在dials
中默认的cp
范围是[-10,-1],这其实是经过log10
转换之后的,原始的cp
的范围应该是[10^-10, 0.1]。
也可以通过以下函数查看默认的范围:
cost_complexity() %>% range_get()
## $lower
## [1] 1e-10
##
## $upper
## [1] 0.1
可以手动设置超参数的范围:
# 2种方法都可以,效果一样
cost_complexity() %>% range_set(c(-5,1))
## Cost-Complexity Parameter (quantitative)
## Transformer: log-10 [1e-100, Inf]
## Range (transformed scale): [-5, 1]
cost_complexity(range = c(-5,1))
## Cost-Complexity Parameter (quantitative)
## Transformer: log-10 [1e-100, Inf]
## Range (transformed scale): [-5, 1]
还可以手动设置超参数的具体值:
# 这个参数默认是log的,所以要用log
cost_complexity() %>% value_set(log10(c(0.001,0.01)))
## Cost-Complexity Parameter (quantitative)
## Transformer: log-10 [1e-100, Inf]
## Range (transformed scale): [-10, -1]
## Values: 2
有些超参数的值一开始并不能被设定好,一定要有数据才行,这种情况在dials
中是通过占位符
表示的,比如一个?
:
mtry()
## # Randomly Selected Predictors (quantitative)
## Range: [1, ?]
sample_size()
## # Observations Sampled (quantitative)
## Range: [?, ?]
同时设置多个超参数的范围也是可以的。
比如在弹性网络模型中有lambda
和alpha
两个超参数可以调整,可通过以下方法查看超参数的范围:
glmnet_set <- parameters(list(lambda = penalty(), alpha = mixture()))
glmnet_set
## Collection of 2 parameters for tuning
##
## identifier type object
## lambda penalty nparam[+]
## alpha mixture nparam[+]
可以被更改:
# 更改alpha的范围
update(glmnet_set, alpha = mixture(c(.3, .6)))
## Collection of 2 parameters for tuning
##
## identifier type object
## lambda penalty nparam[+]
## alpha mixture nparam[+]
tidymodels
中提供两种常见的超参数调优方式,分别是网格搜索和迭代搜索。在进行网格搜索时,我们经常要提前设置超参数的网格范围,dials
提供几个常用的函数快速进行设置。
我最常用的是这个grid_regular
函数:
# 3种方法
grid_regular(mixture(),
penalty(),
levels = 3 # 设置每个超参数取几个值
)
grid_regular(mixture(),
penalty(),
levels = c(3,4)
)
grid_regular(mixture(),
penalty(),
levels = c(mixture = 3, penalty = 4)
)
## # A tibble: 9 × 2
## mixture penalty
## <dbl> <dbl>
## 1 0 0.0000000001
## 2 0.5 0.0000000001
## 3 1 0.0000000001
## 4 0 0.00001
## 5 0.5 0.00001
## 6 1 0.00001
## 7 0 1
## 8 0.5 1
## 9 1 1
通过grid_regular
函数生成的网格在tidymodels
中被称为规则网格,这也是大家常说的网格搜索中的网格。
除了规则网格外,还有不规则网格,常见的生成方法是通过grid_random
实现,这个函数生成的网格又被叫做随机网格,也就是大家常说的随机搜索:
set.seed(1041)
grid_random(
mixture(),
penalty(),
size = 6 # 设置一共有几组值
)
## # A tibble: 6 × 2
## mixture penalty
## <dbl> <dbl>
## 1 0.200 0.0176
## 2 0.750 0.000388
## 3 0.191 0.000000159
## 4 0.929 0.00000176
## 5 0.143 0.0442
## 6 0.973 0.0110
但是grid_random
生成的随机网格有几个问题,首先就是,如果网格的规模比较小,那么参数值组合之间可能会有重叠;其次是随机网格需要覆盖整个超参数空间,这就要求超参数值的数量要足够多,这样会很耗费时间。
所以tidymodels
提供了另外几种、被称为空间填充设计的,生成不规则网格的方法,用来弥补随机网格的缺点。
比如:grid_max_entropy()
通过最大熵设计生成网格, grid_latin_hypercube()
通过拉丁超立方设计生成网格。它们的用法都是一样的。
在实际使用中我们经常要先建立模型,然后选择合适的超参数范围(虽然是有默认范围的,但有时我们可能想要更改)。
比如选择一个随机森林算法:
library(parsnip)
# 建立模型设定
rf_spec <- rand_forest(mode = "classification",
mtry = tune(),
min_n = tune(),
trees = 500
)
此时的算法超参数是有默认值的:
# 查看所有的超参数
rf_spec %>% extract_parameter_set_dials()
## Collection of 2 parameters for tuning
##
## identifier type object
## mtry mtry nparam[?]
## min_n min_n nparam[+]
##
## Model parameters needing finalization:
## # Randomly Selected Predictors ('mtry')
##
## See `?dials::finalize` or `?dials::update.parameters` for more information.
# 查看某一个超参数
rf_spec %>% extract_parameter_dials("min_n")
## Minimal Node Size (quantitative)
## Range: [2, 40]
# 或者像上面介绍的,直接使用参数名
min_n()
## Minimal Node Size (quantitative)
## Range: [2, 40]
可以通过多种方法进行更改,首先是更新:
# 更新默认超参数的值
rf_param <- rf_spec %>%
extract_parameter_set_dials() %>%
update(min_n = min_n(c(10,50)),
mtry = mtry(c(1,5))
)
然后再通过网格函数产生网格:
grid_rf <- grid_regular(rf_param,
levels = 5
)
grid_rf
## # A tibble: 25 × 2
## mtry min_n
## <int> <int>
## 1 1 10
## 2 2 10
## 3 3 10
## 4 4 10
## 5 5 10
## 6 1 20
## 7 2 20
## 8 3 20
## 9 4 20
## 10 5 20
## # … with 15 more rows
或者不用这么麻烦,直接设置超参数的范围:
# 建议用这种,最简单直接!
rf_grid <- grid_regular(mtry(c(1,5)),
min_n(c(10,50)),
levels = 5
)
rf_grid
## # A tibble: 25 × 2
## mtry min_n
## <int> <int>
## 1 1 10
## 2 2 10
## 3 3 10
## 4 4 10
## 5 5 10
## 6 1 20
## 7 2 20
## 8 3 20
## 9 4 20
## 10 5 20
## # … with 15 more rows
或者使用以下方法:
rf_params <- parameters(mtry(c(1,5)),
min_n(c(10,50))
)
gridd <- grid_regular(rf_params,levels = 5)
gridd
## # A tibble: 25 × 2
## mtry min_n
## <int> <int>
## 1 1 10
## 2 2 10
## 3 3 10
## 4 4 10
## 5 5 10
## 6 1 20
## 7 2 20
## 8 3 20
## 9 4 20
## 10 5 20
## # … with 15 more rows
dials
包作为tidymodels
的一部分,主要作用就是设置超参数,可以看到dials
中设置超参数的方法很灵活,在实际使用中需要注意,选择自己最喜欢的一种即可,尤其是设定超参数网格的方法,大家一定不要搞混了!