1、代码实现
# 计算赤峰面积和房价之间的关系
import numpy as np
import matplotlib.pyplot as plt
# 构建数据集
data = []
# (70-90,90-100,100-110,110-130)
for i in range(300):
# 面积(训练集)
area = np.random.uniform(60, 100)
# 房价
eps2 = np.random.uniform(60, 62)
# bias
eps3 = np.random.uniform(200., 700.)
# 总房价(标签)
price = eps2 * area + eps3 # 随机生成一个线性方程,大小为(500,1)
data.append([area, price])
data = np.array(data) # 数据集创建完毕 2维数组 [面积,房价]
# 将参数分出来方便之后的使用
area = data[:, 0]
price = data[:, 1]
# 绘制原始数据
plt.title("Area-Price") # 标题名
plt.scatter(area, price, s=10) # 设置为散点图
plt.xlabel("area") # x轴的标题
plt.ylabel("price") # y轴的标题
plt.show() # 绘制出来
# 创建一个loss值的list
loss_list = []
def mse(b, w, data): # 根据当前的 w,b,参数计算均方差损失
TotalError = 0 # 记录总误差
for i in range(0, len(data)):
x = data[i, 0]
y = data[i, 1]
TotalError += (y - (w * x + b)) ** 2
return TotalError / float(len(data))
def gradient_update(b, w, data, lr):
b_gradient = 0
w_gradient = 0
size = float(len(data))
for i in range(0, len(data)):
x = data[i, 0]
y = data[i, 1]
# 计算梯度
b_gradient += (2 / size) * ((w * x + b) - y)
w_gradient += (2 / size) * x * ((w * x + b) - y)
# 根据梯度更新权重和偏置
b -= lr * b_gradient
w -= lr * w_gradient
return [b, w]
# 梯度下降法
def gradient_descent(data, b, w, lr, num_iterations):
# 因为没有batch,所以num_iterations即为epoch
for num in range(num_iterations):
# 更新参数
b, w = gradient_update(b, w, data, lr)
# 计算损失值并添加到损失列表
loss = mse(b, w, data)
loss_list.append(loss)
print('iteration:[%s] | loss:[%s] | w:[%s] | b:[%s]' % (num, loss, w, b))
return [b, w]
def main():
lr = 0.00001
initial_b = np.random.randn(1)
initial_w = np.random.randn(1)
num_iterations = 100 # 因为没有batch,所以num_iterations即为epoch
[b, w] = gradient_descent(data, initial_b, initial_w, lr, num_iterations)
loss = mse(b, w, data)
print('Final loss:[%s] | w:[%s] | b:[%s]' % (loss, w, b))
# 损失函数
plt.title("Loss Function") # 标题名
plt.plot(np.arange(0, 100), loss_list)
plt.xlabel('Interation')
plt.ylabel('Loss Value')
plt.show()
# 绘制
y2 = w * area + b
print(w * 100 + b)
plt.title("Fit the line graph") # 标题名
plt.scatter(area, price, label='Original Data', s=10) # 设置为散点图
plt.plot(area, y2, color='Red', label='Fitting Line', linewidth=2)
plt.xlabel('m_j')
plt.ylabel('j_g')
plt.legend()
plt.show()
main()
2.代码结果
————————————————
版权声明:本文为CSDN博主「羊之战神」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。
原文链接:https://blog.csdn.net/m0_68962414/article/details/127326084