神经网络学习小记录15——tf利用训练好的模型进行预测

神经网络学习小记录15——tf利用训练好的模型进行预测

学习前言

神经网络学习小记录14——slim常用函数与如何训练、保存模型文章里已经讲述了如何使用slim训练出来一个模型,这篇文章将会讲述如何预测。
在这里插入图片描述

载入模型思路

载入模型的过程主要分为以下四步:
1、建立会话Session;
2、将img_input的placeholder传入网络,建立网络结构;
3、初始化所有变量;
4、利用saver对象restore载入所有参数。

这里要注意的重点是,在利用saver对象restore载入所有参数之前,必须要建立网络结构,因为网络结构对应着cpkt文件中的参数。
(网络层具有对应的名称scope。)
在这里插入图片描述

实现代码

在运行实验代码前,可以前往我的github下载代码,因为存在许多依赖的文件https://github.com/bubbliiiing/Mnist-recognition-By-Slim。

import tensorflow as tf
import numpy as np
from nets import Net
from tensorflow.examples.tutorials.mnist import input_data

def compute_accuracy(x_data,y_data):
    global prediction
    y_pre = sess.run(prediction,feed_dict={img_input:x_data})
    
    correct_prediction = tf.equal(tf.arg_max(y_data,1),tf.arg_max(y_pre,1))

    accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))

    result = sess.run(accuracy,feed_dict = {img_input:x_data})
    return result

mnist = input_data.read_data_sets("MNIST_data",one_hot = "true")

slim = tf.contrib.slim

# img_input的placeholder
img_input = tf.placeholder(tf.float32, shape = (None, 784))
img_reshape = tf.reshape(img_input,shape = (-1,28,28,1))

# 载入模型
sess = tf.Session()

Conv_Net = Net.Conv_Net()
# 将img_input的placeholder传入网络
prediction = Conv_Net.net(img_reshape)

# 载入模型
ckpt_filename = './logs/model.ckpt-20000'

# 初始化所有变量
sess.run(tf.global_variables_initializer())
saver = tf.train.Saver()

# 恢复
saver.restore(sess, ckpt_filename)

print(compute_accuracy(mnist.test.images,mnist.test.labels))

运行结果为:

0.9921
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Bubbliiiing

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值