训练集数据
http://www.pjreddie.com/media/files/mnist_train.csv
测试集数据
说明:每一行数据集第一个数据是正确的数字,其它均为像素值(28 * 28)。一般情况下,对于像素值,0表示黑,255表示白。但是mnist的数据恰好相反,所以后期对于自己手写的字体,要进行处理。
1、第一步,加载数据集,将其映射到0.01 - 1
data = np.loadtxt("mnist_train.csv", delimiter=',')
X_train = data[:, 1:]
X_train_scaled = (X_train / 255.0 * 0.99) + 0.01
y_train = data[:, 0].flatten()
2、第二步,训练
mlp = MLPClassifier(solver='lbfgs', activation='tanh', random_state=0, hidden_layer_sizes=[300])
mlp.fit(X_train_scaled, y_train)
3、第三步,测试训练集
test = np.loadtxt("mnist_test_10.csv", delimiter=',')
x_test = test[:, 1:]
x_test_scaled = (x_test / 255.0 * 0.99) + 0.01
y_test = test[:, 0].flatten()
print("predict:\n", mlp.predict(x_test_scaled))
print("true:\n", y_test)
4、第四步,识别自己手写的训练集
for i in range(0, 10):
img_array = imageio.imread("pic/" + str(i) + ".png", as_gray=True)
img_data = 255.0 - img_array.reshape(1, -1)
img_data_scaled = (img_data / 255.0 * 0.99) + 0.01
print("predict:" + str(i) + "res:", mlp.predict(img_data_scaled))
5、输出结果(存在一定误差)