1.X版本说明
此版本基于老师所给的MATLAB代码改写,并对隐藏层大小、数据读取方式进行修改,提高收敛速度以及准确率
一、数据处理
import scipy.io as scio
## import torch
import numpy as np
## import torchvision
## 使用GPU资源
## device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
1. 直接读取mat文件
path=''
data = scio.loadmat(path)
## 查看数据类型
data
{'__header__': b'MATLAB 5.0 MAT-file, Platform: MACI64, Created on: Wed Nov 20 21:23:10 2019',
'__version__': '1.0',
'__globals__': [],
'test_ima': array([[0, 0, 0, ..., 0, 0, 0],
[0, 0, 0, ..., 0, 0, 0],
[0, 0, 0, ..., 0, 0, 0],
...,
[0, 0, 0, ..., 0, 0, 0],
[0, 0, 0, ..., 0, 0, 0],
[0, 0, 0, ..., 0, 0, 0]], dtype=uint8),
'test_lab': array([[7],
[2],
[1],
...,
[4],
[5],
[6]], dtype=uint8),
'train_ima': array([[0, 0, 0, ..., 0, 0, 0],
[0, 0, 0, ..., 0, 0, 0],
[0, 0, 0, ..., 0, 0, 0],
...,
[0, 0, 0, ..., 0, 0, 0],
[0, 0, 0, ..., 0, 0, 0],
[0, 0, 0, ..., 0, 0, 0]], dtype=uint8),
'train_lab': array([[5],
[0],
[4],
...,
[5],
[6],
[8]], dtype=uint8)}
2. 将数据集转换为train、valid、test
## 制作训练集
train_ima=np.array(data['train_ima'])
train_lab=np.array(data['train_lab'])
length=len(train_lab)
train_feature=train_ima[:,:int(length*0.8)]
train_label=train_lab[:int(length*0.8):]
## 制作测试集
valid_feature=train_ima[:,int(length*0.8):]
valid_label=train_lab[int(length*0.8):]
## 制作检验集
test_feature=np.array(data['test_ima'])
test_label=np.array(data['test_lab'])
train_sum=len(train_label)
valid_sum=len(valid_label)
test_sum=len(test_label)
## 查看大小,确保划分的正确性
print(train_sum)
48000
3. 对MATLAB代码进行复现
import numpy as np
## 初始化模型参数
label_num=10
num_inputs = 784
num_hidden = 20
num_outputs = 10
W1 = np.random.normal(0, 0.01, (num_inputs, num_hidden))
b1 = np.zeros((num_hidden,1))
W2 = np.random.normal(0, 0.01, (num_hidden, num_outputs))
b2 = np.zeros((num_outputs,1))
## 设置梯度运算
rate1 = 0.05;
rate2 = 0.05; #置学习率
## 定义需要用到的函数
def sigmord(x):
output =1./(1+np.exp(-x))
return output
def forward(x,w,b):
y=np.dot(x.reshape(1,-1),w).T + b
return y
## 训练与检测
## 在这一部分中,针对原代码进行了两项修改:
## 用作三层神经网络的暂时存储变量
temp1 = np.zeros((num_hidden,1))
net = temp1
temp2 = np.zeros((num_outputs,1))
z = temp2
batch_size=720
epochs=1
for epoch in range(epochs):
print('epoch:',epoch)
for num in range(int(train_sum/batch_size)):
## 随机读取数据
rand=np.random.randint(0,train_sum-1,batch_size)
## 训练
for i in rand:
label = np.zeros((label_num,1))
label[int(train_label[i])]=1
## 前向计算
temp1=forward(train_feature[:,i],W1,b1)
#temp1=np.dot(train_feature[:,i].reshape(1,784),W1).T + b1
net = sigmord(temp1).T
temp2=forward(net,W2,b2)
#temp2 = np.dot(net , W2).T + b2
z = sigmord(temp2)
## 误差计算
error = label - z
deltaZ = (error*z*(1-z))
deltaNet =(net*(1-net)).T*np.dot(W2,deltaZ)
## 参数更新
for j in range(num_outputs):
W2[:,j] = W2[:,j] + rate2*deltaZ[j]*net
for j in range(num_hidden):
W1[:,j] = W1[:,j] + (rate1*deltaNet[j]*train_feature[:,i])
b2 = b2 + rate2*deltaZ
b1 = b1 + rate1*deltaNet
## 检测
valid_sum=60000*0.2
count = 0
if num%10==0:
for i in range(int(valid_sum)):
temp1=np.dot(valid_feature[:,i].reshape(1,784),W1).T + b1
net = sigmord(temp1).T
temp2 = np.dot(net , W2).T + b2
z = sigmord(temp2)
inx =np.argmax(z)
inx=inx
if inx == valid_label[i]:
count=count+1
correctRate=count/valid_sum
print('step: ',num,'acc:',correctRate)
epoch: 0
step: 0 acc: 0.27841666666666665
step: 10 acc: 0.8729166666666667
step: 20 acc: 0.8755
step: 30 acc: 0.8913333333333333
step: 40 acc: 0.8919166666666667
step: 50 acc: 0.8993333333333333
step: 60 acc: 0.89525
## test
count=0
for i in range(test_sum):
temp1=np.dot(test_feature[:,i].reshape(1,784),W1).T + b1
#print(temp1.shape[0])
net = sigmord(temp1).T
#print(net.shape[1])
temp2 = np.dot(net , W2).T + b2
z = sigmord(temp2)
inx =np.argmax(z)
inx=inx
if inx == test_label[i]:
count=count+1
correctRate=count/test_sum
print('the test\'s acc:',correctRate)
the test's acc: 0.9013
## 保存参数
import csv
file = open('mnist_4.0.csv','w',encoding='utf-8')
# 2. 基于文件对象构建 csv写入对象
csv_writer = csv.writer(file)
# 3. 构建列表头
csv_writer.writerow(["batch_size","epochs","w1","w2","b1","b2","num_hidden"])
csv_writer.writerow([512,10,W1,W2,b1,b2,num_hidden])
# 5. 关闭文件
file.close()
对个别参数进行了“调乱”,直接跑无法得到相应精度,自行完善
欢迎私信或评论区讨论