感谢 以下四篇文章,让我比较深刻了解了卷积神经网络
CNN 卷积神经网络推导和实现 http://blog.csdn.net/zouxy09/article/details/9993371
c++ 实现卷积神经网络 http://www.codeproject.com/Articles/16650/Neural-Network-for-Recognition-of-Handwritten-Digi
python 实现卷积神经网络 http://deeplearning.net/tutorial/lenet.html
手写识别的例子:http://www.csdn.net/article/1970-01-01/2825549
想自己实现一个简单版本的CNN, 中间碰到了些问题,留着以后有时间再实现(给自己找了个不重复制造轮子的理由)。
今天主要是想使用版本Lasagne来实现手写识别,把准确率从97%提升到99%。
实现代码:
from datetime import datetime
from time import clock
import lasagne
import numpy as np
from lasagne import layers
from lasagne.updates import nesterov_momentum
from nolearn.lasagne import NeuralNet
from sklearn.metrics import classification_report
net2 = NeuralNet(
layers=[('input', layers.InputLayer),
('conv2d1', layers.Conv2DLayer),
('maxpool1', layers.MaxPool2DLayer),
('conv2d2', layers.Conv2DLayer),
('maxpool2', layers.MaxPool2DLayer),
('dropout1', layers.DropoutLayer),
('dense', layers.DenseLayer),
('dropout2', layers.DropoutLayer),
('output', layers.DenseLayer),
],
# input layer
input_shape=(None, 1, 28, 28),
# layer conv2d1
conv2d1_num_filters=32,
conv2d1_filter_size=(5, 5),
conv2d1_nonlinearity=lasagne.nonlinearities.rectify,
conv2d1_W=lasagne.init.GlorotUniform(),
# layer maxpool1
maxpool1_pool_size=(2, 2),
# layer conv2d2
conv2d2_num_filters=60,
conv2d2_filter_size=(5, 5),
conv2d2_nonlinearity=lasagne.nonlinearities.rectify,
# layer maxpool2
maxpool2_pool_size=(2, 2),
# dropout1
dropout1_p=0.5,
# dense
dense_num_units=500,
dense_nonlinearity=lasagne.nonlinearities.rectify,
# dropout2
dropout2_p=0.5,
# output
output_nonlinearity=lasagne.nonlinearities.softmax,
output_num_units=10,
# optimization method params
update=nesterov_momentum,
update_learning_rate=0.01,
update_momentum=0.9,
max_epochs=10,
verbose=1,
)
def load_source(filename):
with open(filename, "r") as file:
lines = file.readlines()
return lines[1:]
data_lines = load_source("./data/train.csv")
for i in range(len(data_lines)):
data_lines[i] = data_lines[i].split(',')
data_lines = np.array(data_lines).astype(np.float32)
x_data = data_lines[:, 1:].reshape((len(data_lines), 1, 28, 28))
y_data = data_lines[:, 0].astype(np.int32)
x_data /= np.float32(256)
X_train = x_data[:-3000]
y_train = y_data[:-3000]
X_test = x_data[-3000:]
y_test = y_data[-3000:]
np.set_printoptions(suppress=True, linewidth=175, precision=3)
# Train the network
print(datetime.now().strftime('%b-%d-%y %H:%M:%S'), "start trans net0")
start = clock()
net2.fit(X_train, y_train)
preds = net2.predict(X_test)
print(classification_report(y_test, preds))
print(datetime.now().strftime('%b-%d-%y %H:%M:%S'), "end net0")
end = clock()
print("net0 : %.3f s" % (end-start))
输出结果:
# Neural Network with 534402 learnable parameters
## Layer information
# name size
--- -------- --------
0 input 1x28x28
1 conv2d1 32x24x24
2 maxpool1 32x12x12
3 conv2d2 60x8x8
4 maxpool2 60x4x4
5 dropout1 60x4x4
6 dense 500
7 dropout2 500
8 output 10
epoch trn loss val loss trn/val valid acc dur
------- ---------- ---------- --------- ----------- ------
1 0.74042 0.15038 4.92376 0.95527 61.06s
2 0.20932 0.10588 1.97692 0.96732 62.03s
3 0.15610 0.08767 1.78063 0.97270 60.36s
4 0.12981 0.07162 1.81254 0.97783 59.70s
5 0.11279 0.06259 1.80196 0.97975 61.17s
6 0.09886 0.05699 1.73463 0.98180 59.02s
7 0.09049 0.05362 1.68761 0.98231 60.36s
8 0.08104 0.04864 1.66604 0.98488 61.23s
9 0.07727 0.04682 1.65020 0.98398 60.72s
10 0.06763 0.04635 1.45896 0.98552 60.29s
precision recall f1-score support
0 0.98 0.99 0.99 327
1 0.99 0.98 0.99 330
2 0.98 0.96 0.97 284
3 0.99 0.99 0.99 300
4 0.99 0.99 0.99 315
5 1.00 0.98 0.99 242
6 0.98 1.00 0.99 317
7 0.97 0.99 0.98 304
8 0.99 0.98 0.98 283
9 0.98 0.98 0.98 298
avg / total 0.99 0.99 0.99 3000
可以调整配置来实现新的网络,很是方便。
后来把识别错误的数据打印出来,发现识别率还可以提有提升空间, 以后有时间再来细研究.
print("net0 : %.3f s" % (end-start))
print("error cnt:", len([item for i, item in enumerate(y_test) if y_test[i] != preds[i]]))
for i in range(len(y_test)):
if y_test[i] == preds[i]:
continue
print(i, "error", y_test[i], preds[i])
plt.imshow(X_test[i][0], cmap=cm.binary)
plt.savefig("/XXX/error/%d-%d-%d" % (i, y_test[i], preds[i]))