1.基于TensorFlow简单线性回归模型的训练,模型的保存,以及恢复使用

#coding=utf-8
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
def get_bact(num, x_train, y_train):  #打乱顺序,抽取batch
    examples_num = np.arange(len(x_train))
    np.random.shuffle(examples_num)
    examples_num = examples_num[0:num]
    x_data = []
    y_lable = []
    for i in examples_num:
        x_data.append(x_train[i])
        y_lable.append(y_train[i])
    return np.asarray(x_data),np.asarray(y_lable)

if __name__ == "__main__":
    x_train = np.random.random((100, 1))
    y_train = 0.3*x_train+0.1
    w = tf.Variable(tf.zeros(1), name="W")
    basi = tf.Variable(tf.zeros(1), name="basi")
    x = tf.placeholder(tf.float32, shape=[None, 1],name="input_data")
    y_ = tf.placeholder(tf.float32, shape=[None, 1])
    y = tf.add(tf.multiply(w, x), basi,name= "y") #预测
    cross_entry = tf.reduce_mean(tf.square(y - y_)) #平方误差
    optimize = tf.train.GradientDescentOptimizer(0.001).minimize(cross_entry)#最小化平方误差
    validation_feed = {x: x_validation, y_: y_validation}
    init_op = tf.global_variables_initializer() #初始化所有变量    
    correct_prediction = tf.equal(tf.argmax(y_, 1), tf.argmax(y, 1))  # 正确值的一个返回
    accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))  # 求平均正确率
    Save = tf.train.Saver()#定义Save
    with tf.Session() as sess:
        sess.run(init_op)
        x_plot = []
        y_plot = []
        for i in range(40000):
            xx, yy = get_bact(90, x_train, y_train)
            _, ww, bbasi, xxx = sess.run([optimize, w, basi, cross_entry], feed_dict={x: xx, y_: yy})
            print("ww:%f,bais:%floss:%f" % (ww, bbasi, xxx))
            if i % 1000== 0:
                x_plot.append(i)
                y_plot.append(xxx)
            if xxx <= 0.00006:
                Save.save(sess, "model/model.ckpt")#保存模型
                break

    plt.xlabel("epoch")
    plt.ylabel("loss")
    plt.plot(x_plot, y_plot)
    plt.show()
  

对应的使用保存的模型进行测试:

#利用训练好的模型进行测试
import tensorflow as tf
saver = tf.train.import_meta_graph('./model/model.ckpt.meta') #加载图
sess = tf.Session()
saver.restore(sess,tf.train.latest_checkpoint('./model/'))#获得最新的数据,输入是模型的路径
print(sess.run("W:0"))  #打印出权重值
print(sess.run("basi:0"))  #偏置
graph = tf.get_default_graph() #加载默认图
x = graph.get_tensor_by_name("input_data:0") #通过名字,获得张量
y = graph.get_tensor_by_name("y:0")
feed = {x:[[2.0]]} #很重要,当输入为[None,1]的时候,不是单纯的x:2.0或者
print(sess.run(y,feed_dict=feed)) #测试,获取结果
注意:在训练的时候最好把每个张量的名字带上,以免后期保存模型后无法找到对应的张量

  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值