import numpy as np
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('data/',one_hot=True)
train_imgs = mnist.train.images
train_labels = mnist.train.labels
test_imgs = mnist.test.images
test_label_imgs = mnist.test.labels
# 取训练数据的20%
validate_datasets = 0.2
# 打乱的索引序列
permutation = np.random.permutation(train_labels.shape[0])
validate_indexs = permutation[:int(train_labels.shape[0]*validate_datasets)]
train_indexs = permutation[int(train_labels.shape[0]*validate_datasets):]
x_train_imgs = train_imgs[train_indexs,:]
y_train_labels = train_labels[train_indexs,:]
validate_imgs = train_imgs[validate_indexs,:]
validate_labels = train_labels[validate_indexs,:]
数据预处理之打乱数据集
最新推荐文章于 2022-07-02 11:58:18 发布