代码主要核心思想来自:https://www.cnblogs.com/JadenFK3326/p/12164519.html
K折交叉交叉验证的过程如下:
以200条数据,十折交叉验证为例子,十折也就是将数据分成10组,进行10组训练,每组用于测试的数据为:数据总条数/组数,即每组20条用于valid,180条用于train,每次valid的都是不同的。
(1)将200条数据,分成按照 数据总条数/组数(折数),进行切分。然后取出第i份作为第i次的valid,剩下的作为train
(2)将每组中的train数据利用DataLoader和Dataset,进行封装。
(3)将train数据用于训练,epoch可以自己定义,然后利用valid做验证。得到一次的train_loss和 valid_loss。
(4)重复(2)(3)步骤,得到最终的 averge_train_loss和averge_valid_loss
上述过程如下图所示:
上述的代码如下:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader,Dataset
import torch.nn.functional as F
from torch.autograd import Variable
#####构造的训练集####
x = torch.rand(