模型公式:
y(t)=g(t)+s(t)+h(t)+εt
其中g(t)是趋势项, s(t) 是周期项,如 weekly 和 yearly 等,h(t)是节假日趋势
主要调整参数:
- n_changepoints(拐点数量):默认为25,通常情况下,拐点数量越多,曲线拟合的越好,也容易过拟合。
- yearly_seasonality(季节性标识):默认为'auto',如果数据有季节性规律,如年中销量较低,而年末销量特别高,该值须设为True,这时model.add_seasonality(name='yearly', period=6, fourier_order=1) 函数无效。
- changepoint_prior_scale:默认为0.05,值越大,拟合得越好,大到一定程度曲线不再变化。
- seasonality_prior_scale:默认为10,值越大,季节性越明显。
- interval_width(置信区间):默认为0.8,真实值有80%的可能落在预测值的上下边界之间。
- growth:默认'linear',当growth = 'logistic'时,可以人工设定预测值的最大值、最小值,见下图代码“# train['cap'] = 400
# train['floor'] = 0”,“# future_data['cap'] = 400
# future_data['floor'] = 0”。
def prophet_model(train, pred_periods=12, **kwargs):
model = Prophet(**kwargs)
# model.add_seasonality(name='yearly', period=12, fourier_order=5)
# train['cap'] = 400
# train['floor'] = 0
model.fit(train)
future_data = model.make_future_dataframe(periods=pred_periods, freq='m')
# future_data['cap'] = 400
# future_data['floor'] = 0
forecast_data = model.predict(future_data)
model.plot(forecast_data)
# plt.scatter(test['ds'], test['y'], color='r', s=12)
model.plot_components(forecast_data)
return model, forecast_data
另外:
- 若yearly_seasonality、weekly_seasonality、daily_seasonality等周期规律标识参数均为False,预测曲线近似一条向上或向下(决定于趋势值)的直线,此时模型公式里只含趋势项。
- Prophet适合用于周期性、趋势性较强的时间序列