简单的线性回归实现模型的存储和读取

和这篇文章对比https://blog.csdn.net/fanzonghao/article/details/81023730

不希望重复定义图上的运算,也就是在模型恢复过程中,不想sess.run(init)首先看路径

lineRegulation_model.py定义线性回归类:

import tensorflow as tf
"""
类定义一些公共量,方便模型载入用
"""
class LineRegModel:
    def __init__(self):
        with tf.variable_scope('var'):
            self.a_val=tf.Variable(tf.random_normal(shape=[1]),name='a_val')
            self.b_val = tf.Variable(tf.random_normal(shape=[1]),name='b_val')
        self.x_input=tf.placeholder(dtype=tf.float32,name='input_placeholder')
        self.y_label = tf.placeholder(dtype=tf.float32,name='result_placeholder')
        self.y_output = tf.add(tf.multiply(self.x_input,self.a_val),
                               self.b_val,name='output')
        self.loss=tf.reduce_mean(tf.pow(self.y_output-self.y_label,2))
    def get_saver(self):
        return tf.train.Saver()
    def get_op(self):
        return tf.train.GradientDescentOptimizer(0.01).minimize(self.loss)

model_train.py定义模型训练过程

import tensorflow as tf
import numpy as np
from save_and_restore2 import global_variable
from save_and_restore2 import  lineRegulation_model as model
import os
if not os.path.exists('./model'):
    os.makedirs('./model')
"""
训练模型
"""
train_x=np.random.rand(5)
train_y=train_x*5+3
model=model.LineRegModel()#类要加括号
a_val=model.a_val
b_val=model.b_val
x_input=model.x_input
y_label=model.y_label
y_output=model.y_output
loss=model.loss
optimizer=model.get_op()
saver=model.get_saver()
if __name__ == '__main__':
    init=tf.global_variables_initializer()
    with tf.Session() as sess:
        sess.run(init)
        flag=True
        epoch=0
        while flag:
            epoch+=1
            cost,_=sess.run([loss,optimizer],feed_dict={x_input:train_x,y_label:train_y})
            if cost<1e-6:
                flag=False
                print('a={},b={}'.format(a_val.eval(sess),b_val.eval(sess)))
                print('epoch={}'.format(epoch))
                print(a_val)
                # print(a_val.op)
                saver.save(sess,global_variable.save_path)
                print('model save finish')

print(a_val)的形式

print(a_val.op)的形式

model_restore.py恢复模型 ,利用恢复图在恢复权重的方式,可实现更细节的模型恢复

import tensorflow as tf
from save_and_restore import global_variable,lineRegulation_model as model
"""
恢复模型图文件
"""
saver=tf.train.import_meta_graph('./model/weight.meta')
#读取placeholder和最终的输出结果
graph=tf.get_default_graph()
a_val=graph.get_tensor_by_name('var/a_val:0')
b_val=graph.get_tensor_by_name('var/b_val:0')

input_placeholder=graph.get_tensor_by_name('input_placeholder:0')
labels_placeholder=graph.get_tensor_by_name('result_placeholder:0')
y_output=graph.get_tensor_by_name('output:0')


with tf.Session() as sess:
    #具体权重的恢复
    saver.restore(sess,'./model/weight')
    result=sess.run(y_output,feed_dict={input_placeholder:[1]})
    print(result)
    print(sess.run(a_val))
    print(sess.run(b_val))

 

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值