import matplotlib.pyplot as plt
import numpy as np
import csv
def get_a(x):
a = 0.0
for i in x:
a = a + (i * i)
return a
def get_b(x):
a = 0.0
for i in x:
a = a + i
return a
def get_c(x, y):
a = 0.0
for i in range(len(x)):
a = a + x[i] * y[i]
return a
def get_d(y):
a = 0.0
for i in y:
a = a + i
return a
def print_list(ilist):
'打印数据,没啥用'
for i in ilist:
print(i, ",", end = "")
print("\n")
plt.figure()#使用plt.figure定义一个图像窗口
plt.title('regression')#图像标题
plt.xlabel('x')#x轴标题
plt.ylabel('y')#y轴标题
#数据组
listx = [10,8,13,9,11,14,6,4,12,7,5]
listy = [8.04,6.95,7.58,8.81,8.33,9.96,7.24,4.26,10.84,4.82,5.68]
plt.grid(True)#是否打开网格
x = np.linspace(0, 20)#线性回归方程线
#等式计算
A = get_a(listx)
B = get_b(listx)
C = get_c(listx, listy)
D = get_d(listy)
n = len(listx)
a = (B*D-C*n)/(B*B-n*A)
b = (B*C-D*A)/(B*B-n*A)
plt.scatter(listx, listy, c='b') #描点
plt.plot(x, a * x + b, 'b-') #绘制线条
#线性回归方程
a = "%.4f" % a
b = "%.4f" % b
print('y='+a+'*x'+'+'+'('+b+')')
plt.pause(10)#画图延时
Python中scatter函数参数及用法详解:
https://www.jb51.net/article/127806.htm
参考原文:https://blog.csdn.net/qq_22510521/article/details/80058148