多变量线性回归
同理。
注意,这里我使用归一化处理,所以预测出来的结果都是在小数点,要预测结果的话反过来乘过去就行了。
import pandas as pd
import matplotlib.pyplot as plt
# 计算损失
def compute_loss(x1, x2, y, w1, w2, b):
m = x1.shape[0]
loss = 0.
for i in range(m):
loss += (w1 * x1[i] + w2 * x2[i] + b - y[i]) ** 2
return loss / (2 * m)
# 梯度下降
def gradient_descent(x1, x2, y, w1, w2, b, eta, iterations):
m = x1.shape[0]
loss_history = []
for _ in range(iterations):
sum_w1 = 0.
sum_w2 = 0.
sum_b = 0.
for i in range(m):
error = w1 * x1[i] + w2 * x2[i] + b - y[i]
sum_w1 += error * x1[i]
sum_w2 += error * x2[i]
sum_b += error
w1 -= eta * sum_w1 / m
w2 -= eta * sum_w2 / m
b -= eta * sum_b / m
loss_history.append(compute_loss(x1, x2, y, w1, w2, b))
return w1, w2, b, loss_history
def predict(x1, x2, w1, w2, b):
return w1 * x1 + w2 * x2 + b
if __name__ == '__main__':
# 读取数据
data = pd.read_csv(r'D:\BaiduNetdiskDownload\data_sets\ex1data2.txt', names=["x1", "x2", "y"])
# 保存均值和标准差以备标准化和反标准化
mean_x1 = data['x1'].mean()
mean_x2 = data['x2'].mean()
mean_y = data['y'].mean()
std_x1 = data['x1'].std()
std_x2 = data['x2'].std()
std_y = data['y'].std()
# 标准化数据
data['x1'] = (data['x1'] - mean_x1) / std_x1
data['x2'] = (data['x2'] - mean_x2) / std_x2
data['y'] = (data['y'] - mean_y) / std_y
x1 = data['x1']
x2 = data['x2']
y = data['y']
# 训练模型
w1, w2, b, loss_history = gradient_descent(x1, x2, y, 0, 0, 0, 0.01, 1000)
print(f'Trained model parameters: w1 = {w1}, w2 = {w2}, b = {b}')
# 绘制损失曲线
epochs = range(len(loss_history))
plt.plot(epochs, loss_history, color='red', label='loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Loss Curve')
plt.legend()
plt.grid(True)
plt.show()
# 预测房价
new_data = pd.DataFrame({'x1': [2104,1600,1416], 'x2': [3,3,2]})
new_data['x1'] = (new_data['x1'] - mean_x1) / std_x1
new_data['x2'] = (new_data['x2'] - mean_x2) / std_x2
predicted_price_standardized = predict(new_data['x1'], new_data['x2'], w1, w2, b)
# 反标准化预测结果
predicted_price = predicted_price_standardized * std_y + mean_y
print(f'Predicted price: {predicted_price}')
一些图表
损失:
参数(归一化之后):
w1 = 0.8785036522230538, w2 = -0.046916657038053915, b = -1.1073884126473927e-16