神经网络的训练通常要划分训练集、验证集、测试集。训练集用来对模型参数进行调整、验证集用来选出泛化较好的模型(参数),测试用来检验模型的泛化性能。因此训练集、验证集、测试集并不参与模型的训练。
神经网络的训练通常需要大量数据,当已有数据较少时,或者测试集中的数据不完整,可以将本用于训练模型的数据划出一部分用于模型的验证与测试,本文主要介绍对数据处理完成,即特征提取,标签添加,得到输入样本集的前提下,对这个样本集进行6:2:2的划分,分别作训练集、验证集、测试集。
首先生成一个简单样本集,大小为50,这样按照我们之前的划分比例,训练集大小为30,验证与测试集大小均为10。
number = np.arange(51,151).reshape(50,2)
name = ['feature_1','feature_2']
data = pd.DataFrame(columns=name,data=number)
data.insert(0,'ID',range(1,51))
num = len(data)
调用split_data函数进行样本划分。
ef split_data(data,num,valid_ratio,test_ratio):
shuffled_indices = np.random.permutation(range(1,num+1))
valid_set_size=int(num*valid_ratio)
test_set_size=int(num*test_ratio)
valid_indices =sorted(shuffled_indices[:valid_set_size])
test_indices = sorted(shuffled_indices[valid_set_size:(test_set_size+valid_set_size)])
train_indices = sorted(shuffled_indices[(test_set_size+valid_set_size):])
train = pd.concat(data[data['ID'] == id] for id in train_indices).reset_index(drop=True)
valid = pd.concat(data[data['ID'] == id] for id in valid_indices).reset_index(drop=True)
test = pd.concat(data[data['ID'] == id] for id in test_indices).reset_index(drop=True)
return train, valid, test
完整代码如下:
import pandas as pd
import numpy as np
def split_data(data,num,valid_ratio,test_ratio):
shuffled_indices = np.random.permutation(range(1,num+1))
valid_set_size=int(num*valid_ratio)
test_set_size=int(num*test_ratio)
valid_indices =sorted(shuffled_indices[:valid_set_size])
test_indices = sorted(shuffled_indices[valid_set_size:(test_set_size+valid_set_size)])
train_indices = sorted(shuffled_indices[(test_set_size+valid_set_size):])
train = pd.concat(data[data['ID'] == id] for id in train_indices).reset_index(drop=True)
valid = pd.concat(data[data['ID'] == id] for id in valid_indices).reset_index(drop=True)
test = pd.concat(data[data['ID'] == id] for id in test_indices).reset_index(drop=True)
return train, valid, test
if __name__ == "__main__":
number = np.arange(51,151).reshape(50,2)
name = ['feature_1','feature_2']
data = pd.DataFrame(columns=name,data=number)
data.insert(0,'ID',range(1,51))
num = len(data)
train_data, valid_data, test_data = split_data(data,num,0.2,0.2)