tensorflow读取一个模型后多次使用

训练好一个模型后,将其投入使用,会有在项目初始化后多次加载测试数据的需求,可以采用保存graph的思想实现

(在一个项目中需要加载多个模型同样可用)

另:这条博客接我的上一条https://blog.csdn.net/qq_34470213/article/details/104076898,是在上一个代码的基础上改写的。

1、新建文件test.py,建一个类Model_test,用来保存模型,包括一个初始化方法,用来初始化模型(项目中仅需初始化时调用一次),一个测试调用方法,用来调用模型进行测试(每次测试调用一次)。

class Model_test():
    def restore(self):
        self.model = Model.LeNet5(1, 5)
        path = "D:/model/model/model.ckpt"
        self.model.load(path)

    def restore_test(self, image_path):
        image = Process.process_one(image_path)
        sort = self.model.test1(image)
        return sort

2、在model.py的类中添加初始化函数和测试函数,这里和之前的测试函数的差别在于拆分开了加载和测试的部分,并且将graph和session保存为了类属性变量。

    def load(self, model_path):
        self.graph = tf.Graph()
        self.sess = tf.Session(graph=self.graph)
        with self.sess.graph.as_default():
            self.sess.run(tf.global_variables_initializer())
            self.sess.run(tf.local_variables_initializer())
            saver11 = tf.train.import_meta_graph(model_path+'.meta',
                                               clear_devices=True)
            saver11.restore(self.sess, model_path)

    def test1(self, image):
        x = tf.placeholder(tf.float32, [None, 64, 64, 1], name='x-input')
        self.activation = self.graph.get_tensor_by_name('layer6-fc2/add:0')
        image = np.array(image) / 255.0
        image = np.reshape(image, (-1, 64, 64, 1))
        logit = tf.arg_max(self.activation, 1)
        y, label = self.sess.run((self.activation, logit), feed_dict={'x-input:0': image})
        return label

3、以上两步就可以成功实现了,调用方法为:

tm = test.Model_test()
tm.restore()


……


while(True){
    sort = tm.restore_test(pathname[i])
}

……

 

  • 1
    点赞
  • 10
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
TensorFlow一个开源的机器学习框架,通常用于创建神经网络模型。在训练模型之前,需要准备好数据集,本文将介绍如何使用TensorFlow读取数据。 TensorFlow了多种读取数据的方法,其中最常用的是使用tf.data模块。首先,我们需要定义一个数据集对象,并通过读取文件的方式将数据加载进来。TensorFlow支持多种文件格式,如csv、txt、json、tfrecord等,可以根据自己的需求选择合适的格式。 加载数据后,我们可以对数据进行一些预处理,比如做数据增强、进行归一化等操作。预处理完数据后,我们需要将数据转化为张量类型,并将其打包成batch。通过这种方式,我们可以在每次训练中同时处理多个数据。 随后,我们可以使用tf.data.Dataset中的shuffle()函数打乱数据集顺序,防止模型只学习到特定顺序下的模式,然后使用batch()函数将数据划分成批次。最后,我们可以使用repeat()函数让数据集每次可以被使用多次,达到更好的效果。 在TensorFlow中,我们可以通过输入函数将数据集传入模型中,使模型能够直接从数据集中读取数据。使用输入函数还有一个好处,即能够在模型训练时动态地修改数据的内容,特别是在使用esimator模块进行模型训练时,输入函数是必须要的。 总结一下,在TensorFlow读取数据的流程如下:定义数据集对象-读取文件-预处理数据-打包数据为batch-打乱数据集-划分批次数据-重复数据集-使用输入函数读取数据。 在实际应用过程中,我们还可以通过其他方式来读取数据,如使用numpy、pandas等工具库,也可以自定义数据集类来处理数据。无论使用何种方式,读取数据都是机器学习训练中重要的一步,需要仔细处理。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值