数据集为简单的表格,包括:年份,GDP,全社会用电量三列数据。
#预测
def predict(data,LinearRegression):
Y_pred = LinearRegression.predict(data)
print(Y_pred)
return Y_pred
#训练模型并画图
def reg_huigui(data,label,test):
regr = LinearRegression() #线性回归
#regr = Ridge(alpha=10) #岭回归
#regr = Lasso(alpha=0.001) #Lasso回归
regr.fit(data.values,label.values)
Y_pred_train = regr.predict(data.values) #训练集的预测值
predict(test.values,regr) #预测
print('参数:',regr.coef_.astype(np.float32))
#对模型进行评估
from sklearn.metrics import mean_squared_error,mean_absolute_error,r2_score
print("MSE:",mean_squared_error(label.values,Y_pred_train)) #误差
print("MAE:",mean_absolute_error(label.values,Y_pred_train))
print("R2:",r2_score(label.values,Y_pred_train))