我们处理的数据,可能是seg-y文件、mat文件或者npy文件。
在读取数据到把数据输入网路之前,我们经常需要变换数据的维度,数据类型等。
在这个阶段,通常数据是ndarry类型。
以下总结常用的数据处理操作。
from scipy.io import loadmat
mnist_data = loadmat('../data/mnist.mat')
mnist_train = mnist_data['train_28']
mnist_test = mnist_data['test_28']
mnist_train = np.expand_dims(mnist_train, 3) # 增加矩阵维度
mnist_test = np.expand_dims(mnist_test, 3)
mnist_train = mnist_train.transpose((0, 3, 1, 2)).astype(np.float32) #矩阵各维度数据交换
mnist_test = mnist_test.transpose((0, 3, 1, 2)).astype(np.float32) #改变数据类型
mnist_labels_train = mnist_data['label_train']
mnist_labels_test = mnist_data['label_test']
train_label = np.argmax(mnist_labels_train, axis=1) # 降维 6000*1 变为6000
inds = np.random.permutation(mnist_train.shape[0]) # 打乱后随机排列
mnist_train = mnist_train[inds]
train_label = train_label[inds]
test_label = np.argmax(mnist_labels_test, axis=1)