tensorflow如何使用训练好的模型做测试

方法1

保存
1、定义变量
2、使用saver.save()方法保存

import tensorflow as tf  
import numpy as np  

W = tf.Variable([[1,1,1],[2,2,2]],dtype = tf.float32,name='w')  
b = tf.Variable([[0,1,2]],dtype = tf.float32,name='b')  

init = tf.initialize_all_variables()  
saver = tf.train.Saver()  
with tf.Session() as sess:  
        sess.run(init)  
        save_path = saver.save(sess,"save/model.ckpt")  

载入
1、定义变量
2、使用saver.restore()方法载入

import tensorflow as tf  
import numpy as np  

W = tf.Variable(tf.truncated_normal(shape=(2,3)),dtype = tf.float32,name='w')  
b = tf.Variable(tf.truncated_normal(shape=(1,3)),dtype = tf.float32,name='b')  

saver = tf.train.Saver()  
with tf.Session() as sess:  
        saver.restore(sess,"save/model.ckpt")  

这种方法不方便的在于,在使用模型的时候,必须把模型的结构重新定义一遍,然后载入对应名字的变量的值。但是很多时候我们都更希望能够读取一个文件然后就直接使用模型,而不是还要把模型重新定义一遍。所以就需要使用另一种方法。

方法2

不需重新定义网络结构的方法
这个方法可以从文件中将保存的graph的所有节点加载到当前的default graph中,并返回一个saver。也就是说,我们在保存的时候,除了将变量的值保存下来,其实还有将对应graph中的各种节点保存下来,所以模型的结构也同样被保存下来了。

保存

### 定义模型
input_x = tf.placeholder(tf.float32, shape=(None, in_dim), name='input_x')
input_y = tf.placeholder(tf.float32, shape=(None, out_dim), name='input_y')

w1 = tf.Variable(tf.truncated_normal([in_dim, h1_dim], stddev=0.1), name='w1')
b1 = tf.Variable(tf.zeros([h1_dim]), name='b1')
w2 = tf.Variable(tf.zeros([h1_dim, out_dim]), name='w2')
b2 = tf.Variable(tf.zeros([out_dim]), name='b2')
keep_prob = tf.placeholder(tf.float32, name='keep_prob')
hidden1 = tf.nn.relu(tf.matmul(self.input_x, w1) + b1)
hidden1_drop = tf.nn.dropout(hidden1, self.keep_prob)
### 定义预测目标
y = tf.nn.softmax(tf.matmul(hidden1_drop, w2) + b2)
# 创建saver
saver = tf.train.Saver(...variables...)
# 假如需要保存y,以便在预测时使用
tf.add_to_collection('pred_network', y)
sess = tf.Session()
for step in xrange(1000000):
    sess.run(train_op)
    if step % 1000 == 0:
        # 保存checkpoint, 同时也默认导出一个meta_graph
        # graph名为'my-model-{global_step}.meta'.
        saver.save(sess, 'my-model', global_step=step)

载入

with tf.Session() as sess:
  new_saver = tf.train.import_meta_graph('my-save-dir/my-model-10000.meta')
  new_saver.restore(sess, 'my-save-dir/my-model-10000')
  # tf.get_collection() 返回一个list. 但是这里只要第一个参数即可
  y = tf.get_collection('pred_network')[0]

  graph = tf.get_default_graph()

  # 因为y中有placeholder,所以sess.run(y)的时候还需要用实际待预测的样本以及相应的参数来填充这些placeholder,而这些需要通过graph的get_operation_by_name方法来获取。
  input_x = graph.get_operation_by_name('input_x').outputs[0]
  keep_prob = graph.get_operation_by_name('keep_prob').outputs[0]

  # 使用y进行预测  
  sess.run(y, feed_dict={input_x:....,  keep_prob:1.0})

上述装载:http://blog.csdn.net/thriving_fcl/article/details/71423039

下面是自己写的influence代码供参考

下面源码仅供参考

import tensorflow as tf
import cv2
import alexnet as AN
import genarate_trainning_data as gtd
import create_tfRecord as tfRec
import numpy as np
import os
import time

from skimage import io as sio
from skimage.segmentation import slic
from skimage.segmentation import felzenszwalb as felseg#图分割函数

FLAGS = tf.app.flags.FLAGS
tf.app.flags.DEFINE_string('image_dir', './SED2/',
                           """Directory where to detect saliency """)
tf.app.flags.DEFINE_string('model_dir', './model/',
                           """Directory where to check point file """)
tf.app.flags.DEFINE_string('mean_image', './SED2/tfRecord_file/',
                           """Directory where to demean """)
x  = tf.placeholder(dtype=tf.float16, shape=[None, tfRec.IMAGE_HEIGHT, tfRec.IMAGE_WIDTH, 3])

def caculate_multi_level(images):
    fparams = np.load('./seg_para.npy').item()
    mean_image = np.load(FLAGS.mean_image +'mean_image.npy')
    for curr_img in images:
        img = sio.imread(FLAGS.image_dir+curr_img)
        img = cv2.resize(img,(tfRec.IMAGE_HEIGHT,tfRec.IMAGE_WIDTH))
        for g in range(15):
            f_seg = felseg(img, sigma=fparams['sigma'][g], scale=np.float(fparams['scale'][g]),
                           min_size=np.int(fparams['min_size'][g]))
            segments_temp = np.unique(f_seg)

            for segment in segments_temp:
                multi_scale_image = tfRec.merge_multiscale_image(img,)

def influence():
    if 'out_result' not in os.listdir(FLAGS.image_dir):
        os.mkdir(FLAGS.image_dir+'out_result')

    gt,images = gtd.dirtomdfbatchmsra(FLAGS.image_dir)

    with tf.Session() as sess:
        meta = [fn for fn in os.listdir(FLAGS.model_dir) if fn.endswith('meta')]
        saver = tf.train.import_meta_graph(FLAGS.model_dir+meta[0])
        saver.restore(sess,tf.train.latest_checkpoint(FLAGS.model_dir))
        predict = tf.get_collection('predict')[0]
        graph = tf.get_default_graph()

        input_x = graph.get_operation_by_name('input_x').outputs[0]
        keep_pro = graph.get_operation_by_name('keep_pro').outputs[0]

        fparams = np.load('./seg_para.npy').item()
        mean_image = np.load(FLAGS.mean_image + 'mean_image.npy')
        for index,curr_img in enumerate(images):
            img = sio.imread(FLAGS.image_dir + curr_img)
            img = cv2.resize(img, (tfRec.IMAGE_HEIGHT, tfRec.IMAGE_WIDTH))
            gt_map = sio.imread(FLAGS.image_dir +gt[index])
            gt_map = cv2.resize(gt_map, (tfRec.IMAGE_HEIGHT, tfRec.IMAGE_WIDTH))
            cv2.imwrite((FLAGS.image_dir+'/out_result/' + curr_img[0:-4] + '-0' + '.jpg'), gt_map)
            for g in range(15):
                start_time = time.time()
                f_seg = felseg(img, sigma=fparams['sigma'][g], scale=np.float(fparams['scale'][g]),
                               min_size=np.int(fparams['min_size'][g]))
                segments_temp = np.unique(f_seg)
                sp_batch = []
                new_sp =[]
                for segment in segments_temp:
                    multi_scale_image = tfRec.merge_multiscale_image(img, mean_image,f_seg,segment)
                    if multi_scale_image is None:
                        f_seg[f_seg==segment]=0
                        continue
                    new_sp.append(segment)
                    sp_batch.append(multi_scale_image)
                sp_predict = sess.run(predict,feed_dict={input_x:sp_batch,keep_pro:1.0})
                sp_label = np.argmax(sp_predict,0)
                #recover a picture
                #for (sp, sp_num) in (new_sp, sp_label):
                for num in range(len(sp_label)):
                    if sp_label[num]:
                        f_seg[f_seg==new_sp[num]]= 255
                    else:
                        f_seg[f_seg == new_sp[num]] = 0
                f_seg.astype(np.uint8)
                duration=time.time()-start_time
                #sio.imsave(('./out_result'+curr_img[0:-4]+'%d'+'.jpg')%g,f_seg)
                cv2.imwrite((FLAGS.image_dir+'/out_result/'+curr_img[0:-4]+'-%d'+'.jpg')%(g+1),f_seg)
                print((curr_img[0:-4]+'-%d'+'.jpg')%(g+1)+' has created...(%.3f sec)'%duration)


def main(_):
    influence()


if __name__ == '__main__':
    tf.app.run()
  • 11
    点赞
  • 95
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值