使用tensorflow实现一个简单的线性回归

'''
本文主要实现线性回归
线性回归的基本概念: y = w1x1 + w2x2 + w3x3 + ... + b

(1) 建立模型: y = w1x1 + w2x2 + w3x3 + ... + b
(2) 损失函数: 均方根误差 MSE
(3) 优化损失:梯度优化
'''

import tensorflow.compat.v1 as tf
import os

tf.disable_eager_execution()
os.environ['TF_CPP_MIN_LOGLEVEL'] = '2'

# -------------------------------------------------------分割线-------------------------------------------------------

def linear_regression():

    # 1.准备数据
    with tf.variable_scope('original_data'):     # 为各个变量创建命名空间是一个良好的习惯
        x = tf.random_normal(shape=[100, 1], name='feature')
        y_true = tf.matmul(x, [[0.8]]) + [[0.7]]

    # 2.构建模型
    with tf.variable_scope('create_model'):
        weights = tf.Variable(initial_value=tf.random_normal(shape=[1, 1]), name='Weights')
        bias = tf.Variable(initial_value=tf.random_normal(shape=[1, 1]), name='Bias')
        y_pred = tf.matmul(x, weights) + bias

    # 3.损失函数
    with tf.variable_scope('loss_function'):
        loss = tf.reduce_mean(tf.square(y_pred - y_true))

    # 4.优化损失
    with tf.variable_scope('optimizer'):
        optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01).minimize(loss)
        '''
        梯度下降优化算法:
        w2 = w1 - learning_rate * 最速梯度下降的方向
        b2 = b1 - learning_rate * 最速梯度下降的方向
        '''

    # 初始化变量
    init = tf.global_variables_initializer()

    # 模型的保存
    saver = tf.train.Saver()

    # 创建会话
    with tf.Session() as sess:

        sess.run(init)
        print('训练前权重为%f,偏置为%f,损失为%f' % (sess.run(weights), sess.run(bias), sess.run(loss)))

        for i in range(100):
            sess.run(optimizer)
            print('Epoch:%d,权重为%f,偏置为%f,损失为%f' % (i + 1, sess.run(weights), sess.run(bias), sess.run(loss)))

        # 模型的保存
        saver.save(sess, './model/linear_regression.ckpt')

        '''
        模型的加载
        if os.path.exists('./model/checkpoint'):  # 注意,这里判断的是checkpoint是否存在,而不是linear_regression.ckpt是否存在
            saver.restore(sess, './model/linear_regression.ckpt')
        print('训练后的权重为%f,偏置为%f,损失为%f' % (sess.run(weights), sess.run(bias), sess.run(loss)))   
        '''

    return None

if __name__ == '__main__':
    linear_regression()
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值