用正点原子ZYNQ7020实现mnist数据集的前向推理
项目流程说明
目前的神经网络并不适合在低功耗处理器上进行训练,但是在fpga中单纯用于任务处理还是有点搞头的。。本项目即实现了最最简单网络模型在FPGA上用于识别最最简单手写数字的tiny级任务。
一、生成预先计划的图片数据文件
从网络下载的mnist数据集包括两个csv数据文件,训练集和测试集,需要将其转换为所需的数据文件供PS端加载。
- 原始文件 :网络下载的mnist数据集包括两个csv数据文件,训练集和测试集。每个文件中的每一行有785个数据,第一列为标记,之后的是784个0~255之间的像素值;
- 需要用数据处理程序读取源文件的某一行并生成 具有特定分隔符 的数据文件;
- 仅作为测试用,一个文件只保存一幅图像数据即可,可以随机多生成几个文件,一是防止网络模型识别错误的小概率事件发生,二是相互对照便于比较PS端程序的稳定性;
- 数据均转化为0~1之间的浮点数,每个数据为4字节;
- 示例程序中使用python脚本读取csv文件,随机生成了5个数据文件。上代码:
import csv
import random
def normalize_pixel_value(value):
"""Normalize the pixel value to be between 0 and 1."""
return float(value) / 255.0
def read_random_row_from_csv(file_path):
"""Read a random row between 1 and 10 from the CSV file and return as a list of normalized floats, excluding the first value."""
with open(file_path, newline='') as csvfile:
reader = list(csv.reader(csvfile))
random_row_index = random.randint(1, 10) - 1 # Randomly select a row index between 0 and 9 (corresponding to rows 1 to 10)
random_row = reader[random_row_index] # Get the selected row
normalized_data = [normalize_pixel_value(value) for value in random_row[1:]] # Skip the first value and normalize the rest
return normalized_data
def write_data_to_file(file_path, data):
"""Write the data to a file with ',\n\r' as the separator."""
with open(file_path, 'w') as file:
for value in data:
file.write(f"{value},\n\r")
def main():
# Read a random row between 1 and 10 from src.csv
normalized_data = read_random_row_from_csv('data/mnist_test.csv')
# Write the normalized data to b.dat
write_data_to_file('out/a05.dat', normalized_data)
if __name__ == "__main__":
main()
二、搭建模型,训练神经网络,得到权重参数
- 搭建的神经网络为两个隐藏层、一个输入层、一个输出层,均为全连接网络层线性层,即y = A * x + B ;
- 输入层节点数为784,隐藏层节点数为64、32,输出层节点数为10;
- 共有3个网络权重文件和3个偏置数据文件,在网络训练完成后需要各个单独保存;
- 这6个网络数据文件需要固化在PL端的IP核中;
- 示例程序中使用python脚本实现了上述功能。上代码:
# Importing necessary libraries
import numpy
import scipy.special
import matplotlib.pyplot
# Defining the neural network class
class neuralNetwork:
# Initializing the neural network
def __init__(self, inputnodes, hiddennodes1, hiddennodes2, outputnodes, learningrate):
self.inodes = inputnodes
self.hnodes1 = hidde