贝叶斯随机参数模型R语言实现
本文通过使用R,主要通过使用tidyverse包来进行数据清洗处理以及绘图,使用brms包来实现Bayesian回归模型。
1. 贝叶斯数据分析的基本步骤
-
确定与研究问题相关的数据。数据的度量尺度是什么?哪些数据变量是因变量,哪些数据变量应该是解释变量?
-
为相关数据建立一个描述性模型,给予数学形式及其参数描述。
-
指定参数的先验分布。
-
使用贝叶斯推理在参数值之间重新分配可信度。从理论上解释有意义的变量的后验分布(假设模型是对数据的合理描述)。
-
检查后验预测是否以合理的准确度模拟数据(即进行“后验预测检查”)。如果不是,则考虑不同的描述模型。
2. Example
2.1导入数据
# 一个集合了众多数据分析常用包的包,详情https://zhuanlan.zhihu.com/p/80732610
library(bruceR)
##
## ⭐ bruceR: BRoadly Useful Convenient and Efficient R functions
##
## Loaded R packages:
## [Data]: rio / dplyr / tidyr / stringr / forcats / data.table
## [Stat]: psych / emmeans / effectsize / performance
## [Plot]: ggplot2 / ggtext / cowplot / see
##
## Frequently used functions in `bruceR`:
## set.wd() / Describe() / Freq() / Corr() / Alpha() / MEAN()
## MANOVA() / EMMEANS() / model_summary() / theme_bruce()
# bruceR 内置的函数import,可以导入csv、excel等多种格式
df <- import("/Users/cpf/Downloads/salary.txt")
# 观察数据
glimpse(df)
## Rows: 62
## Columns: 6
## $ id <int> 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18…
## $ time <int> 3, 6, 3, 8, 9, 6, 16, 10, 2, 5, 5, 6, 7, 11, 18, 6, 9, 7, 7, …
## $ pub <int> 18, 3, 2, 17, 11, 6, 38, 48, 9, 22, 30, 21, 10, 27, 37, 8, 13…
## $ sex <int> 1, 1, 1, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 1, 1…
## $ citation <int> 50, 26, 50, 34, 41, 37, 48, 56, 19, 29, 28, 31, 25, 40, 61, 3…
## $ salary <int> 51876, 54511, 53425, 61863, 52926, 47034, 66432, 61100, 41934…
## id: 就是id; time: 获得博士学位到现在的时间; pub: 出版数量;
# sex: 性别 1是女性,0是男性; citation: 被引用数量; salary: 现在的收入
#去掉第一列id, 数据的基本描述统计, 相关系数描述图1
Describe(df[, -1],plot = TRUE)
## Descriptive Statistics:
## ────────────────────────────────────────────────────────────────────────────
## N Mean SD | Median Min Max Skewness Kurtosis
## ────────────────────────────────────────────────────────────────────────────
## time 62 6.79 4.28 | 6.00 1.00 21.00 1.23 1.29
## pub 62 18.18 14.00 | 13.00 1.00 69.00 1.31 1.61
## sex 62 0.44 0.50 | 0.00 0.00 1.00 0.25 -1.97
## citation 62 40.23 17.17 | 35.00 1.00 90.00 0.65 0.32
## salary 62 54815.76 9706.02 | 53681.00 37939.00 83503.00 0.61 0.25
## ────────────────────────────────────────────────────────────────────────────
# 相关系数描述图2
pairs.panels(df[, -1], ellipses = FALSE)
2.2 简单线性回归,因变量salary, 自变量pub
2.2.1 先对变量进行简单可视化
本文的可视化主要用到的是ggplot2函数,在可视化之前首先介绍如何在图中正确显示中文字体。
# 解决中文乱码问题
library(showtext)
font_add("kaiti", "/Users/cpf/Library/Fonts/楷体_GB2312.ttf")
showtext_auto()
# 开始画图
p1 <- ggplot(df, #数据
# 映射到画布上,x轴为pub,y轴为salary
aes(x = pub, y = salary)) +
# geom 可以画points, lines, shapes等多种类型的图,这里是散点图
geom_point() +
# 在散点图的基础上再加上平滑拟合工资的均值
geom_smooth() +
# 修改连续型坐标轴刻度及标签
scale_x_continuous(breaks = seq(0,80,by=10)) +
# 修改y轴
scale_y_continuous(breaks = seq(40000, 100000, by=10000)) +
#设置坐标轴标签
labs(x = "pub 出版数量", y = "salary 薪水") +
theme_bw()
m1 <- lm( salary ~ 1 + pub, data = df)
summary(m1)
##
## Call:
## lm(formula = salary ~ 1 + pub, data = df)
##
## Residuals:
## Min 1Q Median 3Q Max
## -20660.0 -7397.5 333.7 5313.9 19238.7
##
## Coefficients:
## Estimate Std. Error t value Pr(>|t|)
## (Intercept) 48439.09 1765.42 27.438 < 2e-16 ***
## pub 350.80 77.17 4.546 2.71e-05 ***
## ---
## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
##
## Residual standard error: 8440 on 60 degrees of freedom
## Multiple R-squared: 0.2562, Adjusted R-squared: 0.2438
## F-statistic: 20.67 on 1 and 60 DF, p-value: 2.706e-05
p1 +
# 无参数的拟合
geom_smooth(se = FALSE) +
# 线性回归(红线)
geom_smooth(method = "lm", col = "red")
2.3 Bayesian回归模型: Basic; Random intercept; Random slope
2.3.1 调用summary
函数来查看模型拟合效果
summary(fit1)
## Family: gaussian
## Links: mu = identity; sigma = identity
## Formula: salary ~ 1 + pub
## Data: df (Number of observations: 62)
## Samples: 4 chains, each with iter = 2000; warmup = 1000; thin = 1;
## total post-warmup samples = 4000
##
## Population-Level Effects:
## Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
## Intercept 48417.10 1792.15 44846.95 51820.71 1.00 3752 2966
## pub 351.57 78.68 200.29 507.17 1.00 3861 2731
##
## Family Specific Parameters:
## Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
## sigma 8564.52 770.84 7232.69 10211.00 1.00 3516 2849
##
## Samples were drawn using sampling(NUTS). For each parameter, Bulk_ESS
## and Tail_ESS are effective sample size measures, and Rhat is the potential
## scale reduction factor on split chains (at convergence, Rhat = 1).
summary(fit2)
## Family: gaussian
## Links: mu = identity; sigma = identity
## Formula: salary ~ 1 + pub + (1 | id)
## Data: df (Number of observations: 62)
## Samples: 4 chains, each with iter = 2000; warmup = 1000; thin = 1;
## total post-warmup samples = 4000
##
## Group-Level Effects:
## ~id (Number of levels: 62)
## Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
## sd(Intercept) 4653.58 2544.46 237.26 8797.24 1.02 93 224
##
## Population-Level Effects:
## Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
## Intercept 48331.10 1754.94 44987.38 52007.40 1.01 1553 1882
## pub 350.10 77.83 201.68 506.99 1.00 1208 2176
##
## Family Specific Parameters:
## Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
## sigma 6571.68 1967.71 2604.79 9572.40 1.03 90 101
##
## Samples were drawn using sampling(NUTS). For each parameter, Bulk_ESS
## and Tail_ESS are effective sample size measures, and Rhat is the potential
## scale reduction factor on split chains (at convergence, Rhat = 1).
summary(fit3)
## Family: gaussian
## Links: mu = identity; sigma = identity
## Formula: salary ~ 1 + pub + (1 + pub | id)
## Data: df (Number of observations: 62)
## Samples: 4 chains, each with iter = 2000; warmup = 1000; thin = 1;
## total post-warmup samples = 4000
##
## Group-Level Effects:
## ~id (Number of levels: 62)
## Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
## sd(Intercept) 3892.24 2244.35 259.73 8313.88 1.02 257 698
## sd(pub) 227.14 119.49 28.38 487.90 1.06 67 121
## cor(Intercept,pub) -0.00 0.53 -0.93 0.94 1.03 333 1001
##
## Population-Level Effects:
## Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
## Intercept 48822.85 1749.96 45346.27 52383.86 1.01 1587 2276
## pub 325.23 101.22 128.44 524.34 1.01 287 554
##
## Family Specific Parameters:
## Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
## sigma 5832.06 1626.41 2666.04 8793.85 1.02 159 172
##
## Samples were drawn using sampling(NUTS). For each parameter, Bulk_ESS
## and Tail_ESS are effective sample size measures, and Rhat is the potential
## scale reduction factor on split chains (at convergence, Rhat = 1).
Estimate表示参数的后验分布(posterior means), Est.Error 表示估计的标准差
2.3.2 绘制后验分布密度图来观测模型的拟合效果
plot(fit1)
plot(fit2)
plot(fit3)
2.3.3 如何判断模型中参数拟合效果?
- 基本要求
- 所有参数的 Rhat < 1.01
- 没有过度发散的情况
- Bulk effective sample size(ESS) 和 tail ESS 要足够大,通常大于1500
如果没有足够收敛的话,可以
- 运行更多迭代(令iter = 2000 到 iter = 4000)
- 将 adapt_delta 增加到接近 1(.9、.95、.99、.999 等)
- 使用更强的先验(尤其是在较小的样本中)
2.3.4 后验预测检查
library(patchwork)
pp1 <- pp_check(fit1)
pp2 <- pp_check(fit2)
pp3 <- pp_check(fit3)
pp1 / pp2 / pp3
上图显示了模型预测(以较浅的线条,标记为 yrep)与原始结果变量(以较深的线条标记为 y)相比有些偏差。这意味着可以通过放宽正态假设来改进模型。然而,对于这种相对轻微的错误指定,推论仍然可以接受,因为固定效应系数通常在轻微违反正态性假设的情况下是稳健的。从上到下三个图对比可以发现Random slope 模型拟合效果最好。
2.3.5 模型比较
可以使用 loo() 函数来比较模型,该函数计算留一法交叉验证标准(LOO;是一种类似AIC,BIC,但Bayesian模型中评判效果更好的方法)。现在,只知道 LOO 值较小的模型应该是首选。这里比较普通模型,随机截距模型,随机斜率模型
LOO(fit1,fit2,fit3)
## Warning: Found 39 observations with a pareto_k > 0.7 in model 'fit2'. It is
## recommended to set 'moment_match = TRUE' in order to perform moment matching for
## problematic observations.
## Warning: Found 31 observations with a pareto_k > 0.7 in model 'fit3'. It is
## recommended to set 'moment_match = TRUE' in order to perform moment matching for
## problematic observations.
## Output of model 'fit1':
##
## Computed from 4000 by 62 log-likelihood matrix
##
## Estimate SE
## elpd_loo -650.6 5.3
## p_loo 2.9 0.8
## looic 1301.1 10.6
## ------
## Monte Carlo SE of elpd_loo is 0.0.
##
## Pareto k diagnostic values:
## Count Pct. Min. n_eff
## (-Inf, 0.5] (good) 61 98.4% 1531
## (0.5, 0.7] (ok) 1 1.6% 1227
## (0.7, 1] (bad) 0 0.0% <NA>
## (1, Inf) (very bad) 0 0.0% <NA>
##
## All Pareto k estimates are ok (k < 0.7).
## See help('pareto-k-diagnostic') for details.
##
## Output of model 'fit2':
##
## Computed from 4000 by 62 log-likelihood matrix
##
## Estimate SE
## elpd_loo -646.0 4.6
## p_loo 29.3 3.1
## looic 1292.0 9.3
## ------
## Monte Carlo SE of elpd_loo is NA.
##
## Pareto k diagnostic values:
## Count Pct. Min. n_eff
## (-Inf, 0.5] (good) 0 0.0% <NA>
## (0.5, 0.7] (ok) 23 37.1% 5
## (0.7, 1] (bad) 38 61.3% 5
## (1, Inf) (very bad) 1 1.6% 28
## See help('pareto-k-diagnostic') for details.
##
## Output of model 'fit3':
##
## Computed from 4000 by 62 log-likelihood matrix
##
## Estimate SE
## elpd_loo -644.0 4.3
## p_loo 32.8 3.1
## looic 1288.1 8.6
## ------
## Monte Carlo SE of elpd_loo is NA.
##
## Pareto k diagnostic values:
## Count Pct. Min. n_eff
## (-Inf, 0.5] (good) 0 0.0% <NA>
## (0.5, 0.7] (ok) 31 50.0% 20
## (0.7, 1] (bad) 30 48.4% 2
## (1, Inf) (very bad) 1 1.6% 3
## See help('pareto-k-diagnostic') for details.
##
## Model comparisons:
## elpd_diff se_diff
## fit3 0.0 0.0
## fit2 -2.0 1.9
## fit1 -6.5 2.2
具有随机斜率的模型的 LOO 最低,表明应包括随机斜率。一般来说,对于贝叶斯建模,建议包括所有随机斜率,因为它通常不会像频率论方法那样存在收敛问题。
Session info
## R version 4.0.3 (2020-10-10)
## Platform: x86_64-apple-darwin17.0 (64-bit)
## Running under: macOS Catalina 10.15.7
##
## Matrix products: default
## BLAS: /Library/Frameworks/R.framework/Versions/4.0/Resources/lib/libRblas.dylib
## LAPACK: /Library/Frameworks/R.framework/Versions/4.0/Resources/lib/libRlapack.dylib
##
## locale:
## [1] en_US.UTF-8/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8
##
## attached base packages:
## [1] stats graphics grDevices utils datasets methods base
##
## other attached packages:
## [1] patchwork_1.1.1 brms_2.15.0 Rcpp_1.0.6 showtext_0.9-2
## [5] showtextdb_3.0 sysfonts_0.8.3 see_0.6.3 cowplot_1.1.1
## [9] ggtext_0.1.1 ggplot2_3.3.3 performance_0.7.1 effectsize_0.4.4-1
## [13] emmeans_1.5.5-1 psych_2.0.12 data.table_1.14.0 forcats_0.5.1
## [17] stringr_1.4.0 tidyr_1.1.3 dplyr_1.0.5 rio_0.5.16
## [21] bruceR_0.6.4
##
## loaded via a namespace (and not attached):
## [1] backports_1.2.1 readxl_1.3.1 plyr_1.8.6
## [4] igraph_1.2.6 splines_4.0.3 crosstalk_1.1.1
## [7] TH.data_1.0-10 rstantools_2.1.1 inline_0.3.17
## [10] digest_0.6.27 htmltools_0.5.1.1 rsconnect_0.8.16
## [13] fansi_0.4.2 magrittr_2.0.1 openxlsx_4.2.3
## [16] RcppParallel_5.1.1 matrixStats_0.58.0 xts_0.12.1
## [19] sandwich_3.0-0 prettyunits_1.1.1 colorspace_2.0-0
## [22] haven_2.3.1 xfun_0.22 callr_3.6.0
## [25] crayon_1.4.1 jsonlite_1.7.2 lme4_1.1-26
## [28] survival_3.2-7 zoo_1.8-8 glue_1.4.2
## [31] gtable_0.3.0 V8_3.4.0 pkgbuild_1.2.0
## [34] rstan_2.21.2 abind_1.4-5 scales_1.1.1
## [37] mvtnorm_1.1-1 DBI_1.1.1 GGally_2.1.1
## [40] miniUI_0.1.1.1 xtable_1.8-4 gridtext_0.1.4
## [43] tmvnsim_1.0-2 foreign_0.8-81 StanHeaders_2.21.0-7
## [46] stats4_4.0.3 DT_0.17 htmlwidgets_1.5.3
## [49] threejs_0.3.3 RColorBrewer_1.1-2 ellipsis_0.3.1
## [52] pkgconfig_2.0.3 reshape_0.8.8 loo_2.4.1
## [55] farver_2.0.3 sass_0.3.1 utf8_1.2.1
## [58] tidyselect_1.1.0 labeling_0.4.2 rlang_0.4.10
## [61] reshape2_1.4.4 later_1.1.0.1 munsell_0.5.0
## [64] cellranger_1.1.0 tools_4.0.3 cli_2.4.0
## [67] generics_0.1.0 pacman_0.5.1 ggridges_0.5.3
## [70] evaluate_0.14 fastmap_1.1.0 yaml_2.2.1
## [73] processx_3.5.1 knitr_1.33 zip_2.1.1
## [76] purrr_0.3.4 nlme_3.1-152 mime_0.10
## [79] projpred_2.0.2 xml2_1.3.2 compiler_4.0.3
## [82] bayesplot_1.8.0 shinythemes_1.2.0 rstudioapi_0.13
## [85] curl_4.3 gamm4_0.2-6 tibble_3.1.1
## [88] statmod_1.4.35 bslib_0.2.4 stringi_1.5.3
## [91] ps_1.5.0 highr_0.9 parameters_0.13.0
## [94] Brobdingnag_1.2-6 lattice_0.20-41 Matrix_1.2-18
## [97] nloptr_1.2.2.2 markdown_1.1 shinyjs_2.0.0
## [100] vctrs_0.3.7 pillar_1.6.0 lifecycle_1.0.0
## [103] jquerylib_0.1.3 bridgesampling_1.0-0 estimability_1.3
## [106] insight_0.13.2 httpuv_1.5.5 R6_2.5.0
## [109] promises_1.1.1 gridExtra_2.3 codetools_0.2-16
## [112] boot_1.3-25 colourpicker_1.1.0 MASS_7.3-53
## [115] gtools_3.8.2 assertthat_0.2.1 withr_2.4.1
## [118] shinystan_2.5.0 mnormt_2.0.2 multcomp_1.4-16
## [121] mgcv_1.8-33 bayestestR_0.9.0 parallel_4.0.3
## [124] hms_1.0.0 grid_4.0.3 coda_0.19-4
## [127] minqa_1.2.4 rmarkdown_2.7 shiny_1.6.0
## [130] base64enc_0.1-3 dygraphs_1.1.1.6