简单的线性模型实现tensorflow权重的生成和调用,并且用类的方式实现参数共享

首先看文件路径,line_regression是总文件夹,model文件夹存放权重文件,

global_variable.py写了一句话.

 

save_path='./model/weight'

权重要存放的路径,以weight命名.

lineRegulation_model.py代码

 

import tensorflow as tf
"""
类定义一些公共量,方便模型载入用
"""
class LineRegModel:
    def __init__(self):
        self.a_val=tf.Variable(tf.random_normal(shape=[1]))
        self.b_val = tf.Variable(tf.random_normal(shape=[1]))
        self.x_input=tf.placeholder(dtype=tf.float32)
        self.y_label = tf.placeholder(dtype=tf.float32)
        self.y_output = tf.multiply(self.x_input,self.a_val)+self.b_val
        self.loss=tf.reduce_mean(tf.pow(self.y_output-self.y_label,2))
    def get_op(self):
        return tf.train.GradientDescentOptimizer(0.01).minimize(self.loss)

定义了一个类,方便后面共享权值恢复模型的调用

model_train.py代码:

 

import tensorflow as tf
import numpy as np
from save_and_restore import global_variable
from save_and_restore import  lineRegulation_model as model
"""
训练模型
"""
train_x=np.random.rand(5)
train_y=train_x*5+3
model=model.LineRegModel()#类要加括号
a_val=model.a_val
b_val=model.b_val
x_input=model.x_input
y_label=model.y_label
y_output=model.y_output
loss=model.loss
optimizer=model.get_op()
if __name__ == '__main__':
    saver = tf.train.Saver()
    init=tf.global_variables_initializer()
    with tf.Session() as sess:
        sess.run(init)
        flag=True
        epoch=0
        while flag:
            epoch+=1
            cost,_=sess.run([loss,optimizer],feed_dict={x_input:train_x,y_label:train_y})
            if cost<1e-6:
                flag=False
                print('a={},b={}'.format(a_val.eval(sess),b_val.eval(sess)))
                print('epoch={}'.format(epoch))
                saver.save(sess,global_variable.save_path)
                print('model save finish')

训练模型,并且存放模型的目的,这样前面三段代码就可以实现简单的线性模型权重的生成和存放。

其中checkpoint指的是检查点文件,记录存储文件名称,weight.data_00000-of-00001权重存储文件,weight.index存储权重目录

weight.meta模型的全部图文件,所以weight.data_00000-of-00001和weight.meta是最大的。

model_restore.py代码如下:

import tensorflow as tf
from save_and_restore import global_variable,lineRegulation_model as model
"""
加载模型
"""
model=model.LineRegModel()
x_input=model.x_input
y_output=model.y_output
init=tf.global_variables_initializer()
saver=tf.train.Saver()
with tf.Session() as sess:
    sess.run(init)
    saver.restore(sess,global_variable.save_path)
    result=sess.run(y_output,feed_dict={x_input:[1]})
    print(result)

调用生成的模型打印出预测结果:

结果和8差不多。

 

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值