使用tensorflow实现线性回归

前言

实现一个算法主要从以下三步入手:
找到这个算法的预测函数, 比如线性回归的预测函数形式为:y = wx + b,
找到这个算法的损失函数 , 比如线性回归算法的损失函数为最小二乘法
找到让损失函数求得最小值的时候的系数, 这时一般使用梯度下降法.

使用TensorFlow实现算法的基本套路:

使用TensorFlow中的变量将算法的预测函数, 损失函数定义出来.
使用梯度下降法优化器求损失函数最小时的系数
分批将样本数据投喂给优化器,找到最佳系数

实战

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import tensorflow as tf

# 生成线性数据
x = np.linspace(0, 10, 20) + np.random.randn(20)
y = np.linspace(0, 10, 20) + np.random.randn(20)

# W, B 定义为变量
W = tf.Variable(initial_value=np.random.randn())
B = tf.Variable(initial_value=np.random.randn())

# 定义线性模型, TensorFlow2.x没有占位符了.把线性模型封装为一个函数, x作为参数传入
def linear_regression(x):
    return W * x + B

# 定义损失函数
def mean_square_loss(y_pred, y_true):
    return tf.reduce_sum(tf.square(y_pred - y_true)) / 20

# 优化器 随机梯度下降法
optimizer = tf.optimizers.SGD(0.01)

# 定义优化过程
def run_optimization():
    # 把计算过程放在梯度带中执行,可以实现自动微分
    with tf.GradientTape() as g:
        pred = linear_regression(x)
        loss = mean_square_loss(pred, y)
    # 计算梯度
    gradients = g.gradient(loss, [W, B])
    
    # 更新W和B
    optimizer.apply_gradients(zip(gradients, [W, B]))

    # 训练
for step in range(1, 5001):
    # 每次训练都要更新W和B
    run_optimization()
    # 展示结果
    if step % 100 == 0:
        pred = linear_regression(x)
        loss = mean_square_loss(pred, y)
        print(f'step: {step}, loss: {loss}, W: {W.numpy()}, B: {B.numpy()}')

plt.scatter(x, y)
x_test = np.linspace(0, 10, 20).reshape(-1, 1)
plt.xlabel('X')
plt.ylabel('Y')
plt.plot(x_test, W.numpy() * x_test + B.numpy(), c='r', lw=1.8, alpha=0.5)
plt.show()

结果如下:

step: 100, loss: 1.1600629091262817, W: 0.9292533993721008, B: 1.1462109088897705
step: 200, loss: 1.1559219360351562, W: 0.9378345012664795, B: 1.0833985805511475
step: 300, loss: 1.1547658443450928, W: 0.94236820936203, B: 1.0502125024795532
step: 400, loss: 1.154443383216858, W: 0.9447634816169739, B: 1.0326790809631348
step: 500, loss: 1.1543530225753784, W: 0.946029007434845, B: 1.0234158039093018
step: 600, loss: 1.1543279886245728, W: 0.9466976523399353, B: 1.0185215473175049
step: 700, loss: 1.1543208360671997, W: 0.947050929069519, B: 1.0159354209899902
step: 800, loss: 1.1543190479278564, W: 0.9472375512123108, B: 1.0145692825317383
step: 900, loss: 1.1543183326721191, W: 0.9473361968994141, B: 1.0138472318649292
step: 1000, loss: 1.1543184518814087, W: 0.9473882913589478, B: 1.0134658813476562
step: 1100, loss: 1.15431809425354, W: 0.9474157691001892, B: 1.0132646560668945
step: 1200, loss: 1.1543182134628296, W: 0.9474303722381592, B: 1.0131583213806152
step: 1300, loss: 1.15431809425354, W: 0.9474380016326904, B: 1.0131020545959473
step: 1400, loss: 1.15431809425354, W: 0.9474419951438904, B: 1.0130726099014282
step: 1500, loss: 1.15431809425354, W: 0.9474440813064575, B: 1.0130577087402344
step: 1600, loss: 1.1543182134628296, W: 0.947445273399353, B: 1.0130486488342285
step: 1700, loss: 1.1543182134628296, W: 0.947445273399353, B: 1.0130486488342285
step: 1800, loss: 1.1543182134628296, W: 0.947445273399353, B: 1.0130486488342285
step: 1900, loss: 1.1543182134628296, W: 0.947445273399353, B: 1.0130486488342285
step: 2000, loss: 1.1543182134628296, W: 0.947445273399353, B: 1.0130486488342285
step: 2100, loss: 1.1543182134628296, W: 0.947445273399353, B: 1.0130486488342285
step: 2200, loss: 1.1543182134628296, W: 0.947445273399353, B: 1.0130486488342285
step: 2300, loss: 1.1543182134628296, W: 0.947445273399353, B: 1.0130486488342285
step: 2400, loss: 1.1543182134628296, W: 0.947445273399353, B: 1.0130486488342285
step: 2500, loss: 1.1543182134628296, W: 0.947445273399353, B: 1.0130486488342285
step: 2600, loss: 1.1543182134628296, W: 0.947445273399353, B: 1.0130486488342285
step: 2700, loss: 1.1543182134628296, W: 0.947445273399353, B: 1.0130486488342285
step: 2800, loss: 1.1543182134628296, W: 0.947445273399353, B: 1.0130486488342285
step: 2900, loss: 1.1543182134628296, W: 0.947445273399353, B: 1.0130486488342285
step: 3000, loss: 1.1543182134628296, W: 0.947445273399353, B: 1.0130486488342285
step: 3100, loss: 1.1543182134628296, W: 0.947445273399353, B: 1.0130486488342285
step: 3200, loss: 1.1543182134628296, W: 0.947445273399353, B: 1.0130486488342285
step: 3300, loss: 1.1543182134628296, W: 0.947445273399353, B: 1.0130486488342285
step: 3400, loss: 1.1543182134628296, W: 0.947445273399353, B: 1.0130486488342285
step: 3500, loss: 1.1543182134628296, W: 0.947445273399353, B: 1.0130486488342285
step: 3600, loss: 1.1543182134628296, W: 0.947445273399353, B: 1.0130486488342285
step: 3700, loss: 1.1543182134628296, W: 0.947445273399353, B: 1.0130486488342285
step: 3800, loss: 1.1543182134628296, W: 0.947445273399353, B: 1.0130486488342285
step: 3900, loss: 1.1543182134628296, W: 0.947445273399353, B: 1.0130486488342285
step: 4000, loss: 1.1543182134628296, W: 0.947445273399353, B: 1.0130486488342285
step: 4100, loss: 1.1543182134628296, W: 0.947445273399353, B: 1.0130486488342285
step: 4200, loss: 1.1543182134628296, W: 0.947445273399353, B: 1.0130486488342285
step: 4300, loss: 1.1543182134628296, W: 0.947445273399353, B: 1.0130486488342285
step: 4400, loss: 1.1543182134628296, W: 0.947445273399353, B: 1.0130486488342285
step: 4500, loss: 1.1543182134628296, W: 0.947445273399353, B: 1.0130486488342285
step: 4600, loss: 1.1543182134628296, W: 0.947445273399353, B: 1.0130486488342285
step: 4700, loss: 1.1543182134628296, W: 0.947445273399353, B: 1.0130486488342285
step: 4800, loss: 1.1543182134628296, W: 0.947445273399353, B: 1.0130486488342285
step: 4900, loss: 1.1543182134628296, W: 0.947445273399353, B: 1.0130486488342285
step: 5000, loss: 1.1543182134628296, W: 0.947445273399353, B: 1.0130486488342285

在这里插入图片描述

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值