python打乱训练集测试集数据的函数
在深度学习的编程中,时常要遇到打乱数据集测试集的场景,tensorflow和其他框架可能都会提供了不同的接口去实现这个功能,不过我都感觉太麻烦了,百度上面也没有比较好的封装函数。所以自己写了一个自己觉得比较好用的函数。可以直接使用。可供参考。
直接上代码
'''
Args:
data:数据
label:标签
deepth:转化成one_hot变量时传入的参数如果为-1则不转化为onehot
tag:标记参数,如果tag为1 则将训练集数据转化成numpy数组并且将标签转化为onehot变量
'''
def data_split(data,label,deepth=-1,tag=1):
len(data)
total_label=len(label)
label_index=np.arange(0,total_label)
np.random.shuffle(label_index)
train_num=int(total_label*0.8)
X_train=[]
y_train=[]
#打乱训练集
for i in label_index[:train_num]:
X_train.append(data[i])
y_train.append(label[i])
X_test=[]
y_test=[]
#打乱测试集
for i in label_index[train_num:]:
X_test.append(data[i])
y_test.append(label[i])
if tag==1 and depth!=-1:
X_train=np.asarray(X_train)
X_test=np.asarray(X_test)
y_train=tf.one_hot(y_train,depth=deepth)
y_test=tf.one_hot(y_test,depth=deepth)
return X_train, X_test, y_train, y_test
可以直接使用代码。
#需要导入的包
import numpy as np
import tensorflow as tf
源数据是50000的长度,按照8:2的比例分割训练集和测试集。
随机打乱的功能在函数中已经实现了,可以直接调用函数传入参数后使用。非常好用。
(大佬勿喷,自己随便写的)