先定义一个k交叉取训练集数据函数:
def get_k_fold(k,i,X,y):
assert k> 1
fold_size = X.shape[0]//k
X_train,y_train = None,None
for j in range(k):
ix = slice(j * fold_size ,(j+1) * fold_size)
#print('ix:',ix)
X_part, y_part = X[ix, :], y[ix]
if j == i:
print('[j==i]:j = %d,i = %d'%(j,i))
X_valid,y_valid = X_part,y_part
elif X_train is None:
print('[X_train is None]:j = %d,i = %d'%(j,i))
X_train ,y_train = X_part,y_part
else:
print('[X_train,X_part]:j = %d,i = %d'%(j,i))
X_train = torch.cat((X_train,X_part),dim=0)
y_train = torch.cat((y_train,y_part),dim=0)
print('X_train