上周写完了该代码,但是由于没有注意到softmax相关的实现故结果不对,更正后可以得到正确结果,用200幅图片迭代200次可以得到90%以上的正确率,参数设置还有待于优化,另外可以考虑用多线程加速,此处目前还有问题(有待于修改,慎用)。
推导请参考之前的文章http://blog.csdn.net/xuanyuansen/article/details/41214115。
用MSE作为目标函数,也可以得到很好的结果,只是需要迭代的次数较多,400幅图片,迭代2000次,训练的正确率是96.75%。
#coding=utf-8
'''
Created on 2014��11��15��
@author: wangshuai13
'''
import numpy
#import matplotlib.pyplot as plt
import struct
import math
import random
import time
import threading
class MyThread(threading.Thread):
def __init__(self,threadname,tANN,idx_start,idx_end):
threading.Thread.__init__(self,name=threadname)
self.ANN=tANN
self.idx_start=idx_start
self.idx_end=idx_end
def run(self):
cDetaW,cDetaB,cError=self.ANN.backwardPropogation(self.ANN.traindata[self.idx_start],0)
for idx in range(self.idx_start+1,self.idx_end):
DetaWtemp,DetaBtemp,Errortemp=self.ANN.backwardPropogation(self.ANN.traindata[idx],idx)
cError += Errortemp
#cDetaW += DetaWtemp
#cDetaB += DetaBtemp
for idx_W in range(0,len(cDetaW)):
cDetaW[idx_W] += DetaWtemp[idx_W]
for idx_B in range(0,len(cDetaB)):
cDetaB[idx_B] += DetaBtemp[idx_B]
return cDetaW,cDetaB,cError
def sigmoid(inX):
return 1.0/(1.0+math.exp(-inX))
def softmax(inMatrix):
m,n=numpy.shape(inMatrix)
outMatrix=numpy.mat(numpy.zeros((m,n)))
soft_sum=0
for idx in range(0,n):
outMatrix[0,idx] = math.exp(inMatrix[0,idx])
soft_sum += outMatrix[0,idx]
for idx in range(0,n):
outMatrix[0,idx] /= soft_sum
return outMatrix
def tangenth(inX):
return (1.0*math.exp(inX)-1.0*math.exp(-inX))/(1.0*math.exp(inX)+1.0*math.exp(-inX))
def difsigmoid(inX):
return sigmoid(inX)*(1.0-sigmoid(inX))
def sigmoidMatrix(inputMatrix):
m,n=numpy.shape(inputMatrix)
outMatrix=numpy.mat(numpy.zeros((m,n)))
for idx_m in range(0,m):
for idx_n in range(0,n):
outMatrix[idx_m,idx_n]=sigmoid(inputMatrix[idx_m,idx_n])
return outMatrix
def loadMNISTimage(absFilePathandName,datanum=60000):
images=open(absFilePathandName,'rb')
buf=images.read()
index=0
magic, numImages , numRows , numColumns = struct.unpack_from('>IIII' , buf , index)
print magic, numImages , numRows , numColumns
index += struct.calcsize('>IIII')
if magic != 2051: