import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from statsmodels.tsa.statespace.sarimax import SARIMAX
plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False # 用来显示负号
pd.set_option('display.max_columns', None) # 显示所有列
pd.set_option('display.max_rows', None) # 显示所有行
pd.set_option('max_colwidth', 400)
# 定义一个使用SARIMAX模型预测的函数
def predict_sarimax(series, exog, steps=7):
model = SARIMAX(series, order=(5, 1, 0), exog=exog)
model_fit = model.fit(disp=False)
forecast = model_fit.forecast(steps=steps, exog=exog[-steps:])
return forecast
# 重新加载 CSV 文件
df_past_order = pd.read_csv("C:\\Users\\ChenCong\\Desktop\\2024年度“火花杯”数学建模精英联赛-C题-附件\\2024年度“火花杯”数学建模精英联赛-C题-附件\\input_data\\df_past_order.csv")
df_loc = pd.read_csv("C:\\Users\\ChenCong\\Desktop\\2024年度“火花杯”数学建模精英联赛-C题-附件\\2024年度“火花杯”数学建模精英联赛-C题-附件\\input_data\\df_loc.csv")
# # 去除 'Name' 列中末尾的 'shi' 和 '-'
# df_past_order['Name'] = df_past_order['Name'].str.replace('shi$', '', regex=True).str.rstrip('-')
df_loc['name'] = df_loc['name'] + '-shi'
# 重新合并数据集
df_merged_corrected = pd.merge(df_past_order, df_loc, left_on='Name', right_on='name', how='left')
# print(df_merged_corrected)
# print('-------------------------------------------------------')
# 删除冗余的 'name' 列
df_merged_corrected.drop(columns=['name'], inplace=True)
date_columns = df_past_order.columns[2:]
print(date_columns)
df_merged_corrected[['Longitude', 'Latitude', 'resident_pop', 'gdp']] = df_merged_corrected[['Longitude', 'Latitude', 'resident_pop', 'gdp']].apply(pd.to_numeric, errors='coerce')
df_merged_corrected.fillna(0, inplace=True) # 可以使用0填充NaN,或者根据需求使用其他方法
# # 查看数据框中每列的数据类型
# print(df_merged_corrected.dtypes)
# print('---------------------------------------------------')
# # 检查是否有缺失值(NaN)
# print(df_merged_corrected[['Longitude', 'Latitude', 'resident_pop', 'gdp']].isnull())
# print(df_merged_corrected.isnull().sum())
# print('---------------------------------------------------')
# df_merged_corrected.to_excel('df_merged_corrected.xlsx')
# 遍历每个城市和SKU进行预测
predictions = []
for _, row in df_merged_corrected.iterrows():
city = row['Name']
sku = row['SKU']
series = row[date_columns].values.astype(float)
# 打印要传递给模型的数据
print("Exogenous variables (exog):", row[['Longitude', 'Latitude', 'resident_pop', 'gdp']].values)
print("Time series data (series):", row[date_columns].values)
# 选择外生变量(例如,地理信息、人口和GDP)
exog = row[['Longitude', 'Latitude', 'resident_pop', 'gdp']].values.reshape(1, -1)
print(exog)
# 预测未来7天的订单量
forecast = predict_sarimax(series, exog, steps=7)
# 将结果添加到预测列表中
prediction = {'Name': city, 'SKU': sku}
for i, val in enumerate(forecast):
prediction[f'forecast_day_{i + 1}'] = val
predictions.append(prediction)
# 将预测结果转换为DataFrame格式
df_predictions = pd.DataFrame(predictions)
# 绘制每个城市的未来7天预测数据
# 设置图形大小
plt.figure(figsize=(12, 8))
# 遍历每个城市的预测数据
for i, row in df_predictions.iterrows():
days = [f'forecast_day_{j + 1}' for j in range(7)] # 创建未来7天的标签
values = row[days].values # 获取对应的预测值
plt.plot(days, values, marker='o', label=f'{row["Name"]} ({row["SKU"]})')
if i == 5:
break
# 添加标题和标签
plt.title('未来7天的订单预测')
plt.xlabel('预测日期')
plt.ylabel('预测订单量')
plt.legend(loc='upper right')
plt.grid(True)
# 显示图形
plt.show()
错误提示:
Traceback (most recent call last):
File "E:\Catch_pigs\23C\test3.py", line 74, in <module>
forecast = predict_sarimax(series, exog, steps=7)
File "E:\Catch_pigs\23C\test3.py", line 17, in predict_sarimax
model = SARIMAX(series, order=(5, 1, 0), exog=exog)
File "E:\python test\venv\lib\site-packages\statsmodels\tsa\statespace\sarimax.py", line 328, in __init__
self._spec = SARIMAXSpecification(
File "E:\python test\venv\lib\site-packages\statsmodels\tsa\arima\specification.py", line 446, in __init__
self._model = TimeSeriesModel(endog, exog=exog, dates=dates, freq=freq,
File "E:\python test\venv\lib\site-packages\statsmodels\tsa\base\tsa_model.py", line 470, in __init__
super().__init__(endog, exog, missing=missing, **kwargs)
File "E:\python test\venv\lib\site-packages\statsmodels\base\model.py", line 270, in __init__
super().__init__(endog, exog, **kwargs)
File "E:\python test\venv\lib\site-packages\statsmodels\base\model.py", line 95, in __init__
self.data = self._handle_data(endog, exog, missing, hasconst,
File "E:\python test\venv\lib\site-packages\statsmodels\base\model.py", line 135, in _handle_data
data = handle_data(endog, exog, missing, hasconst, **kwargs)
File "E:\python test\venv\lib\site-packages\statsmodels\base\data.py", line 675, in handle_data
return klass(endog, exog=exog, missing=missing, hasconst=hasconst,
File "E:\python test\venv\lib\site-packages\statsmodels\base\data.py", line 88, in __init__
self._handle_constant(hasconst)
File "E:\python test\venv\lib\site-packages\statsmodels\base\data.py", line 133, in _handle_constant
if not np.isfinite(exog_max).all():
TypeError: ufunc 'isfinite' not supported for the input types, and the inputs could not be safely coerced to any supported types according to the casting rule ''safe''
将错误提示传给ChatGPT:
经过测试发现,输入只包含float64并且没有NaN或者inf,因此问题并没有找到。
可能原因是模型自身问题。