以鸢尾花的数据为例:
import numpy as np
import matplotlib.pyplot as plt
from sklearn import datasets
iris = datasets.load_iris()
X = iris.data
y = iris.target
X.shape
y.shape
### train_test_split
y
对第0到149个索引进行乱序排列
shuffle_indexes = np.random.permutation(len(X))
shuffle_indexes
test_ratio = 0.2
test_size = int(len(X)*test_ratio)
test_size
test_indexes = shuffle_indexes[:test_size]
train_indexes = shuffle_indexes[test_size:]
X_train = X[train_indexes]
y_train = y[train_indexes]
X_test = X[test_indexes]
y_test = y[test_indexes]
print(X_train.shape)
print(y_train.shape)
print(X_test.shape)
print(y_test.shape)
### sklearn中的train_test_split
from sklearn .model_selection import train_test_split
X_train,X_test,y_train,y_test = train_test_split(X,y,test_size=0.2,random_state=666)
print(X_train.shape)
print(y_train.shape)
print(X_test.shape)
print(y_test.shape)
二者结果一致