pd_plot函数——bartMachine包内函数详解
R bartMachine包下载
R 所有包下载地址 :
https://cran.r-project.org/web/packages/available_packages_by_name.html
R bartMachine包下载地址 :
https://cran.r-project.org/web/packages/bartMachine/index.html
一、 函数包内原文
1、pd_plot函数底层代码
注:直接运行 pd_plot 即可得到
function (bart_machine, j, levs = c(0.05, seq(from = 0.1, to = 0.9,
by = 0.1), 0.95), lower_ci = 0.025, upper_ci = 0.975, prop_data = 1)
{
check_serialization(bart_machine)
if (class(j) == "integer") {
j = as.numeric(j)
}
if (class(j) == "numeric" && (j < 1 || j > bart_machine$p)) {
stop(paste("You must set j to a number between 1 and p =",
bart_machine$p))
}
else if (class(j) == "character" && !(j %in% bart_machine$training_data_features)) {
stop("j must be the name of one of the training features (see \"<bart_model>$training_data_features\")")
}
else if (!(class(j) == "numeric" || class(j) == "character")) {
stop("j must be a column number or column name")
}
x_j = bart_machine$model_matrix_training_data[, j]
if (length(unique(na.omit(x_j))) <= 1) {
warning("There must be more than one unique value in this training feature. PD plot not generated.")
return()
}
x_j_quants = unique(quantile(x_j, levs, na.rm = TRUE))
if (length(unique(x_j_quants)) <= 1) {
warning("There must be more than one unique value among the quantiles selected. PD plot not generated.")
return()
}
n_pd_plot = round(bart_machine$n * prop_data)
bart_predictions_by_quantile = array(NA, c(length(x_j_quants),
n_pd_plot, bart_machine$num_iterations_after_burn_in))
for (q in 1:length(x_j_quants)) {
indices = sample(1:bart_machine$n, n_pd_plot)
test_data = bart_machine$X[indices, ]
test_data