一、介绍
输入为28x28的图片的浮点型数据
输出0-9的概率值
二、代码
#导入tensorflow库
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import numpy as np
#下载数据集
mnist = input_data.read_data_sets('MNIST_data/', one_hot=True)
#输入
x = tf.placeholder(tf.float32, [None, 784])
#目标输出
y = tf.placeholder(tf.float32, [None, 10])
#第一层实现
w1 = tf.Variable(tf.truncated_normal([784,64]), dtype = tf.float32)
b1 = tf.Variable(tf.zeros([1,64]), dtype = tf.float32)
a1 = tf.nn.relu(tf.matmul(x, w1) + b1)
#第二层实现
w2 = tf.Variable(tf.truncated_normal([64,10]), dtype = tf.float32)
b2 = tf.Variable(tf.zeros([1,10]), dtype = tf.float32)
y_ = tf.nn.softmax(tf.matmul(a1,w2) + b2)
#计算偏差 代入损失函数 梯度下降
#损失函数
loss = tf.reduce_mean(-tf.reduce_sum(y*tf.log(y_), axis = 1))
#梯度下降
train_step = tf.train.AdamOptimizer(0.0001).minimize(loss)
#初始化操作
#初始化参数节点操作
init = tf.global_variables_initializer()
#启动会话
sess = tf.Session()
sess.run(init)
#定义损失值和正确的识别率
#损失=实际输出与目标的差值
correct_prediction = tf.equal(tf.argmax(y_, axis = 1), tf.argmax(y, axis = 1))
#正确率
accuracy = tf.reduce_mean(tf.cast(correct_prediction ,tf.float32))
for i in range(1000):
#从数据集里面读取100张图片和标签
batch_xs, batch_ys = mnist.train.next_batch(100)
sess.run(train_step, feed_dict = {x:batch_xs, y:batch_ys})
if(i % 100 == 0):
print("损失值",sess.run(loss, feed_dict={x: mnist.test.images, y: mnist.test.labels}))
print('识别正确率',sess.run(accuracy, feed_dict={x: mnist.test.images, y: mnist.test.labels}))
#导入图片处理相关库
from PIL import Image
from itertools import chain
#输入一张图片,转换为张量
test_img = Image.open("./2.bmp") #打开图片
test_img_array = np.array(test_img)#将图片数据存放到数组
test_img_array = np.asarray(test_img_array, dtype="float32")#转换数组元素的数据类型
#转换数组成张量
test_img_array = list(chain.from_iterable(test_img_array))#一阶
test_img_array = [test_img_array]#二阶
#还原计算图
layer1 = tf.nn.relu(tf.matmul(test_img_array, w1) + b1)
y_test = tf.nn.softmax(tf.matmul(layer1,w2) + b2)
#启动会话推理
#打印结果
print("识别结果:")
print(sess.run(y_test))
print("保存图片和参数:")
with open("img2.h", "w") as f:
img_save = str(test_img_array)
str1 = img_save.replace("[", "")
str2 = str1.replace("]", "")
f.write("img2[784]={"+str2+"};")
with open("w1.h", "w") as f:
w1_value = w1.eval(session=sess)#获取张量的值
str1 = str(np.transpose(w1_value).tolist())#转换为字符
str2 = str1.replace("[", "")
str3 = str2.replace("]", "")
f.write("w1[784*64]={"+str3+"};")
#f.write(str1)
print("保存完成")