这段代码是用ANN进行OCR的,它包含一个隐藏层,输入的是28x28大小的图像。代码运行时没有任何错误,但是即使在训练了5000多张图像之后,输出也不准确。我使用的是jpg图像形式的mnist数据集。请告诉我我的逻辑出了什么问题。在import numpy as np
from PIL import Image
import random
from random import randint
y = [[0,0,0,0,0,0,0,0,0,0]]
W1 = [[ random.uniform(-1, 1) for q in range(40)] for p in range(784)]
W2 = [[ random.uniform(-1, 1) for q in range(10)] for p in range(40)]
def sigmoid(x):
global b
return (1.0 / (1.0 + np.exp(-x)))
#run the neural net forward
def run(X, W):
return sigmoid(np.matmul(X,W)) #1x2 * 2x2 = 1x1 matrix
#cost function
def cost(X, y, W):
nn_output = run(X, W)
return ((nn_output - y))
def gradient_Descent(X,y,W1,W2):
alpha = 0.12 #learning rate
epochs = 15000 #num iterations
for i in range(epochs):
Z2=sigmoid(np.matmul(run(X,W1),W2)) #final activation function(1X10))
Z1=run(X,W1) #first activation function(1X40)
phi1=Z1*(1-Z1) #differentiation of Z1
phi2=Z2*(1-Z2) #differentiation of Z2
delta2 = phi2*cost(Z1,y,W2) #delta for outer layer(1X10)
delta1 = np.transpose(np.transpose(phi1)*np.matmul(W2,np.transpose(delta2)))
deltaW2 = alpha*(np.matmul(np.transpose(Z1),delta2))
deltaW1 = alpha*(np.matmul(np.transpose(X),delta1))
W1=W1+deltaW1
W2=W2+deltaW2
def Training():
for j in range(8):
y[0][j]=1
k=1
while k<=15: #5421
print(k)
q=0
img = Image.open('mnist_jpgfiles/train/mnist_'+str(j)+'_'+str(k)+'.jpg')
iar = np.array(img) #image array
ar=np.reshape(iar,(1,np.product(iar.shape)))
ar=np.array(ar,dtype=float)
X = ar
'''
for p in range(784):
if X[0][p]>0:
X[0][p]=1
else:
X[0][p]=0
'''
k+=1
gradient_Descent(X,y,W1,W2)
print(np.argmin(cost(run(X,W1),y,W2)))
#print(W1)
y[0][j]=0
Training()
def test():
global W1,W2
for j in range(3):
k=1
while k<=5: #890
img = Image.open('mnist_jpgfiles/test/mnist_'+str(j)+'_'+str(k)+'.jpg')
iar = np.array(img) #image array
ar=np.reshape(iar,(1,np.product(iar.shape)))
ar=np.array(ar,dtype=float)
X = ar/256
'''
for p in range(784):
if X[0][p]>0:
X[0][p]=1
else:
X[0][p]=0
'''
k+=1
print("Should be "+str(j))
print((run(run(X,W1),W2)))
print((np.argmax(run(run(X,W1),W2))))
print("Testing.....")
test()