Pytorch实现k折交叉验证
代码思路参考:https://blog.csdn.net/foneone/article/details/104445320
用pytorch实现k-fold cross validation
# 导入模块
import numpy as np
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset, TensorDataset
# 创建一个数据集
X = torch.rand(500, 32, 32)
Y = torch.rand(500, 1)
# random shuffle
index = [i for i in range(len(X))]
random.shuffle(index)
X = X[index]
Y = Y[index]
获取k折交叉验证某一折的训练集和验证集
def get_kfold_data(k, i, X, y):
# 返回第 i+1 折 (i = 0 -> k-1) 交叉验证时所需要的训练和验证数据,X_train为训练集,X_valid为验证集
fold_size = X.shape[0] // k # 每份的个数:数据总条数/折数(组数)
val_start = i * fold_size
if i != k - 1:
val_end = (i + 1) * fold_size
X_valid, y_valid = X[val_start:val_end], y[val_start:val_end]
X_train = torch.cat((X[0:val_start], X[val_end:]), dim = 0)
y_train = torch.cat((y[0:val_start], y[val_end:]), dim = 0)
else:

本文介绍如何利用PyTorch实现K折交叉验证,详细解析获取每一折的训练集和验证集过程,并展示模型训练的方法。内容参照了https://blog.csdn.net/foneone/article/details/104445320的相关实现。
最低0.47元/天 解锁文章

1070

被折叠的 条评论
为什么被折叠?



