Pytorch实现k折交叉验证

本文介绍如何利用PyTorch实现K折交叉验证,详细解析获取每一折的训练集和验证集过程,并展示模型训练的方法。内容参照了https://blog.csdn.net/foneone/article/details/104445320的相关实现。
摘要由CSDN通过智能技术生成

Pytorch实现k折交叉验证

代码思路参考:https://blog.csdn.net/foneone/article/details/104445320
用pytorch实现k-fold cross validation

# 导入模块
import numpy as np
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset, TensorDataset
# 创建一个数据集
X = torch.rand(500, 32, 32)
Y = torch.rand(500, 1)
# random shuffle
index = [i for i in range(len(X))] 
random.shuffle(index)
X = X[index]
Y = Y[index]

获取k折交叉验证某一折的训练集和验证集

def get_kfold_data(k, i, X, y):  
     
    # 返回第 i+1 折 (i = 0 -> k-1) 交叉验证时所需要的训练和验证数据,X_train为训练集,X_valid为验证集
    fold_size = X.shape[0] // k  # 每份的个数:数据总条数/折数(组数)
    
    val_start = i * fold_size
    if i != k - 1:
        val_end = (i + 1) * fold_size
        X_valid, y_valid = X[val_start:val_end], y[val_start:val_end]
        X_train = torch.cat((X[0:val_start], X[val_end:]), dim = 0)
        y_train = torch.cat((y[0:val_start], y[val_end:]), dim = 0)
    else:
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值