阿里天池供应链需求预测(二)

61 篇文章 6 订阅
11 篇文章 1 订阅

阿里天池供应链需求预测第二阶段总结

一、已尝试的模型和存在的问题:

  • LSTM单变量多步预测模型:通过循环迭代预测,实现了通过前42天的历史需求数据来预测未来14天的库存资源需求量;但是目前由于有的Unit的历史数据非常少,导致LSTM往往处于一种过拟合的状态,预测的效果非常的差。

    [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-vc0G8Fyz-1639014876079)(%E9%98%BF%E9%87%8C%E5%A4%A9%E6%B1%A0%E4%BE%9B%E5%BA%94%E9%93%BE%E9%9C%80%E6%B1%82%E9%A2%84%E6%B5%8B%E7%AC%AC%E4%BA%8C%E9%98%B6%E6%AE%B5%E6%80%BB%E7%BB%93.assets/Unit_Length.jpg)]

  • 由第一次的尝试的经过我们选用了比较传统经典的模型ARIMA,这里总结回顾一下ARIMA的一个完整流程:

    1.数据准备与预处理

    Train_dt = read_csv('train.csv')[['unit','ts','qty']]
    Test_dt = read_csv('test.csv')[['unit','ts','qty']]
    Train_dt["ts"] = Train_dt["ts"].apply(lambda x: pd.to_datetime(x))  # 日期对应数据标准化为规范时间  这一步比较耗时间!!!!!
    Test_dt["ts"] = Test_dt["ts"].apply(lambda x: pd.to_datetime(x))  # 日期对应数据标准化为规范时间
    last_dt = pd.to_datetime("20210301")  # 用来限定使用的是历史数据而不是未来数据
    start_dt = pd.to_datetime("20210301")  # 用来划定预测的针对test的起始时间
    end_dt = pd.to_datetime("20210607")  # 预测需求的截止时间
    qty_using = pd.concat([Train_dt, Test_dt])
    for num,chunk in enumerate(qty_using.groupby("unit")):
    unit = chunk[0]
    demand = chunk[1]
    eval = demand.copy()
    demand["log qty"] = np.log(demand['qty'])  #对数处理源数据
    
  
  2.平稳性和非白噪声

  由于ARMA和ARIMA需要时间序列满足平稳性和非白噪声的要求,所以要用查分法和平滑法(滚动平均和滚动标准差)来实现序列的平稳性操作。一般情况下,对时间序列进行一阶差分法就可以实现序列的平稳性,有时需要二阶查分。
(1)差分法实现


```python
demand["diff"] = demand["qty"].diff().values
del demand["diff"]
demand = demand[1:]

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-jOlNMUKw-1639014876079)(%E9%98%BF%E9%87%8C%E5%A4%A9%E6%B1%A0%E4%BE%9B%E5%BA%94%E9%93%BE%E9%9C%80%E6%B1%82%E9%A2%84%E6%B5%8B%E7%AC%AC%E4%BA%8C%E9%98%B6%E6%AE%B5%E6%80%BB%E7%BB%93.assets/watermark,type_ZmFuZ3poZW5naGVpdGk,shadow_10,text_aHR0cHM6Ly9ibG9nLmNzZG4ubmV0L2ZvbmVvbmU=,size_16,color_FFFFFF,t_70.png)]
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-HfY3ddxu-1639014876080)(%E9%98%BF%E9%87%8C%E5%A4%A9%E6%B1%A0%E4%BE%9B%E5%BA%94%E9%93%BE%E9%9C%80%E6%B1%82%E9%A2%84%E6%B5%8B%E7%AC%AC%E4%BA%8C%E9%98%B6%E6%AE%B5%E6%80%BB%E7%BB%93.assets/watermark,type_ZmFuZ3poZW5naGVpdGk,shadow_10,text_aHR0cHM6Ly9ibG9nLmNzZG4ubmV0L2ZvbmVvbmU=,size_16,color_FFFFFF,t_70-16375704804282.png)]

一阶差分基本就满足了平稳性需要。
(2)平滑法处理

#滚动平均(平滑法不平稳处理)
demand_log_moving_avg = demand['log qty'].rolling(12).mean()
#滚动标准差
demand_log_std = demand['log qty'].rolling(12).std()

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-at5kYUv1-1639014876080)(%E9%98%BF%E9%87%8C%E5%A4%A9%E6%B1%A0%E4%BE%9B%E5%BA%94%E9%93%BE%E9%9C%80%E6%B1%82%E9%A2%84%E6%B5%8B%E7%AC%AC%E4%BA%8C%E9%98%B6%E6%AE%B5%E6%80%BB%E7%BB%93.assets/watermark,type_ZmFuZ3poZW5naGVpdGk,shadow_10,text_aHR0cHM6Ly9ibG9nLmNzZG4ubmV0L2ZvbmVvbmU=,size_16,color_FFFFFF,t_70-16375704931414.png)]

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-vyTotVxh-1639014876081)(%E9%98%BF%E9%87%8C%E5%A4%A9%E6%B1%A0%E4%BE%9B%E5%BA%94%E9%93%BE%E9%9C%80%E6%B1%82%E9%A2%84%E6%B5%8B%E7%AC%AC%E4%BA%8C%E9%98%B6%E6%AE%B5%E6%80%BB%E7%BB%93.assets/watermark,type_ZmFuZ3poZW5naGVpdGk,shadow_10,text_aHR0cHM6Ly9ibG9nLmNzZG4ubmV0L2ZvbmVvbmU=,size_16,color_FFFFFF,t_70-16375704989136.png)]

平滑法不太适合我造出来的数据,一般情况下,这种方法更适合带有周期性稳步上升的数据类型。
(3)ADF检验

除了上述两种对于时间序列的处理方法之外,还有一种以数据的方式呈现的平稳性检验方法:ADF检验。

#ADF检验 
x = np.array(diff1['values'])
adftest = adfuller(x, autolag='AIC')
print (adftest) 

结果如下:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-edWD6Be1-1639014876081)(%E9%98%BF%E9%87%8C%E5%A4%A9%E6%B1%A0%E4%BE%9B%E5%BA%94%E9%93%BE%E9%9C%80%E6%B1%82%E9%A2%84%E6%B5%8B%E7%AC%AC%E4%BA%8C%E9%98%B6%E6%AE%B5%E6%80%BB%E7%BB%93.assets/20190528132323891.png)]

如何确定该序列能否平稳呢?主要看:

(1)1%、5%、10%不同程度拒绝原假设的统计值和ADF Test result的比较,ADF Test result同时小于1%、5%、10%即说明非常好地拒绝该假设,本数据中,adf结果为-6.9, 小于三个level的统计值。

(2)P-value是否非常接近0.本数据中,P-value 为 7.9e-10,接近0。

ADF结果如何查看参考了这篇博客:

https://blog.csdn.net/weixin_42382211/article/details/81332431
(4)非白噪声检验

#纯随机性检验(白噪声检验)
p_value = acorr_ljungbox(timeseries, lags=1) 
print (p_value)

结果如图:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-jkiBQ0Gw-1639014876082)(%E9%98%BF%E9%87%8C%E5%A4%A9%E6%B1%A0%E4%BE%9B%E5%BA%94%E9%93%BE%E9%9C%80%E6%B1%82%E9%A2%84%E6%B5%8B%E7%AC%AC%E4%BA%8C%E9%98%B6%E6%AE%B5%E6%80%BB%E7%BB%93.assets/2019052813285024.png)]

统计量的P值小于显著性水平0.05,则可以以95%的置信水平拒绝原假设,认为序列为非白噪声序列(否则,接受原假设,认为序列为纯随机序列。)

由于P值为0.315远大于0.05所以接受原假设,认为时间序列是白噪声的,即是随机产生的序列,不具有时间上的相关性。(解释一下,由于老师没有给数据,所以只能硬着头皮,假设它是非白噪声的做)
4.时间序列定阶

定阶用到了ACF和PACF判断模型阶数、信息准则定阶(AIC、BIC、HQIC)、热力图定阶。
(1)ACF和PACF定阶

​ 直接采用步骤3的一阶差分后的数据来进行定阶操作。

def determinate_order(timeseries): 
    #利用ACF和PACF判断模型阶数
    plot_acf(timeseries,lags=40) #延迟数
    plot_pacf(timeseries,lags=40)
    plt.show()

结果如图所示:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-4FtiNZLJ-1639014876082)(%E9%98%BF%E9%87%8C%E5%A4%A9%E6%B1%A0%E4%BE%9B%E5%BA%94%E9%93%BE%E9%9C%80%E6%B1%82%E9%A2%84%E6%B5%8B%E7%AC%AC%E4%BA%8C%E9%98%B6%E6%AE%B5%E6%80%BB%E7%BB%93.assets/watermark,type_ZmFuZ3poZW5naGVpdGk,shadow_10,text_aHR0cHM6Ly9ibG9nLmNzZG4ubmV0L2ZvbmVvbmU=,size_16,color_FFFFFF,t_70-163757103350512.png)]

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-YIa326d0-1639014876083)(%E9%98%BF%E9%87%8C%E5%A4%A9%E6%B1%A0%E4%BE%9B%E5%BA%94%E9%93%BE%E9%9C%80%E6%B1%82%E9%A2%84%E6%B5%8B%E7%AC%AC%E4%BA%8C%E9%98%B6%E6%AE%B5%E6%80%BB%E7%BB%93.assets/watermark,type_ZmFuZ3poZW5naGVpdGk,shadow_10,text_aHR0cHM6Ly9ibG9nLmNzZG4ubmV0L2ZvbmVvbmU=,size_16,color_FFFFFF,t_70-163757103774414.png)]

上面分别是ACF和PACF的图,至于如何定阶不详细叙述了。一般是通过截尾和拖尾来确定阶数。目前还没有看到总结的比较好的文章。
(2)信息准则定阶

由于要通过ACF和PACF图来定阶,是一种看图的方法,因此可以计算AIC等值,来进行定阶。

#信息准则定阶:AIC、BIC、HQIC
#AIC
AIC = sm.tsa.arma_order_select_ic(timeseries,\
max_ar=6,max_ma=4,ic='aic')['aic_min_order']
#BIC
BIC = sm.tsa.arma_order_select_ic(timeseries,max_ar=6,\
max_ma=4,ic='bic')['bic_min_order']
#HQIC
HQIC = sm.tsa.arma_order_select_ic(timeseries,max_ar=6,\
max_ma=4,ic='hqic')['hqic_min_order']
print('the AIC is{},\nthe BIC is{}\n the HQIC is{}'.format(AIC,BIC,HQIC))

一般都是一个一个运行,最好不要一起运行,结果出来的太慢了。
(3)热力图定阶

其实热力图定阶的方式和(2)信息准则定阶的方式类似,只是用热力图的方式呈现了。

      #设置遍历循环的初始条件,以热力图的形式展示,跟AIC定阶作用一样
      p_min = 0
      q_min = 0
      p_max = 5
      q_max = 5
      d_min = 0
      d_max = 5
      # 创建Dataframe,以BIC准则
      results_aic = pd.DataFrame(index=['AR{}'.format(i) \
                                 for i in range(p_min,p_max+1)],\
              columns=['MA{}'.format(i) for i in range(q_min,q_max+1)])
      # itertools.product 返回p,q中的元素的笛卡尔积的元组
      for p,d,q in itertools.product(range(p_min,p_max+1),\
                                     range(d_min,d_max+1),range(q_min,q_max+1)):
          if p==0 and q==0:
              results_aic.loc['AR{}'.format(p), 'MA{}'.format(q)] = np.nan
              continue
          try:
              model = sm.tsa.ARIMA(timeseries, order=(p, d, q))
              results = model.fit()
              #返回不同pq下的model的BIC值
              results_aic.loc['AR{}'.format(p), 'MA{}'.format(q)] = results.aic
          except:
              continue
      results_aic = results_aic[results_aic.columns].astype(float)
      #print(results_bic)
      
      fig, ax = plt.subplots(figsize=(10, 8))
      ax = sns.heatmap(results_aic,
                   #mask=results_aic.isnull(),
                   ax=ax,
                   annot=True, #将数字显示在热力图上
                   fmt='.2f',
                   )
      ax.set_title('AIC')
      plt.show() 

图如下所示:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-K2dx5ubm-1639014876084)(%E9%98%BF%E9%87%8C%E5%A4%A9%E6%B1%A0%E4%BE%9B%E5%BA%94%E9%93%BE%E9%9C%80%E6%B1%82%E9%A2%84%E6%B5%8B%E7%AC%AC%E4%BA%8C%E9%98%B6%E6%AE%B5%E6%80%BB%E7%BB%93.assets/watermark,type_ZmFuZ3poZW5naGVpdGk,shadow_10,text_aHR0cHM6Ly9ibG9nLmNzZG4ubmV0L2ZvbmVvbmU=,size_16,color_FFFFFF,t_70-163757108962116.png)]

黑色的位置最好,可以看出p,q取(1,1)(3,1)(1,4)都可以。一般情况下是越小越好。

热力图实现过程参考了下面这篇博客(找不见了=,=)如果有侵权,请告知,会删除博文。
5.构建模型和预测
创建模型的代码基本同上,只不过ARIMA有三个参数:p,d,q。其中p和q可以参考定阶的方法确定。d指的是用了多少阶差分,在我的模型中运用了一阶差分,因此d=1。

由于预测都是针对的差分法后的数据做的预测,但是真实数据并不是那样的,因此我还对差分后的数据进行还原操作。

看一下通过预测与实际真实值对比,模型到底是否很好。

#迭代预测
        date_list = pd.date_range(start=start_dt, end=end_dt - datetime.timedelta(days=1))
        for date in date_list:
            if date.dayofweek != 0:
                # 周一为补货决策日,非周一不做决策
                pass
            else:
                demand_his = demand[(demand["ts"] >= date - datetime.timedelta(days=42)) & (demand["ts"] < date)]['log qty'].values.astype('float32')
                pmax = 4
                qmax = 4
                bic_matrix = []
                for p in range(pmax + 1):
                    tmp = []
                    for q in range(qmax + 1):
                        try:
                            tmp.append(ARIMA(demand_his, order = (p, 1, q)).fit().bic)
                        except:
                            tmp.append(None)
                    bic_matrix.append(tmp)
                bic_matrix = pd.DataFrame(bic_matrix)
                try:
                    p, q = bic_matrix.stack().idxmin()
                    model = ARIMA(demand_his, order=(p, 1, q)).fit()
                    model.summary()
                    forecast = model.forecast(7, alpha=0.05)
                    forecast = np.exp(np.array(forecast))
                    # forecast = diff1_reduction(forecast[1], demand_his)
                    demand.loc[(demand["ts"] > date) & (demand["ts"] <= date + datetime.timedelta(days=7)), 'qty'] = forecast
                except:
                    p = 1
                    q = 4
                    print('ARIMA检验存在问题的unit:{}'.format(unit))

        plt.plot(eval['ts'].iloc[-99:], eval['qty'].iloc[-99:],label = 'True')
        plt.plot(demand['ts'].iloc[-99:],demand['qty'].iloc[-99:],label='test_predict')
        plt.legend(loc='best')
        # plt.savefig('./result/{}_pred_true_compare.jpg'.format(unit), dpi=300)
        plt.show()

        rmse = math.sqrt(mean_squared_error(eval['qty'].iloc[-99:].values, demand['qty'].iloc[-99:].values))
        mape = np.mean(np.abs(demand['qty'].iloc[-99:] - eval['qty'].iloc[-99:].values) / np.abs(eval['qty'].iloc[-99:].values))

        print(f'RMSE: {rmse}')
        print(f'MAPE: {mape}')
        score = r2_score(eval['qty'].iloc[-99:], demand['qty'].iloc[-99:])
        print("模型检验的 r^2: ", score)
        if num == 0:
            result = demand
        else:
            result = pd.concat([result,demand])

    result = result.drop('log qty',axis = 1)
    result = result[result["ts"] > start_dt]
    result.to_csv('./result/result.csv',index=False)

目前可参考的文献:

  • https://www.jianshu.com/p/b4571e63fcfc 这篇文献主要是用四种LSTM进行集成对股票数据的预测,同样是单变量模型,并且还和ARIMA模型进行了对比。
  • 对应的code:https://github.com/amanjain252002/Stock-Price-Prediction

目前模型的定阶还有一系列操作还是不太自适应,考虑把ARIMA全部换成auto_arima进行预测

参考书籍 预测:原理与实践

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-nruPBI1d-1639014876084)(%E9%98%BF%E9%87%8C%E5%A4%A9%E6%B1%A0%E4%BE%9B%E5%BA%94%E9%93%BE%E9%9C%80%E6%B1%82%E9%A2%84%E6%B5%8B%E7%AC%AC%E4%BA%8C%E9%98%B6%E6%AE%B5%E6%80%BB%E7%BB%93.assets/fpp2_cover.jpg)]

二、目前正在进行和需要进一步开展的工作:

  • 分层时间序列预测:

10.1 显示了一个 K=2层级结构。 层次结构的顶部(我们称之为级别 0)是“总计”,即数据的最聚合级别。 该 t总序列的第 th 个观察值表示为 yt对于 t=1,…,T. 总计在第 1 级分解为两个系列,而在层次结构的最底层,它们又分别分为三个和两个系列。 在顶层之下,我们使用 yj,t表示 t节点 对应的系列的第 th 个观察值 j. 例如, yA,t表示 t与节点 A 在级别 1, 对应的系列的第 th 个观察 yAB,t表示 t第 2 级观察节点 AB 对应的序列,依此类推。

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-QwVeiBAi-1639014876085)(%E9%98%BF%E9%87%8C%E5%A4%A9%E6%B1%A0%E4%BE%9B%E5%BA%94%E9%93%BE%E9%9C%80%E6%B1%82%E9%A2%84%E6%B5%8B%E7%AC%AC%E4%BA%8C%E9%98%B6%E6%AE%B5%E6%80%BB%E7%BB%93.assets/hts.png)]

在这个小例子中,层级中的系列总数为 n = 1 + 2 + 5 = 8 , 而底层的系列数为 m = 5 . 注意 n > m

在所有层次结构中。

翻译有点问题见谅

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-rI0oLTxF-1639014876085)(%E9%98%BF%E9%87%8C%E5%A4%A9%E6%B1%A0%E4%BE%9B%E5%BA%94%E9%93%BE%E9%9C%80%E6%B1%82%E9%A2%84%E6%B5%8B%E7%AC%AC%E4%BA%8C%E9%98%B6%E6%AE%B5%E6%80%BB%E7%BB%93.assets/image-20211124154033968.png)]

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-esVOEdLc-1639014876086)(%E9%98%BF%E9%87%8C%E5%A4%A9%E6%B1%A0%E4%BE%9B%E5%BA%94%E9%93%BE%E9%9C%80%E6%B1%82%E9%A2%84%E6%B5%8B%E7%AC%AC%E4%BA%8C%E9%98%B6%E6%AE%B5%E6%80%BB%E7%BB%93.assets/image-20211124154044016.png)]

示例:澳大利亚旅游层级

澳大利亚分为八个地理区域(一些称为州,另一些称为领土),每个区域都有自己的政府以及一些经济和行政自治权。 这些中的每一个都可以进一步细分为更小的感兴趣区域,称为区域。 商业规划者和旅游当局对整个澳大利亚、各州和领地以及各区域的预测感兴趣。 在这个例子中,我们专注于季度国内旅游需求,以澳大利亚人离家出游的游客夜数来衡量。 为了简化我们的分析,我们将这两个领土和塔斯马尼亚合并为一个“其他”州。 所以我们有六个州:新南威尔士州 (NSW)、昆士兰州 (QLD)、南澳大利亚州 (SAU)、维多利亚州 (VIC)、西澳大利亚州 (WAU) 和其他 (OTH)。 对于其中的每一个,我们考虑以下区域内的访客之夜。

状态区域
新南威尔士州地铁 (NSWMetro)、北海岸 (NSWNthCo)、南海岸 (NSWSthCo)、南内陆 (NSWSthIn)、北内陆 (NSWNthIn)
昆士兰州地铁 (QLDMetro)、中部 (QLDCntrl)、北海岸 (QLDNthCo)
或者地铁 (SAUMetro)、沿海 (SAUCaast)、内陆 (SAUInner)
维克地铁 (VICMetro)、西海岸 (VICWstCo)、东海岸 (VICEstCo)、内陆 (VICInner)
瓦乌地铁 (WAUMetro)、沿海 (WAUCoast)、内陆 (WAUInner)
奥特仪表 (OTHMetro),非仪表 (OTHNoMet)

我们考虑了新南威尔士州的五个区域,维多利亚州的四个区域,以及昆士兰州、南澳州和西澳州的三个区域。 请注意,都会区包含首府城市和周边地区。 有关这些地理区域的详细信息,请参阅附录C中 Wickramasuriya,Athanasopoulos,与海德门 2019 ) 。

要创建分层时间序列,我们使用 hts()功能如下面的代码所示。 该函数需要两个输入:底层时间序列和有关层次结构的信息。 visnights是包含底层序列的时间序列矩阵。 有多种方法可以输入层次结构。 在这种情况下,我们使用 characters争论。 每个列名的前三个字符 visnights在层次结构的第一层(状态)捕获类别。 以下五个字符捕获了底层类别(区域)。

library(hts)
tourism.hts <- hts(visnights, characters = c(3, 5))
tourism.hts %>% aggts(levels=0:1) %>%
  autoplot(facet=TRUE) +
  xlab("Year") + ylab("millions") + ggtitle("Visitor nights")

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-lFYl4zcc-1639014876086)(%E9%98%BF%E9%87%8C%E5%A4%A9%E6%B1%A0%E4%BE%9B%E5%BA%94%E9%93%BE%E9%9C%80%E6%B1%82%E9%A2%84%E6%B5%8B%E7%AC%AC%E4%BA%8C%E9%98%B6%E6%AE%B5%E6%80%BB%E7%BB%93.assets/tourismStates-1.png)]

图 10.2:按州分列的 1998 年第一季度至 2016 年第四季度期间澳大利亚国内游客过夜数。

图 中的上图 10.2 显示了整个澳大利亚的游客过夜总数,而下图显示了按州分类的数据。 这些揭示了综合国家层面的多样化和丰富的动态,以及每个州的第一级分解。 这 aggts()函数从一个时间序列中提取 hts任何级别的聚合对象。

图 中的图 10.3 显示了底层时间序列,即每个区域的访客夜数。 这些帮助我们可视化每个区域内不同的个体动态,并帮助识别独特且重要的时间序列。 请注意,例如,沿海 WAU 区域在过去几年中显示出显着增长。

library(tidyverse)
cols <- sample(scales::hue_pal(h=c(15,375),
  c=100,l=65,h.start=0,direction = 1)(NCOL(visnights)))
as_tibble(visnights) %>%
  gather(Zone) %>%
  mutate(Date = rep(time(visnights), NCOL(visnights)),
         State = str_sub(Zone,1,3)) %>%
  ggplot(aes(x=Date, y=value, group=Zone, colour=Zone)) +
    geom_line() +
    facet_grid(State~., scales="free_y") +
    xlab("Year") + ylab("millions") +
    ggtitle("Visitor nights by Zone") +
    scale_colour_manual(values = cols)

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-339rx9Bs-1639014876086)(%E9%98%BF%E9%87%8C%E5%A4%A9%E6%B1%A0%E4%BE%9B%E5%BA%94%E9%93%BE%E9%9C%80%E6%B1%82%E9%A2%84%E6%B5%8B%E7%AC%AC%E4%BA%8C%E9%98%B6%E6%AE%B5%E6%80%BB%E7%BB%93.assets/tourismZones-1.png)]

上图按区域分类的 1998 年第一季度至 2016 年第四季度澳大利亚国内游客过夜数。

为了生成这个图,我们使用了 各种函数 tidyverse 包集合中的 。 详细信息超出了本书的范围,但是有许多很好的在线资源可以用来学习如何使用这些包。

import os
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
%matplotlib inline
import numpy as np
import datetime
from sklearn.preprocessing import MinMaxScaler
from sklearn.preprocessing import LabelEncoder
from chinese_calendar import is_workday, is_holiday
import datetime
import sys
sys.path.append("./src")

from vis import get_nodes_edges_position, make_annotations
import plotly.graph_objects as go

from hts import HTSRegressor
demand_train_A = '../data/demand_train_A.csv'
geo_topo = '../data/geo_topo.csv'
inventory_info_A = '../data/inventory_info_A.csv'
product_topo = '../data/product_topo.csv'
weight_A = '../data/weight_A.csv'

demand_train_A = pd.read_csv(demand_train_A)
geo_topo = pd.read_csv(geo_topo)
inventory_info_A = pd.read_csv(inventory_info_A)
product_topo = pd.read_csv(product_topo)
weight_A = pd.read_csv(weight_A)

demand_test_A = '../data/demand_test_A.csv'
demand_test_A = pd.read_csv(demand_test_A)

demand_train_A["ts"] = demand_train_A["ts"].apply(lambda x:pd.to_datetime(x))  #日期对应数据标准化为规范时间
demand_test_A["ts"] = demand_test_A["ts"].apply(lambda x:pd.to_datetime(x)) #日期对应数据标准化为规范时间
demand_all = pd.concat([demand_train_A,demand_test_A])
demand_all = pd.merge(demand_all,weight_A, left_on = 'unit', right_on = 'unit')
demand_all = demand_all.drop(['Unnamed: 0_x','Unnamed: 0_y'],axis = 1)
demand_all = demand_all.drop(['geography_level','product_level'],axis = 1)
demand_all = pd.merge(demand_all,geo_topo, left_on = 'geography', right_on = 'geography_level_3')
demand_all = pd.merge(demand_all,product_topo, left_on = 'product', right_on = 'product_level_2')

取一部分的product作为试验

demand_all = demand_all.drop(['geography','product'],axis = 1)
prs = demand_all.product_level_2.unique()[0]
demand_all = demand_all[demand_all['product_level_2'].apply(lambda x: True if x in prs else False)]
le = LabelEncoder()
def lb(dataframe,col,name):
    le.fit(dataframe[col])
    dataframe[col] = le.transform(dataframe[col])
    dataframe[col] = dataframe[col].apply(lambda x: name +'_'+ str(x))
    return dataframe


cols = ['geography_level_1','geography_level_2','geography_level_3',
       'product_level_1','product_level_2']
for col in cols:
    demand_all = lb(demand_all, col,col[-7:])
demand_all = lb(demand_all, 'unit','unit')


demand_all["product_unit"] = demand_all.apply(lambda x: f"{x['product_level_2']}_{x['unit']}", axis=1)
demand_all["product_unit"] = demand_all.apply(lambda x: f"{x['product_level_2']}_{x['unit']}", axis=1)


##暂时筛选出商品层级
demand_all = demand_all.drop(['geography_level_1','geography_level_2','geography_level_3'],axis = 1)

可视化层级的树状结构

grouped_sections = demand_all.groupby(["product_level_2", "product_unit"])
edges_hierarchy = list(grouped_sections.groups.keys())
edges_hierarchy[:]

second_level_nodes = demand_all.product_level_2.unique()
root_node = "total"

root_edges = [(root_node, second_level_node) for second_level_node in second_level_nodes]

root_edges += edges_hierarchy

Xn, Yn, Xe, Ye, labels, annot = get_nodes_edges_position(root_edges, root="total")
M = max(Yn)

fig = go.Figure()

fig.add_trace(go.Scatter(x=Xe,
                   y=Ye,
                   mode='lines',
                   line=dict(color='rgb(210,210,210)', width=1),
                   hoverinfo='none'
                   ))

fig.add_trace(go.Scatter(x=Xn,
                  y=Yn,
                  mode='markers',
                  name='bla',
                  marker=dict(symbol='circle-dot',
                                size=30,
                                color='#6175c1',    #'#DB4551',
                                line=dict(color='rgb(50,50,50)', width=1)
                                ),
                  text=labels,
                  hoverinfo='text',
                  opacity=0.8
                  ))

axis = dict(showline=False, # hide axis line, grid, ticklabels and  title
            zeroline=False,
            showgrid=False,
            showticklabels=False,
            )

fig.update_layout(title= '天池供应链需求数据的层次结构树',
              annotations=annot,
              font_size=10,
              showlegend=False,
              xaxis=axis,
              yaxis=axis,
              margin=dict(l=40, r=40, b=85, t=100),
              hovermode='closest',
              plot_bgcolor='rgb(248,248,248)'
              )
fig.show()
demand_train = demand_all[demand_all['ts'] < pd.to_datetime('20210302')]
demand_test = demand_all[demand_all['ts'] >= pd.to_datetime('20210302')]
demand_train_bottom_level = demand_train.pivot(index="ts", columns="product_unit", values="qty")
demand_test_bottom_level = demand_test.pivot(index="ts", columns="product_unit", values="qty")

demand_train_bottom_level = demand_train_bottom_level.fillna(0)
demand_test_bottom_level = demand_test_bottom_level.fillna(0)

def get_state_columns(df, product_level_2):
    return [col for col in df.columns if product_level_2 in col]

products = demand_train["product_level_2"].unique().tolist()

for product in products:
    product_cols = get_state_columns(demand_train_bottom_level, product)
    demand_train_bottom_level[product] = demand_train_bottom_level[product_cols].sum(axis=1)
    product_cols = get_state_columns(demand_test_bottom_level, product)
    demand_test_bottom_level[product] = demand_test_bottom_level[product_cols].sum(axis=1)

demand_train_bottom_level["total"] = demand_train_bottom_level[products].sum(axis=1)
demand_test_bottom_level["total"] = demand_test_bottom_level[products].sum(axis=1)

hierarchy = dict()

for edge in root_edges:
    parent, children = edge[0], edge[1]
    hierarchy.get(parent)
    if not hierarchy.get(parent):
        hierarchy[parent] = [children]
    else:
        hierarchy[parent] += [children]
        
        
# 不改变数据中的任何东西,除了索引是 QS
demand_train_bottom_level.index = pd.to_datetime(demand_train_bottom_level.index)
demand_test_bottom_level.index = pd.to_datetime(demand_test_bottom_level.index)
# demand_all_bottom_level = demand_all_bottom_level.resample("QS").sum()


clf = HTSRegressor(model='auto_arima', revision_method='OLS', n_jobs=0)
model = clf.fit(demand_train_bottom_level.iloc[-42:], hierarchy)
Fitting models: 100%|██████████████████████████████████████████████████████████████████| 53/53 [00:43<00:00,  1.21it/s]
predicted_autoarima = model.predict(steps_ahead=99)

plt.figure(figsize=(20, 10))

def plot_results(col,preds):
    preds[col].plot(label="Predicted")
    demand_test_bottom_level[col].plot(label="Observed")
    plt.legend()
    plt.title(col)
    plt.xlabel("Date")
    plt.ylabel("N of Qty")
    
# plot_results(products[0],predicted_autoarima[-99:])
col = 'level_2_0_unit_6'
plot_results(col,predicted_autoarima[-99:])
plt.tight_layout()

from sklearn.metrics import mean_absolute_error
print(mean_absolute_error(demand_test_bottom_level[col],predicted_autoarima[-98:][col]) / 
     demand_test_bottom_level[col].mean() * 100, '%')
Fitting models:   0%|                                                                           | 0/53 [00:00<?, ?

Fitting models: 100%|█████████████████████████████████████████████████████████████████| 53/53 [00:00<00:00, 187.80it/s]


68.95385134133873 %

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-hFQ6aG5W-1639014876088)(%E9%98%BF%E9%87%8C%E5%A4%A9%E6%B1%A0%E4%BE%9B%E5%BA%94%E9%93%BE%E9%9C%80%E6%B1%82%E9%A2%84%E6%B5%8B%E7%AC%AC%E4%BA%8C%E9%98%B6%E6%AE%B5%E6%80%BB%E7%BB%93.assets/output_9_2.png)]

num = 0
for col in demand_test_bottom_level.columns:
    mape = (mean_absolute_error(demand_test_bottom_level[col],predicted_autoarima[-98:][col]) / 
         demand_test_bottom_level[col].mean() * 100)
    if mape > 0:
        print(mape,'%')
        num +=1
result = pd.read_csv('../data/result_1.csv')
true = pd.read_csv('../data/demand_test_A.csv')
for unit in result['unit'].unique():
    temp_true = true[true['unit'] == unit]
    temp_pred = result[result['unit'] == unit]
    mape = (mean_absolute_error(temp_true['qty'],temp_pred['qty']) / temp_true['qty'].mean() * 100)
    if mape > 5:
        print(mape,'%')
        print(unit)
#根据geograph_level_2与unit对应关系构建的层级预测结果校验,采用的是arima将历史所有数据进行拟合,预测未来的7天qty需求,数据处理与后面的auto_arima一致
6.723370291658045 %
03e2603c3134f61842a393a960693aff
7.512233358725807 %
091e56749c44bef9f58eec0c9ed60280
85.58177310182313 %
15e27721c2e99e7897d613a5f281a92c
55.972178956057576 %
1bfbefa22e0e641366673df7f3977278
17.26465216939635 %
2e1697537caf9cb17d81e49feeb79914
6.496660110169583 %
4773f81c82bacd3fdea9b306d2adae75
6.957986832723311 %
52efc1dfe202fa05c9b7f87711dead95
6.137690718627218 %
56862dbf28a0e3a83a86370e2dcb14ac
7.739665043348025 %
5a78fab4c4c6e747b085c82a4e611080
5.143360549412331 %
754343952ad1589f2c81da0b6c0ed72d
10.441919135423372 %
7b0d3987cb54dd19dbb0cee7f36dbfbd
8.257417957886263 %
7d9cbb373fddba4ce2cddcec96bccbeb
9.922597815037285 %
83bdf03890f543cc0732eb0fa80fb9e0
9.767265360653719 %
8aa97bb705f78f014a9046130df0b03a
5.174707265649723 %
8c12bcd7914a074400cb2d54fb24c360
21.37553433180414 %
8ccf1c02bb050cb3fc4f13789cdfe235
12.229719236285112 %
9c1da4fb74f1ba23ed694225e7b623a0
10.057865858534251 %
a4feb7bd8511f582ffb5023299178ff4
5.3426890917483 %
b035f859cf03840b75abd80dc1cf3e94
5.367696645186517 %
b3f32e02cae9388022f4d06a91311030
8.721015579784645 %
be3aee91b31d8c0350a2f0ed2d21b9a3
6.817164748680692 %
c185032995df20b0c9ef9ae52dd95ccb
6.33182043918111 %
c36d9c12341e6e17a35135144ddafbcd
9.1658129454196 %
c83371eb7a520fb90e220f83f645d789
6.960909796732503 %
ca7d84670a90d845578e07f12debc242
6.6961999075551715 %
d4457d3854301ee3c80d57c506a8bfcc
8.46410976925359 %
dec0b2a698a629fa1f029282029b36e6
9.546219755419834 %
e1f711289e92c38c5fb93d71354634e4
5.2905131606022 %
e5847ca3df209918422da7d5a7b1cd98
5.73205195285176 %
ec6ca57a9ab96da37b77f5cfd7ce6c8e
15.477022916906758 %
f5438e2eae441b33a3b24e5ad65f7e99
10.71941212081028 %
fbb83aefc6f5d6f6bc22ae3ee757d327
5.456101702080666 %
fe994bc0a7241fe686a2eeee39cd2695

##看到这个结果基本可以放弃了,为什么对比了后面的模型可以相应的解释,但是针对层及预测的学习还是很有必要的,也许这个较为复杂的策略短期内不是上分的好技巧,但是研究和可解释性还有待进一步探索和学习
test = pd.merge(result,true, left_on = 'unit', right_on = 'unit')
mape = (mean_absolute_error(test['qty_x'],test['qty_y']) / test['qty_x'].mean() * 100)
mape = (mean_absolute_error(true['qty'],result['qty']) / true['qty'].mean() * 100)
mape
4.277596831546743 %
unittsqty
00025accbb2e3dfbfe6f5b3a4562bdee02021/3/23536.666667
6320025accbb2e3dfbfe6f5b3a4562bdee02021/3/33546.218750
12640025accbb2e3dfbfe6f5b3a4562bdee02021/3/43570.999023
18960025accbb2e3dfbfe6f5b3a4562bdee02021/3/53616.421143
25280025accbb2e3dfbfe6f5b3a4562bdee02021/3/63612.384766
............
587760025accbb2e3dfbfe6f5b3a4562bdee02021/6/33815.427979
594080025accbb2e3dfbfe6f5b3a4562bdee02021/6/43791.432129
600400025accbb2e3dfbfe6f5b3a4562bdee02021/6/53822.543945
606720025accbb2e3dfbfe6f5b3a4562bdee02021/6/63823.613281
613040025accbb2e3dfbfe6f5b3a4562bdee02021/6/73840.084717

98 rows × 3 columns

Unnamed: 0unittsqtygeography_levelgeographyproduct_levelproduct
4666810025accbb2e3dfbfe6f5b3a4562bdee02021-03-023536.666667geography_level_311407e1f167b84f374d9e999a1ed9563product_level_21807ffa2c1c84035f4346c3364104dd1
92913830025accbb2e3dfbfe6f5b3a4562bdee02021-03-033564.666667geography_level_311407e1f167b84f374d9e999a1ed9563product_level_21807ffa2c1c84035f4346c3364104dd1
169624940025accbb2e3dfbfe6f5b3a4562bdee02021-03-043614.333333geography_level_311407e1f167b84f374d9e999a1ed9563product_level_21807ffa2c1c84035f4346c3364104dd1
227333330025accbb2e3dfbfe6f5b3a4562bdee02021-03-053610.000000geography_level_311407e1f167b84f374d9e999a1ed9563product_level_21807ffa2c1c84035f4346c3364104dd1
255537480025accbb2e3dfbfe6f5b3a4562bdee02021-03-063609.000000geography_level_311407e1f167b84f374d9e999a1ed9563product_level_21807ffa2c1c84035f4346c3364104dd1
...........................
58777862130025accbb2e3dfbfe6f5b3a4562bdee02021-06-033790.333333geography_level_311407e1f167b84f374d9e999a1ed9563product_level_21807ffa2c1c84035f4346c3364104dd1
59751876300025accbb2e3dfbfe6f5b3a4562bdee02021-06-043819.666667geography_level_311407e1f167b84f374d9e999a1ed9563product_level_21807ffa2c1c84035f4346c3364104dd1
60632889430025accbb2e3dfbfe6f5b3a4562bdee02021-06-053820.666667geography_level_311407e1f167b84f374d9e999a1ed9563product_level_21807ffa2c1c84035f4346c3364104dd1
61142896840025accbb2e3dfbfe6f5b3a4562bdee02021-06-063836.000000geography_level_311407e1f167b84f374d9e999a1ed9563product_level_21807ffa2c1c84035f4346c3364104dd1
61712905140025accbb2e3dfbfe6f5b3a4562bdee02021-06-073871.333333geography_level_311407e1f167b84f374d9e999a1ed9563product_level_21807ffa2c1c84035f4346c3364104dd1

98 rows × 8 columns

三、目前较优的模型

  1. 基于对数处理的auto_arima,自适应的为每一段时间序列采用aic和bic指标进行评价定阶获得最优的参数组合
最初迭代预测低效定阶版本
if __name__ == '__main__':
    # 加载数据

    Train_dt = read_csv('train.csv')[['unit','ts','qty']]
    Test_dt = read_csv('test.csv')[['unit','ts','qty']]
    Train_dt["ts"] = Train_dt["ts"].apply(lambda x: pd.to_datetime(x))  # 日期对应数据标准化为规范时间  这一步比较耗时间!!!!!
    Test_dt["ts"] = Test_dt["ts"].apply(lambda x: pd.to_datetime(x))  # 日期对应数据标准化为规范时间
    last_dt = pd.to_datetime("20210301")  # 用来限定使用的是历史数据而不是未来数据
    start_dt = pd.to_datetime("20210301")  # 用来划定预测的针对test的起始时间
    end_dt = pd.to_datetime("20210607")  # 预测需求的截止时间

    qty_using = pd.concat([Train_dt, Test_dt])
    for num,chunk in enumerate(qty_using.groupby("unit")):
        unit = chunk[0]
        demand = chunk[1]
        eval = demand.copy()
        # demand["diff"] = demand["qty"].diff().values
        demand["log qty"] = np.log(demand['qty'])
        # demand_log_moving_avg = demand['log qty'].rolling(12).mean()  # 平滑处理也是针对非平稳性的处理方式
        # demand_log_std = demand['log qty'].rolling(12).std()
        # del demand["diff"]
        #         # demand = demand[1:]

#迭代预测
        date_list = pd.date_range(start=start_dt, end=end_dt - datetime.timedelta(days=1))
        for date in date_list:
            if date.dayofweek != 0:
                # 周一为补货决策日,非周一不做决策
                pass
            else:
                demand_his = demand[(demand["ts"] >= date - datetime.timedelta(days=42)) & (demand["ts"] < date)]['log qty'].values.astype('float32')
                pmax = 4
                qmax = 4
                bic_matrix = []
                for p in range(pmax + 1):
                    tmp = []
                    for q in range(qmax + 1):
                        try:
                            tmp.append(ARIMA(demand_his, order = (p, 1, q)).fit().bic)
                        except:
                            tmp.append(None)
                    bic_matrix.append(tmp)
                bic_matrix = pd.DataFrame(bic_matrix)
                try:
                    p, q = bic_matrix.stack().idxmin()
                    model = ARIMA(demand_his, order=(p, 1, q)).fit()
                    model.summary()
                    forecast = model.forecast(7, alpha=0.05)
                    forecast = np.exp(np.array(forecast))
                    # forecast = diff1_reduction(forecast[1], demand_his)
                    demand.loc[(demand["ts"] > date) & (demand["ts"] <= date + datetime.timedelta(days=7)), 'qty'] = forecast
                except:
                    p = 1
                    q = 4
                    print('ARIMA检验存在问题的unit:{}'.format(unit))

        plt.plot(eval['ts'].iloc[-99:], eval['qty'].iloc[-99:],label = 'True')
        plt.plot(demand['ts'].iloc[-99:],demand['qty'].iloc[-99:],label='test_predict')
        plt.legend(loc='best')
        # plt.savefig('./result/{}_pred_true_compare.jpg'.format(unit), dpi=300)
        plt.show()

        rmse = math.sqrt(mean_squared_error(eval['qty'].iloc[-99:].values, demand['qty'].iloc[-99:].values))
        mape = np.mean(np.abs(demand['qty'].iloc[-99:] - eval['qty'].iloc[-99:].values) / np.abs(eval['qty'].iloc[-99:].values))

        print(f'RMSE: {rmse}')
        print(f'MAPE: {mape}')
        score = r2_score(eval['qty'].iloc[-99:], demand['qty'].iloc[-99:])
        print("模型检验的 r^2: ", score)
        if num == 0:
            result = demand
        else:
            result = pd.concat([result,demand])

    result = result.drop('log qty',axis = 1)
    result = result[result["ts"] > start_dt]
    result.to_csv('./result/result.csv',index=False)

共计耗时: 约5小时

最终修改迭代预测版本
if __name__ == '__main__':
    # 加载数据
    t0 = time.time()

    Train_dt = read_csv('train.csv')[['unit','ts','qty']]
    Test_dt = read_csv('test.csv')[['unit','ts','qty']]
    Train_dt["ts"] = Train_dt["ts"].apply(lambda x: pd.to_datetime(x))  # 日期对应数据标准化为规范时间  这一步比较耗时间!!!!!
    Test_dt["ts"] = Test_dt["ts"].apply(lambda x: pd.to_datetime(x))  # 日期对应数据标准化为规范时间
    last_dt = pd.to_datetime("20210301")  # 用来限定使用的是历史数据而不是未来数据
    start_dt = pd.to_datetime("20210301")  # 用来划定预测的针对test的起始时间
    end_dt = pd.to_datetime("20210607")  # 预测需求的截止时间

    qty_using = pd.concat([Train_dt, Test_dt])
    for num,chunk in enumerate(qty_using.groupby("unit")):
        t1 = time.time()
        unit = chunk[0]
        demand = chunk[1]
        eval = demand.copy()
        demand["log qty"] = np.log(demand['qty'])

#迭代预测
        date_list = pd.date_range(start=start_dt, end=end_dt - datetime.timedelta(days=1))
        for date in date_list:
            if date.dayofweek != 0:
                # 周一为补货决策日,非周一不做决策
                pass
            else:
                demand_his = demand[(demand["ts"] >= date - datetime.timedelta(days=42)) & (demand["ts"] < date)]['log qty'].values.astype('float32')
                try:
                    model = auto_arima(demand_his, trace=True, error_action='ignore', suppress_warnings=True)
                    model.fit(demand_his)
                    forecast = model.predict(n_periods=7)
                    forecast = np.exp(np.array(forecast))
                    # forecast = diff1_reduction(forecast[1], demand_his)
                    demand.loc[(demand["ts"] > date) & (demand["ts"] <= date + datetime.timedelta(days=7)), 'qty'] = forecast
                except:
                    print('ARIMA检验存在问题的unit:{}'.format(unit))
                    #出了问题的用前面42天的均值代替
                    demand_future = np.mean(demand_his)    #这里是针对特殊的unit也就是arima无法预测的不能通过平稳性检验的
                    if demand_future == -np.inf:
                        demand_future = 0
                    demand.loc[(demand["ts"] > date) & (demand["ts"] <= date + datetime.timedelta(days=7)), 'qty'] = np.array([demand_future] * 7)

        #可视化预测值与真实值的拟合效果
        # plt.plot(eval['ts'].iloc[-99:], eval['qty'].iloc[-99:],label = 'True')
        # plt.plot(demand['ts'].iloc[-99:],demand['qty'].iloc[-99:],label='test_predict')
        # plt.legend(loc='best')
        # # plt.savefig('./result/{}_pred_true_compare.jpg'.format(unit), dpi=300)
        # plt.show()
        t2 = time.time()
        print('模型进度 : ', '{} / {} \n'.format(str(num), str(632)))
        print('One units time cost:', t2-t1)
        # rmse = math.sqrt(mean_squared_error(eval['qty'].iloc[-99:].values, demand['qty'].iloc[-99:].values))
        # mape = np.mean(np.abs(demand['qty'].iloc[-99:] - eval['qty'].iloc[-99:].values) / np.abs(eval['qty'].iloc[-99:].values))
        #
        # print(f'RMSE: {rmse}')
        # print(f'MAPE: {mape}')
        # score = r2_score(eval['qty'].iloc[-99:], demand['qty'].iloc[-99:])
        # print("模型检验的 r^2: ", score)
        if num == 0:
            result = demand
        else:
            result = pd.concat([result,demand])

    result = result.drop('log qty',axis = 1)
    result = result[result["ts"] > start_dt]
    result.to_csv('./result/result.csv',index=False)
    t3 = time.time()
    print('总消耗时长 :', (t3 - t0)/3600, '小时')

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-NJWqBrS3-1639014876089)(%E9%98%BF%E9%87%8C%E5%A4%A9%E6%B1%A0%E4%BE%9B%E5%BA%94%E9%93%BE%E9%9C%80%E6%B1%82%E9%A2%84%E6%B5%8B%E7%AC%AC%E4%BA%8C%E9%98%B6%E6%AE%B5%E6%80%BB%E7%BB%93.assets/image-20211125085719922.png)]

RMSE 均方根误差:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-SWeMtpOS-1639014876089)(%E9%98%BF%E9%87%8C%E5%A4%A9%E6%B1%A0%E4%BE%9B%E5%BA%94%E9%93%BE%E9%9C%80%E6%B1%82%E9%A2%84%E6%B5%8B%E7%AC%AC%E4%BA%8C%E9%98%B6%E6%AE%B5%E6%80%BB%E7%BB%93.assets/U1$T7_HZAP7PRACNSK%25SV0.jpg)]

目前预测的部分我尝试了层级的方法,效果不是太理想,可以算出来相当一部分unit的MAPE已经超过5%。

最好的预测方法还是auto_arima我改了一版大概算出来整体的相对误差在1.08%大部分的unit在0.5%上下浮动,少部分的unit确实存在>5%和10%的情况,后期我们会主要针对这些unit进行分析

总耗时:1.473小时

相对误差:1.08%

MAE = 71.667

  1. 论坛中的notebook–auto_X自适应的时间序列预测:

参考github项目auto-x

完整colab代码

# -*- coding: utf-8 -*-
# 安装autox
!git clone https://github.com/4paradigm/AutoX.git
!pip install pytorch_tabnet
!pip install ./AutoX

import os
import pandas as pd

"""## 数据预处理"""

data_name = '../../data/'
path = f'./{data_name}'

# 赛题数据demand_test_A中给了标签,我们需要将它删掉。同时我们顺便删掉无用的'Unnamed: 0'列

demand_train_A = pd.read_csv(f'{path}/demand_train_A.csv')
demand_test_A = pd.read_csv(f'{path}/demand_test_A.csv')

demand_train_A.drop('Unnamed: 0', axis=1, inplace=True)
demand_test_A.drop(['Unnamed: 0', 'qty'], axis=1, inplace=True)

# 将 demand_train_A, demand_test_A 保存为train.csv, test.csv
demand_train_A.to_csv(path + '/train.csv', index = False)
demand_test_A.to_csv(path + '/test.csv', index = False)

"""## 导入所需的包"""

from autox import AutoX

"""## 初始化AutoX类"""

# 数据集是多表数据集,需要配置表关系
relations = [
    {
            "related_to_main_table": "true", # 是否为和主表的关系
            "left_entity": "train.csv",  # 左表名字
            "left_on": ["product"],  # 左表拼表键
            "right_entity": "product_topo.csv",  # 右表名字
            "right_on": ["product_level_2"], # 右表拼表键
            "type": "1-1" # 左表与右表的连接关系
        },  # train.csv和product_topo.csv两张表是1对1的关系,拼接键为train.csv中的product列 和 product_topo.csv中的product_level_2列
    {
            "related_to_main_table": "true", # 是否为和主表的关系
            "left_entity": "test.csv",  # 左表名字
            "left_on": ["product"],  # 左表拼表键
            "right_entity": "product_topo.csv",  # 右表名字
            "right_on": ["product_level_2"], # 右表拼表键
            "type": "1-1" # 左表与右表的连接关系
        },  # test.csv和product_topo.csv两张表是1对1的关系,拼接键为test.csv中的product列 和 product_topo.csv中的product_level_2列
    {
            "related_to_main_table": "true", # 是否为和主表的关系
            "left_entity": "train.csv",  # 左表名字
            "left_on": ["geography"],  # 左表拼表键
            "right_entity": "geo_topo.csv",  # 右表名字
            "right_on": ["geography_level_3"], # 右表拼表键
            "type": "1-1" # 左表与右表的连接关系
        },  # train.csv和geo_topo.csv两张表是1对1的关系,拼接键为train.csv中的geography列 和 geo_topo.csv中的geography_level_3列
    {
            "related_to_main_table": "true", # 是否为和主表的关系
            "left_entity": "test.csv",  # 左表名字
            "left_on": ["geography"],  # 左表拼表键
            "right_entity": "geo_topo.csv",  # 右表名字
            "right_on": ["geography_level_3"], # 右表拼表键
            "type": "1-1" # 左表与右表的连接关系
        } # test.csv和geo_topo.csv两张表是1对1的关系,拼接键为test.csv中的geography列 和 geo_topo.csv中的geography_level_3列
]

autox = AutoX(target = 'qty', train_name = 'train.csv', test_name = 'test.csv', 
               id = ['unit'], path = path, time_series=True, ts_unit='D',time_col = 'ts',
               relations = relations
              )  #feature_type = feature_type,

sub = autox.get_submit_ts()

# 检查预测结果和真实结果的差距
sub.rename({'qty': 'qty_pre'}, axis=1, inplace=True)
demand_test_A = pd.read_csv(f'{path}/demand_test_A.csv', usecols = ['unit','ts','qty'])

analyze = demand_test_A.merge(sub, on = ['unit', 'ts'], how = 'left')


# 查看mae
from sklearn.metrics import mean_absolute_error
y_true = analyze['qty']
y_pred = analyze['qty_pre']

print(mean_absolute_error(y_true, y_pred))


[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-LE8fpMIy-1639014876090)(%E9%98%BF%E9%87%8C%E5%A4%A9%E6%B1%A0%E4%BE%9B%E5%BA%94%E9%93%BE%E9%9C%80%E6%B1%82%E9%A2%84%E6%B5%8B%E7%AC%AC%E4%BA%8C%E9%98%B6%E6%AE%B5%E6%80%BB%E7%BB%93.assets/image-20211125160555903.png)]

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-4H6TUCQq-1639014876091)(%E9%98%BF%E9%87%8C%E5%A4%A9%E6%B1%A0%E4%BE%9B%E5%BA%94%E9%93%BE%E9%9C%80%E6%B1%82%E9%A2%84%E6%B5%8B%E7%AC%AC%E4%BA%8C%E9%98%B6%E6%AE%B5%E6%80%BB%E7%BB%93.assets/image-20211125160612201.png)]

模型评价指标分析:

平均绝对误差MAE = 474.893

RMSE:尚未计算,会逊色于目前我们的baseline

总耗时约:48分钟

四、进一步的工作安排:

  1. 我们需要至少一位队员针对auto-X进行研究和了解,深入分析其内部类及过程,主要针对其时间序列预测部分;

  2. 运筹优化库存补货策略部分,可能需要一人参与协助,目前这项工作将在预测工作大体完成之后其决定性作用;

  3. 针对多变量多步的时间序列方法上次提及的相关的内容仍有待进一步探索,即是尝试提升也是学习;

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-a7UtTo8q-1639014876091)(%E9%98%BF%E9%87%8C%E5%A4%A9%E6%B1%A0%E4%BE%9B%E5%BA%94%E9%93%BE%E9%9C%80%E6%B1%82%E9%A2%84%E6%B5%8B%E7%AC%AC%E4%BA%8C%E9%98%B6%E6%AE%B5%E6%80%BB%E7%BB%93.assets/image-20211127173123991.png)]

目前:没有使用运筹优化策略的成绩

  • 5
    点赞
  • 36
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值