def linear_plot():
"""
参数:无
返回:
w -- 自变量系数, 保留两位小数
b -- 截距项, 保留两位小数
fig -- matplotlib 绘图对象
"""
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from pylab import mpl
def calculate(x, w0, w1):
y = []
for i in x:
y.append(w0 + w1 * i)
return y
data = [[5.06, 5.79], [4.92, 6.61], [4.67, 5.48], [4.54, 6.11], [4.26, 6.39],
[4.07, 4.81], [4.01, 4.16], [4.01, 5.55], [3.66, 5.05], [3.43, 4.34],
[3.12, 3.24], [3.02, 4.80], [2.87, 4.01], [2.64, 3.17], [2.48, 1.61],
[2.48, 2.62], [2.02, 2.50], [1.95, 3.59], [1.79, 1.49], [1.54, 2.10], ]
### TODO: 线性拟合计算参数 ###
w = None
b = None
data_array = np.array(data)
x=data_array[:,0]
y=data_array[:,1]
size = len(x)
i = 0
sum_xy = 0
sum_x = 0
sum_y = 0
sum_square_x = 0
while i < size:
sum_xy += x[i] * y[i]
sum_x += x[i]
sum_y += y[i]
sum_square_x += x[i] * x[i]
i += 1
w1 = (size * sum_xy - sum_x * sum_y) / (size * sum_square_x - sum_x * sum_x)
w0 = (sum_square_x * sum_y - sum_x * sum_xy) / (size * sum_square_x - sum_x * sum_x)
w1 = round(w1, 2)
w0 = round(w0, 2)
w=w1
b=w0
print(w,b)
fig = plt.figure() # 务必保留此行,设置绘图对象
### TODO: 按题目要求绘图 ###
ax1=fig.add_subplot(1,2,1)
mpl.rcParams['font.sans-serif'] = ['SimHei']
mpl.rcParams['axes.unicode_minus'] = False
data_array=np.array(data)
ax1.scatter(data_array[:,0],data_array[:,1])
ax1.plot(x, calculate(x, b, w), label="拟合曲线", color="red")
plt.legend()
plt.show()
print(data_array[:,0])
return w, b, fig # 务必按此顺序返回
if __name__=="__main__":
linear_plot()