import pickle
import gzip
import numpy as np
def load_data():
f = gzip.open('D://dataset/mnist.pkl.gz', 'rb')
training_data, validation_data, test_data = pickle.load(f, encoding='bytes')
f.close()
return training_data, validation_data, test_data
def vectorized_result(j):
e = np.zeros((10, 1))
e[j] = 1.0
return e
def load_data_wrapper():
tr_data, va_data, te_d = load_data()
# training data
training_inputs = [np.reshape(x, (784, 1)) for x in tr_data[0]]
training_results = [vectorized_result(y) for y in tr_data[1]]
training_data = zip(training_inputs, training_results)
# validation data
validation_inputs = [np.reshape(x, (784, 1)) for x in va_data[0]]
validation_results = [vectorized_result(y) for y in va_data[1]]
validation_data = zip(validation_inputs, validation_results)
# test_data
test_inputs = [np.reshape(x, (784, 1)) for x in te_d[0]]
test_data = zip(test_inputs, te_d[1])
return(training_data, validation_data, test_data)
def get_Dataset():
t, v, tt = load_data_wrapper()
validation_data = list(v)
# training_data = list(t) + validation_data
training_data = list(t)
testing_data = list(tt)
# len_t = len(training_data)
# len_tdi = len(training_data[0][0])
# len_tl = len(training_data[0][1])
# x_train = np.zeros((len_t, len_tdi))
# y_train = np.zeros((len_t, len_tl))
# for i in range(len_t):
# x_train[i] = np.array(training_data[i][0]).transpose()
# y_train[i] = np.array(training_data[i][1]).transpose()
#
# len_tt = len(testing_data)
# x_test = np.zeros((len_tt, len_tdi))
# y_test = np.zeros(len_tt)
# for i in range(len_tt):
# x_test[i] = np.array(testing_data[i][0]).transpose()
# y_test[i] = testing_data[i][1]
# return x_train, y_train, x_test, y_test
return training_data, validation_data, testing_data
mnist数据集的处理和获取load_data,load_data_wrapper
最新推荐文章于 2023-09-30 12:44:41 发布