前言
K折交叉验证有助于模型选择和超参数调整。我们首先需要定义一个函数,在K折交叉验证过程中返回第i折的数据。具体而言,就是选择第i个切片作为验证数据,其余部分作为训练数据。
获取第i折数据
def get_k_fold_data(k,i,X,y):
assert k>1
fold_size=X.shape[0]//k # 每一折的大小
X_train,y_train=None,None
for j in range(k):
idx=slice(j*fold_size,(j+1)*fold_size) # 生成每一折内的序号
X_part,y_part=X[idx,:],y[idx]
if j==i:
X_valid,y_valid=X_part,y_part
elif X_train is None:
X_train,y_train=X_part,y_part
else:
X_train=torch.cat([X_train,X_part],0)
y_train=torch.cat([y_train,y_part],0)
return X_train,y_train,X_valid,y_valid
K折交叉训练
def k_fold(k,X_train,y_train,num_epochs,learning_rate,weight_decay,batch_size):
train_l_sum,valid_l_sum=0,0
for i in range(k):
data=get_k_fold_data(k,i,X_train,y_train)
net=get_net()
train_ls,valid_ls=train(net,*data,num_epochs,learning_rate,weight_decay,batch_size)
train_l_sum+=train_ls[-1]
valid_l_sum+=valid_ls[-1]
if i==0:
plt.plot(range(1,num_epochs+1),train_ls,label='train')
plt.plot(range(1,num_epochs+1),valid_ls,label='valid')
plt.yscale('log')
plt.xlabel('epoch')
plt.ylabel('rmse')
plt.xlim([1,num_epochs])
plt.legend()
plt.show()
print(f'折{i+1},训练log rmse{float(train_ls[-1]):f}, '
f'验证log rmse{float(valid_ls[-1]):f}')
return train_l_sum/k,valid_l_sum/k
总结
- 一组超参数的训练误差可能非常低,但K折交叉验证的误差高的多,这表明模型过拟合。
- 较小的过拟合表示现有数据可以支撑更复杂的模型,较大的过拟合则表明我们可以采用正则化技术来改善模型。