Pytorch实现k折交叉验证
代码思路参考:https://blog.csdn.net/foneone/article/details/104445320
用pytorch实现k-fold cross validation
# 导入模块
import numpy as np
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset, TensorDataset
# 创建一个数据集
X = torch.rand(500, 32, 32)
Y = torch.rand(500, 1)
# random shuffle
index = [i for i in range(len(X))]
random.shuffle(index)
X = X[index]
Y = Y[index]
获取k折交叉验证某一折的训练集和验证集
def get_kfold_data(k, i, X, y):
# 返回第 i+1 折 (i = 0 -> k-1) 交叉验证时所需要的训练和验证数据,X_train为训练集,X_valid为验证集
fold_size = X.shape[0] // k # 每份的个数:数据总条数/折数(组数)
val_start = i * fold_size
if i != k - 1:
val_end = (i + 1) * fold_size
X_valid, y_valid = X[val_start:val_end], y[val_start:val_end]
X_train = torch.cat((X[0:val_start], X[val_end:]), dim = 0)
y_train = torch.cat((y[0:val_start], y[val_end:]), dim = 0)
else: