tensorflow加载训练好的模型测试数据

总感觉耐不下心来好好学这些东西。
计算图:TensorFlow中的所有计算都会被转化成计算图上的节点。
张量:TensorFlow中,所有的数据都通过张量的形式来表示,可以简单地理解成多维数组。
会话:TensorFlow中使用会话来执行定义好的计算。
使用训练好的模型,最主要就是要得到输入和输出的接口。
方法一:首先是输入。在训练网络时要给输入赋予name,比如:

inputs_ = tf.placeholder(tf.float32, shape=[None, codes.shape[1]],name="input")

最后的 name=“input” 一定要加上,这样训练完成后就可以使用:

inputs_ = tf.get_default_graph().get_operation_by_name('input').outputs[0]

来得到输入的接口(输入的tensor)。
输出同理。比如一个卷积神经网络做分类任务,需要的输出为:

predicted = tf.nn.softmax(logits)

可以改写成:

predicted = tf.nn.softmax(logits,name="predicted")

这样就可以在加载模型后使用:

predicted = tf.get_default_graph().get_operation_by_name('predicted').outputs[0]
sess.run(predicted)

得到预测结果。

方法二:在训练网络时使用tf.add_to_collection将需要的节点加入到一个自己定义名字的列表中,比如:

tf.add_to_collection('input', inputs_)

这样在加载模型后,使用:

inputs_ = tf.get_collection('input')[0]

就可以得到输入的接口(输入的tensor)。
输出同理。使用:

tf.add_to_collection('predicted', predicted)

这样在加载模型后就可以使用:

predicted = tf.get_collection('predicted')[0]
sess.run(predicted)

得到预测结果。

tf.get_default_graph().get_operation_by_name
tf.add_to_collection
tf.get_collection
他们三个的具体参数就不写了,去网上找吧。(其实是我也没弄明白)

以上均为一个初学者记录以便用到时易于查询,有什么不对的地方,希望不要误导了各位看官。

  • 1
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值