'''
@Time : 2022/4/4 11:23
@Author : dongdong
@File : predictive_model.py
@Desc :
'''
import numpy as np
import matplotlib.pyplot as plt
from sklearn.linear_model import LinearRegression
plt.rcParams['font.sans-serif'] = ['SimHei'] #显示中文
# 1.数据展示:了解数据,可以是csv读取,也可以直接copy进来
years = np.arange(2009,2020)
sales = np.array([0.52,9.36,33.6,132,352,571,912,1207,1682,2135,2684])
# print(years)
# print(sales)
# plt.scatter(years,sales,c = 'red')
# plt.show()
# 2.初步判断:多项式回归(3阶)
'''
y = a*x^3 + b*x^2 + c*x +d
'''
# 3.数据预处理:准备x,y对应的值,方便后续建模并计算系数、截距
model_y = sales
model_x = (years - 2008).reshape(-1,1)
model_x = np.concatenate([model_x**3,mode