前言
实现一个算法主要从以下三步入手:
找到这个算法的预测函数, 比如线性回归的预测函数形式为:y = wx + b,
找到这个算法的损失函数 , 比如线性回归算法的损失函数为最小二乘法
找到让损失函数求得最小值的时候的系数, 这时一般使用梯度下降法.
使用TensorFlow实现算法的基本套路:
使用TensorFlow中的变量将算法的预测函数, 损失函数定义出来.
使用梯度下降法优化器求损失函数最小时的系数
分批将样本数据投喂给优化器,找到最佳系数
实战
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import tensorflow as tf
# 生成线性数据
x = np.linspace(0, 10, 20) + np.random.randn(20)
y = np.linspace(0, 10, 20) + np.random.randn(20)
# W, B 定义为变量
W = tf.Variable(initial_value=np.random.randn())
B = tf.Variable(initial_value=np.random.randn())
# 定义线性模型, TensorFlow2.x没有占位符了.把线性模型封装为一个函数, x作为参数传入
def linear_regression(x):
return W * x + B
# 定义损失函数
def mean_square_loss(y_pred, y_true):
return tf.reduce_sum(tf.square(y_pred - y_true)) / 20
# 优化器 随机梯度下降法
optimizer = tf.optimizers.SGD(0.01)
# 定义优化过程
def run_optimization():
# 把计算过程放在梯度带中执行,可以实现自动微分
with tf.GradientTape() as g:
pred = linear_regression(x)
loss = mean_square_loss(pred, y)
# 计算梯度
gradients = g.gradient(loss, [W, B])
# 更新W和B
optimizer.apply_gradients(zip(gradients, [W, B]))
# 训练
for step in range(1, 5001):
# 每次训练都要更新W和B
run_optimization()
# 展示结果
if step % 100 == 0:
pred = linear_regression(x)
loss = mean_square_loss(pred, y)
print(f'step: {step}, loss: {loss}, W: {W.numpy()}, B: {B.numpy()}')
plt.scatter(x, y)
x_test = np.linspace(0, 10, 20).reshape(-1, 1)
plt.xlabel('X')
plt.ylabel('Y')
plt.plot(x_test, W.numpy() * x_test + B.numpy(), c='r', lw=1.8, alpha=0.5)
plt.show()
结果如下:
step: 100, loss: 1.1600629091262817, W: 0.9292533993721008, B: 1.1462109088897705
step: 200, loss: 1.1559219360351562, W: 0.9378345012664795, B: 1.0833985805511475
step: 300, loss: 1.1547658443450928, W: 0.94236820936203, B: 1.0502125024795532
step: 400, loss: 1.154443383216858, W: 0.9447634816169739, B: 1.0326790809631348
step: 500, loss: 1.1543530225753784, W: 0.946029007434845, B: 1.0234158039093018
step: 600, loss: 1.1543279886245728, W: 0.9466976523399353, B: 1.0185215473175049
step: 700, loss: 1.1543208360671997, W: 0.947050929069519, B: 1.0159354209899902
step: 800, loss: 1.1543190479278564, W: 0.9472375512123108, B: 1.0145692825317383
step: 900, loss: 1.1543183326721191, W: 0.9473361968994141, B: 1.0138472318649292
step: 1000, loss: 1.1543184518814087, W: 0.9473882913589478, B: 1.0134658813476562
step: 1100, loss: 1.15431809425354, W: 0.9474157691001892, B: 1.0132646560668945
step: 1200, loss: 1.1543182134628296, W: 0.9474303722381592, B: 1.0131583213806152
step: 1300, loss: 1.15431809425354, W: 0.9474380016326904, B: 1.0131020545959473
step: 1400, loss: 1.15431809425354, W: 0.9474419951438904, B: 1.0130726099014282
step: 1500, loss: 1.15431809425354, W: 0.9474440813064575, B: 1.0130577087402344
step: 1600, loss: 1.1543182134628296, W: 0.947445273399353, B: 1.0130486488342285
step: 1700, loss: 1.1543182134628296, W: 0.947445273399353, B: 1.0130486488342285
step: 1800, loss: 1.1543182134628296, W: 0.947445273399353, B: 1.0130486488342285
step: 1900, loss: 1.1543182134628296, W: 0.947445273399353, B: 1.0130486488342285
step: 2000, loss: 1.1543182134628296, W: 0.947445273399353, B: 1.0130486488342285
step: 2100, loss: 1.1543182134628296, W: 0.947445273399353, B: 1.0130486488342285
step: 2200, loss: 1.1543182134628296, W: 0.947445273399353, B: 1.0130486488342285
step: 2300, loss: 1.1543182134628296, W: 0.947445273399353, B: 1.0130486488342285
step: 2400, loss: 1.1543182134628296, W: 0.947445273399353, B: 1.0130486488342285
step: 2500, loss: 1.1543182134628296, W: 0.947445273399353, B: 1.0130486488342285
step: 2600, loss: 1.1543182134628296, W: 0.947445273399353, B: 1.0130486488342285
step: 2700, loss: 1.1543182134628296, W: 0.947445273399353, B: 1.0130486488342285
step: 2800, loss: 1.1543182134628296, W: 0.947445273399353, B: 1.0130486488342285
step: 2900, loss: 1.1543182134628296, W: 0.947445273399353, B: 1.0130486488342285
step: 3000, loss: 1.1543182134628296, W: 0.947445273399353, B: 1.0130486488342285
step: 3100, loss: 1.1543182134628296, W: 0.947445273399353, B: 1.0130486488342285
step: 3200, loss: 1.1543182134628296, W: 0.947445273399353, B: 1.0130486488342285
step: 3300, loss: 1.1543182134628296, W: 0.947445273399353, B: 1.0130486488342285
step: 3400, loss: 1.1543182134628296, W: 0.947445273399353, B: 1.0130486488342285
step: 3500, loss: 1.1543182134628296, W: 0.947445273399353, B: 1.0130486488342285
step: 3600, loss: 1.1543182134628296, W: 0.947445273399353, B: 1.0130486488342285
step: 3700, loss: 1.1543182134628296, W: 0.947445273399353, B: 1.0130486488342285
step: 3800, loss: 1.1543182134628296, W: 0.947445273399353, B: 1.0130486488342285
step: 3900, loss: 1.1543182134628296, W: 0.947445273399353, B: 1.0130486488342285
step: 4000, loss: 1.1543182134628296, W: 0.947445273399353, B: 1.0130486488342285
step: 4100, loss: 1.1543182134628296, W: 0.947445273399353, B: 1.0130486488342285
step: 4200, loss: 1.1543182134628296, W: 0.947445273399353, B: 1.0130486488342285
step: 4300, loss: 1.1543182134628296, W: 0.947445273399353, B: 1.0130486488342285
step: 4400, loss: 1.1543182134628296, W: 0.947445273399353, B: 1.0130486488342285
step: 4500, loss: 1.1543182134628296, W: 0.947445273399353, B: 1.0130486488342285
step: 4600, loss: 1.1543182134628296, W: 0.947445273399353, B: 1.0130486488342285
step: 4700, loss: 1.1543182134628296, W: 0.947445273399353, B: 1.0130486488342285
step: 4800, loss: 1.1543182134628296, W: 0.947445273399353, B: 1.0130486488342285
step: 4900, loss: 1.1543182134628296, W: 0.947445273399353, B: 1.0130486488342285
step: 5000, loss: 1.1543182134628296, W: 0.947445273399353, B: 1.0130486488342285