k_fold_cv函数——bartMachine包内函数详解
R bartMachine包下载
R 所有包下载地址 :
https://cran.r-project.org/web/packages/available_packages_by_name.html
R bartMachine包下载地址 :
https://cran.r-project.org/web/packages/bartMachine/index.html
一、 函数包内原文
1、k_fold_cv函数底层代码
注:直接运行 k_fold_cv 即可得到
function (X, y, k_folds = 5, folds_vec = NULL, verbose = FALSE,
...)
{
args = list(...)
args$serialize = FALSE
if (class(X) != "data.frame") {
stop("The training data X must be a data frame.")
}
if (!(class(y) %in% c("numeric", "integer", "factor"))) {
stop("Your response must be either numeric, an integer or a factor with two levels.\n")
}
if (!is.null(folds_vec) & class(folds_vec) != "integer")
stop("folds_vec must be an a vector of integers specifying the indexes of each folds.")
y_levels = levels(y)
if (class(y) == "numeric" || class(y) == "integer") {
pred_type = "regression"
}
else if (class(y) == "factor" & length(y_levels) ==
2) {
pred_type = "classification"
}
n = nrow(X)
Xpreprocess = pre_process_training_data(X)$data
p = ncol(Xpreprocess)
if (is.null(folds_vec)) {
if (k_folds == Inf) {
k_folds = n
}
if (k_folds <= 1 || k_folds > n) {
stop("The number of folds must be at least 2 and less than or equal to n, use \"Inf\" for leave one out")
}
temp = rnorm(n)
folds_vec = cut(temp, breaks = quantile(temp, seq(0,
1, length.out = k_folds + 1)), include.lowest = T,
labels = F)
}
else {
k_folds = length(unique(folds_vec))
}
if (pred_type == "regression") {
L1_err = 0
L2_err = 0
yhat_cv = numeric(n)
}
else {
phat_cv = numeric(n)
yhat_cv = factor(n, levels = y_levels)
confusion_matrix = matrix(0, nrow = 3, ncol = 3)
rownames(confusion_matrix) = c(paste("actual",
y_levels), "use errors")
colnames(confusion_matrix) = c(paste("predicted",
y_levels), "model errors")
}
Xy = data.frame(Xpreprocess, y)
for (k in 1:k_folds) {
cat(".")
train_idx = which(folds_vec != k)
test_idx = setdiff(1:n, train_idx)
test_data_k = Xy[test_idx, ]
training_data_k = Xy[train_idx, ]
bart_machine_cv = do.call(build_bart_machine, c(list(X = training_data_k[,