基于PaddleTS的LSTNet时序预测模型实现中国人口预测

基于PaddleTS的LSTNet时序预测模型实现中国人口预测

1. 引言 ✨

1.1 项目简介 🎄

本项目属于机器学习范畴,根据指定数据集(中国人口数据集等)使用PaddleTS进行LSTNet网络模型搭建全流程,包括数据预处理、模型构建、模型训练、模型预测、预测结果可视化等。

  • 我们将根据中国人口数据集中的多个特征(features),例如:出生人口(万)、中国人均GPA(美元计)、中国性别比例(按照女生=100)、自然增长率(%)等8个特征字段,预测中国未来总人口(万人)这1个标签字段。属于多输入,单输出LSTM神经网路预测范畴。

  • 对于本项目使用的工具PaddleTS,PaddleTS是一个易用的深度时序建模的Python库,它基于飞桨深度学习框架PaddlePaddle,专注业界领先的深度模型,旨在为领域专家和行业用户提供可扩展的时序建模能力和便捷易用的用户体验。

  • 尤其是其内置业界领先的深度学习模型,包括NBEATS、NHiTS、LSTNet、TCN、Transformer, DeepAR(概率预测)、Informer等时序预测模型,以及TS2Vec等时序表征模型。本项目将使用其中的LSTNet深度学习模型完成项目开发。

🌈 LSTM(Long Short Term Memory networks)出现以来,在捕获时间序列依赖关系方面表现出了强大的潜力,LSTM由三个门来控制细胞状态,这三个门分别称为忘记门、输入门和输出门,细胞状态像传送带一样。它贯穿整个细胞却只有很少的分支,这样能保证信息不变的流过整个RNN。LSTM是一种特殊的RNN,主要是为了解决长序列训练过程中的梯度消失和梯度爆炸问题。

🌈 2018年,正式提出了LSTNet。该网络是一种专门设计用于时间序列预测的深度学习网络,LSTNet的出现可以认为是研究人员通过注意力机制提升LSTM模型时序预测能力的一次尝试。

  • 卷积层组件(Convolutional Component)
    • LSTNet的第一层是一个没有池化的卷积网络,用于抽取时间维度的短期模式和变量间的局部依赖;卷积层由若干个宽为ω,高为n的多个过滤器组成(高与变量的数目相同)。
  • 循环(递归)层组件(Recurrent Component)
    • 卷积层的输出同时输入到递归分量(the Recurrent component)和递归跳过分量(Recurrent-skipcomponent)(在第3小节中进行描述)。递归分量是门控递归单元(GRU)的递归层,使用RELU函数作为隐式更新激活函数。
  • Recurrent-skip组件
    • 循环层由GRU和LSTM单元精心设计,以记住历史信息,从而了解相对长期的依赖关系。然而,由于梯度消失,GRU和LSTM在实际应用中往往不能捕捉到非常长期的相关性。
  • 时间注意层(Temporal AttentionLayer)
    • 注意机制,它学习输入矩阵的每个窗口位置上隐藏表示的加权组合。
  • 自回归组件(AutoregressiveComponent)
    • LSTNet的最终预测分解为线性部分(主要关注局部尺度问题)和包含重复模式的非线性部分,在LSTNet体系结构中,采用经典的自回归(AR)模型作为线性分量。

1.2 数据集介绍 🌲

本项目使用的数据集为中国人口预测数据集,包含10个字段,其中8个特征字段,1个标签字段,1个行索引字段,数据集各字段对应的数据类型如下表所示:

  • 有些字段为int64类型,需要经过相关的数据处理,才可传入模型进行训练。
  • 数据包含50条样本,因此应该合理确定训练数据、测试数据和验证数据
年份出生人口(万)总人口(万人)中国人均GPA(美元计)中国性别比例(按照女生=100)自然增长率(%)城镇人口(城镇+乡村=100)乡村人口美元兑换人民币汇率中国就业人口(万人)
int64int64int64int64float64float64float64float64float64int64

本项目使用环境为:V100;RAM32GB;显存32GB;磁盘100GB;4核CPU;

2. 环境准备

安装模块并导入模块

2.1 安装依赖库

# 安装paddlets依赖库 
!pip install paddlets -q
! pip install seaborn -q

2.2 导入所需库

%matplotlib inline
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import paddlets
from paddlets import TSDataset
from paddlets import TimeSeries
from paddlets.models.forecasting import MLPRegressor, LSTNetRegressor
from paddlets.transform import Fill, StandardScaler
from paddlets.metrics import MSE, MAE

3. 数据处理

3.1 导入数据

population = pd.read_csv("data/data140190/人口.csv")
population.head()
年份出生人口(万)总人口(万人)中国人均GPA(美元计)中国性别比例(按照女生=100)自然增长率(%)城镇人口(城镇+乡村=100)乡村人口美元兑换人民币汇率中国就业人口(万人)
01970271082542113105.9025.9517.3882.622.461834432
11971255184779118105.8223.4017.2682.742.267335520
21972255086727131105.7822.2717.1382.872.240135854
31973244788761157105.8620.9917.2082.802.020236652
41974222690409160105.8817.5717.1682.841.839737369
population['date'] = pd.to_datetime(population['年份'])

3.2 查看各字段类型

因为数据集中有些字段为int64类型,需要转化成float类型,才能在后续导入模型中进行训练,否则会报错。

population.dtypes
年份                           int64
出生人口(万)                      int64
总人口(万人)                      int64
中国人均GPA(美元计)                 int64
中国性别比例(按照女生=100)           float64
自然增长率(%)                   float64
城镇人口(城镇+乡村=100)            float64
乡村人口                       float64
美元兑换人民币汇率                  float64
中国就业人口(万人)                   int64
date                datetime64[ns]
dtype: object

3.3 数据可视化

3.3.1 特征(features)折线图

绘制出各个特征与年份索引之间的折线图,进行初步观察

因为数据集中包含中文字段,想要能够在绘图中正常显示中文,需要进行如下设定:

from pylab import mpl
from matplotlib.font_manager import FontProperties
myfont=FontProperties(fname=r'/usr/share/fonts/fangzheng/FZSYJW.TTF',size=12)
sns.set(font=myfont.get_name())

from pylab import mpl
from matplotlib.font_manager import FontProperties
myfont=FontProperties(fname=r'/usr/share/fonts/fangzheng/FZSYJW.TTF',size=12)
sns.set(font=myfont.get_name())




titles = [
    "出生人口(万)",
    "总人口(万人)",
    "中国人均GPA(美元计)",
    "中国性别比例(按照女生=100)",
    "自然增长率(%)",
    "城镇人口(城镇+乡村=100)",
    "乡村人口",
    "美元兑换人民币汇率",
    "中国就业人口(万人)",
]

feature_keys = [
    "出生人口(万)",
    "总人口(万人)",
    "中国人均GPA(美元计)",
    "中国性别比例(按照女生=100)",
    "自然增长率(%)",
    "城镇人口(城镇+乡村=100)",
    "乡村人口",
    "美元兑换人民币汇率",
    "中国就业人口(万人)",
]

colors = [
    "blue",
    "chocolate",
    "green",
    "red",
    "purple",
    "brown",
    "darkblue",
    "black",
    "magenta",
]

date_time_key = "年份"


def show_raw_visualization(data):
    time_data = data[date_time_key]
    fig, axes = plt.subplots(
        nrows=3, ncols=3, figsize=(15, 15), dpi=100, facecolor="w", edgecolor="k"
    )
    for i in range(len(feature_keys)):
        key = feature_keys[i]
        c = colors[i % (len(colors))]
        t_data = data[key]
        t_data.index = time_data
        t_data.head()
        
        ax = t_data.plot(
            ax=axes[i // 3, i % 3],
            color=c,
            title="{}".format(titles[i], key),
            rot=25,
        )
        ax.legend([titles[i]])
    plt.tight_layout()


show_raw_visualization(population)
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/pandas/plotting/_matplotlib/tools.py:331: MatplotlibDeprecationWarning: 
The is_first_col function was deprecated in Matplotlib 3.4 and will be removed two minor releases later. Use ax.get_subplotspec().is_first_col() instead.
  if ax.is_first_col():

在这里插入图片描述

3.3.2 箱型图

查看部分数据的分布情况,下面抽取了出生人口(万)、总人口(万人)、中国人均GPA(美元计)、中国就业人口(万人)这四个字段进行箱型图展示。


from pylab import mpl
from matplotlib.font_manager import FontProperties
myfont=FontProperties(fname=r'/usr/share/fonts/fangzheng/FZSYJW.TTF',size=12)
sns.set(font=myfont.get_name())
plt.figure(figsize=(15,8),dpi=100)
plt.subplot(1,4,1)
sns.boxplot(y="出生人口(万)", data=population, saturation=0.9)
plt.subplot(1,4,2)
sns.boxplot(y="总人口(万人)", data=population, saturation=0.9)
plt.subplot(1,4,3)
sns.boxplot(y="中国人均GPA(美元计)", data=population, saturation=0.9)
plt.subplot(1,4,4)
sns.boxplot(y="中国就业人口(万人)", data=population, saturation=0.9)
plt.tight_layout()

在这里插入图片描述

3.3.3 相关性分析

查看变量两两之间的相关性,可以考虑将两个相关性较强的特征选择一个进行保留,因为本数据集特征字段不是很多,就不考虑剔除了。

corr = population.corr()
# 调用热力图绘制相关性关系
plt.figure(figsize=(10,10),dpi=100)
sns.heatmap(corr, square=True, linewidths=0.1, annot=True)
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/seaborn/utils.py:80: UserWarning: Glyph 8722 (\N{MINUS SIGN}) missing from current font.
  fig.canvas.draw()
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/IPython/core/events.py:89: UserWarning: Glyph 8722 (\N{MINUS SIGN}) missing from current font.
  func(*args, **kwargs)





<AxesSubplot:>



/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/IPython/core/pylabtools.py:151: UserWarning: Glyph 8722 (\N{MINUS SIGN}) missing from current font.
  fig.canvas.print_figure(bytes_io, **kw)

在这里插入图片描述

3.4 数据预处理

3.4.1 打印特征字段

因为数据集特征数量有限,我们保留了数据集中所有的特征字段,下面进行对选取的特征进行打印和展示。

print(
    "选取的参数指标是:",
    ", ".join([titles[i] for i in [0, 1, 2, 3, 4, 5, 6, 7, 8]]),
)
selected_features = [feature_keys[i] for i in [0, 1, 2, 3, 4, 5, 6, 7, 8]]
features = population[selected_features]
features.index = population[date_time_key]
features.head()
选取的参数指标是: 出生人口(万), 总人口(万人), 中国人均GPA(美元计), 中国性别比例(按照女生=100), 自然增长率(%), 城镇人口(城镇+乡村=100), 乡村人口, 美元兑换人民币汇率, 中国就业人口(万人)
出生人口(万)总人口(万人)中国人均GPA(美元计)中国性别比例(按照女生=100)自然增长率(%)城镇人口(城镇+乡村=100)乡村人口美元兑换人民币汇率中国就业人口(万人)
年份
1970271082542113105.9025.9517.3882.622.461834432
1971255184779118105.8223.4017.2682.742.267335520
1972255086727131105.7822.2717.1382.872.240135854
1973244788761157105.8620.9917.2082.802.020236652
1974222690409160105.8817.5717.1682.841.839737369

3.4.2 重复值检测

查看特征中是否包含重复值,返回false说明没有重复值。无需剔除。

features.duplicated().any()
False

3.4.3 缺失值检测

查看是否有缺失的样本。返回True说明无缺失值,无需进行额外处理。

pd.notnull(features).all()
出生人口(万)             True
总人口(万人)             True
中国人均GPA(美元计)        True
中国性别比例(按照女生=100)    True
自然增长率(%)            True
城镇人口(城镇+乡村=100)     True
乡村人口                True
美元兑换人民币汇率           True
中国就业人口(万人)          True
dtype: bool

3.4.4 转换字段类型

我们需要将int类型字段转化为float类型的字段,方可用于模型训练,否则会报错。

  • 首先使用如下语句查看数据集各个字段类型。
  • 接下来我们将int64转为float64并替换原数据字段
population.dtypes
年份                    int64
出生人口(万)               int64
总人口(万人)               int64
中国人均GPA(美元计)          int64
中国性别比例(按照女生=100)    float64
自然增长率(%)            float64
城镇人口(城镇+乡村=100)     float64
乡村人口                float64
美元兑换人民币汇率           float64
中国就业人口(万人)            int64
dtype: object
population.head()
年份出生人口(万)总人口(万人)中国人均GPA(美元计)中国性别比例(按照女生=100)自然增长率(%)城镇人口(城镇+乡村=100)乡村人口美元兑换人民币汇率中国就业人口(万人)
01970271082542113105.9025.9517.3882.622.461834432
11971255184779118105.8223.4017.2682.742.267335520
21972255086727131105.7822.2717.1382.872.240135854
31973244788761157105.8620.9917.2082.802.020236652
41974222690409160105.8817.5717.1682.841.839737369
population['出生人口(万)'] = population['出生人口(万)'].astype('float64')
population['总人口(万人)'] = population['总人口(万人)'].astype('float64')
population['中国人均GPA(美元计)'] = population['中国人均GPA(美元计)'].astype('float64')
population['中国就业人口(万人)'] = population['中国就业人口(万人)'].astype('float64')
# population['date'] = population['年份'].astype('str')
# population['date'] = pd.to_datetime(population['date'])

以下面命令展示出处理后的数据,若“出生人口(万)”字段出现NaN,则需要重新读取一次数据进行以上的数据处理即可。

population.head()
年份出生人口(万)总人口(万人)中国人均GPA(美元计)中国性别比例(按照女生=100)自然增长率(%)城镇人口(城镇+乡村=100)乡村人口美元兑换人民币汇率中国就业人口(万人)出生人口(万)1
019702710.082542.0113.0105.9025.9517.3882.622.461834432.02710.0
119712551.084779.0118.0105.8223.4017.2682.742.267335520.02551.0
219722550.086727.0131.0105.7822.2717.1382.872.240135854.02550.0
319732447.088761.0157.0105.8620.9917.2082.802.020236652.02447.0
419742226.090409.0160.0105.8817.5717.1682.841.839737369.02226.0

3.5 构造TSDataset

TSDataset 是 PaddleTS 中最主要的类之一,其被设计用来表示绝大多数时序样本数据。通常,时序数据可以分为以下几种:

  • 单变量数据,只包含单列的预测目标,同时可以包涵单列或者多列协变量
  • 多变量数据,包涵多列预测目标,同时可以包涵单列或者多列协变量

TSDataset 需要包含time_index属性,time_index支持 pandas.DatetimeIndex 和 pandas.RangeIndex 两种类型。

target_cov_dataset = TSDataset.load_from_dataframe(
    population,
    time_col='年份',
    target_cols='总人口(万人)',
    observed_cov_cols=['出生人口(万)', '中国人均GPA(美元计)', '中国性别比例(按照女生=100)', '自然增长率(%)',
       '城镇人口(城镇+乡村=100)', '乡村人口', '美元兑换人民币汇率', '中国就业人口(万人)'],
    fill_missing_dates=True,
    fillna_method='pre'
)
target_cov_dataset.plot(['总人口(万人)', '出生人口(万)', '中国人均GPA(美元计)', '中国性别比例(按照女生=100)', '自然增长率(%)',
       '城镇人口(城镇+乡村=100)', '乡村人口', '美元兑换人民币汇率', '中国就业人口(万人)'])

<AxesSubplot:>

在这里插入图片描述

target_cov_dataset
       总人口(万人)  出生人口(万)  中国人均GPA(美元计)  中国性别比例(按照女生=100)  自然增长率(%)  \
1970   82542.0   2710.0         113.0            105.90     25.95   
1971   84779.0   2551.0         118.0            105.82     23.40   
1972   86727.0   2550.0         131.0            105.78     22.27   
1973   88761.0   2447.0         157.0            105.86     20.99   
1974   90409.0   2226.0         160.0            105.88     17.57   
1975   91970.0   2102.0         178.0            106.04     15.77   
1976   93267.0   1849.0         165.0            106.15     12.72   
1977   94774.0   1783.0         185.0            106.17     12.12   
1978   96159.0   1733.0         156.0            106.16     12.00   
1979   97542.0   1715.0         183.0            106.00     11.61   
1980   98705.0   1776.0         194.0            105.98     11.87   
1981  100072.0   2064.0         197.0            106.11     14.55   
1982  101654.0   2230.0         203.0            106.19     15.68   
1983  103008.0   2052.0         225.0            106.61     13.29   
1984  104357.0   2050.0         250.0            106.61     13.08   
1985  105851.0   2196.0         294.0            107.04     14.26   
1986  107507.0   2374.0         281.0            107.04     15.57   
1987  109300.0   2508.0         251.0            106.19     16.61   
1988  111026.0   2445.0         283.0            106.27     15.73   
1989  112704.0   2396.0         310.0            106.40     15.04   
1990  114333.0   2374.0         317.0            106.27     14.39   
1991  115823.0   2250.0         333.0            105.52     12.98   
1992  117171.0   2113.0         366.0            104.27     11.50   
1993  118517.0   2120.0         377.0            104.18     11.45   
1994  119850.0   2098.0         473.0            104.51     11.21   
1995  121121.0   2052.0         609.0            104.21     10.55   
1996  122389.0   2057.0         709.0            103.34     10.42   
1997  123626.0   2028.0         781.0            104.36     10.06   
1998  124761.0   1934.0         828.0            105.13      9.14   
1999  125786.0   1827.0         873.0            105.89      8.18   
2000  126743.0   1765.0         959.0            106.74      7.58   
2001  127627.0   1696.0        1053.0            106.00      6.95   
2002  128453.0   1641.0        1148.0            106.06      6.45   
2003  129227.0   1594.0        1288.0            106.20      6.01   
2004  129988.0   1588.0        1508.0            106.29      5.87   
2005  130756.0   1612.0        1753.0            106.30      5.89   
2006  131448.0   1581.0        2099.0            106.29      5.28   
2007  132129.0   1591.0        2693.0            106.19      5.17   
2008  132802.0   1604.0        3468.0            106.07      5.08   
2009  133450.0   1587.0        3832.0            105.93      4.87   
2010  134091.0   1588.0        4550.0            105.21      4.79   
2011  134735.0   1600.0        5618.0            105.18      4.79   
2012  135404.0   1635.0        6316.0            105.13      4.95   
2013  136072.0   1640.0        7050.0            105.10      4.92   
2014  136782.0   1687.0        7678.0            105.06      5.21   
2015  137462.0   1655.0        8066.0            105.02      4.96   
2016  138271.0   1786.0        8147.0            104.98      5.86   
2017  139008.0   1723.0        8879.0            104.81      5.32   
2018  139538.0   1523.0        9976.0            104.64      3.81   
2019  140005.0   1465.0       10261.0            104.60      3.34   

      城镇人口(城镇+乡村=100)   乡村人口  美元兑换人民币汇率  中国就业人口(万人)  
1970            17.38  82.62     2.4618     34432.0  
1971            17.26  82.74     2.2673     35520.0  
1972            17.13  82.87     2.2401     35854.0  
1973            17.20  82.80     2.0202     36652.0  
1974            17.16  82.84     1.8397     37369.0  
1975            17.34  82.66     1.9663     38168.0  
1976            17.44  82.55     1.8803     38834.0  
1977            17.55  82.45     1.7300     39377.0  
1978            17.92  82.08     1.5771     40152.0  
1979            18.96  81.04     1.4962     41024.0  
1980            19.39  80.61     1.5303     42361.0  
1981            20.16  79.84     1.7051     43725.0  
1982            21.13  78.87     1.8926     45295.0  
1983            21.52  78.38     1.9757     46436.0  
1984            23.01  76.99     2.3270     48197.0  
1985            23.71  76.29     2.9367     49873.0  
1986            24.52  75.48     3.4528     51282.0  
1987            25.32  74.68     3.7221     52783.0  
1988            25.81  74.19     3.7221     54334.0  
1989            25.21  73.79     3.7659     55329.0  
1990            25.41  73.59     4.7838     64749.0  
1991            26.94  73.06     5.3227     65491.0  
1992            27.46  72.54     5.5149     66152.0  
1993            27.99  72.01     5.7619     66808.0  
1994            28.51  71.49     8.6187     67455.0  
1995            29.04  70.95     8.3507     68065.0  
1996            30.48  69.52     8.3142     68950.0  
1997            31.91  68.09     8.2898     69820.0  
1998            33.35  66.55     8.2791     70537.0  
1999            34.78  65.22     8.2796     71394.0  
2000            36.22  63.78     8.2784     72085.0  
2001            37.66  62.34     8.2770     72797.0  
2002            39.09  60.91     8.2770     73280.0  
2003            40.53  59.47     8.2774     73736.0  
2004            41.76  58.24     8.2780     74264.0  
2005            42.99  57.01     8.1013     74547.0  
2006            44.34  55.56     7.8087     74978.0  
2007            45.89  54.11     7.3872     75321.0  
2008            46.99  53.01     6.8500     75564.0  
2009            48.34  51.66     6.8100     75828.0  
2010            49.95  50.05     6.6220     76105.0  
2011            51.27  48.73     6.6100     76420.0  
2012            52.57  47.43     6.2500     76704.0  
2013            53.73  46.27     6.0700     76977.0  
2014            54.77  45.23     6.0500     77253.0  
2015            56.10  43.90     6.2000     77451.0  
2016            57.35  42.65     6.5600     77603.0  
2017            58.52  41.48     6.5000     77640.0  
2018            59.58  40.42     6.8600     77586.0  
2019            60.60  39.40     6.8967     77471.0  
target_cov_dataset.summary()
总人口(万人)出生人口(万)中国人均GPA(美元计)中国性别比例(按照女生=100)自然增长率(%)城镇人口(城镇+乡村=100)乡村人口美元兑换人民币汇率中国就业人口(万人)
missing0.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.000000
count50.00000050.00000050.00000050.00000050.00000050.00000050.00000050.00000050.000000
mean116769.6400001943.4200002120.840000105.67360010.74120033.58480066.3688005.21976861200.560000
std17393.764444334.8861573010.9874190.8148795.53781414.14831814.1252932.58072315633.538166
min82542.0000001465.000000113.000000103.3400003.34000017.13000039.4000001.49620034432.000000
25%101992.5000001640.250000208.500000105.1075005.45500021.22750054.4725002.28222545580.250000
50%120485.5000001838.000000541.000000105.95500010.88000028.77500071.2200006.06000067760.000000
75%131958.7500002177.0000002544.500000106.19000014.35750045.50250078.7475007.70332575235.250000
max140005.0000002710.00000010261.000000107.04000025.95000060.60000082.8700008.61870077640.000000

3.6 划分数据集

训练集:验证集:测试集 = 0.6 :0.2 :0.2

train_dataset, val_test_dataset = target_cov_dataset.split(0.6)
val_dataset, test_dataset = val_test_dataset.split(0.5)
train_dataset.plot(add_data=[val_dataset,test_dataset], labels=['Val', 'Test'])
<AxesSubplot:>

在这里插入图片描述

3.7 归一化

scaler = StandardScaler()
scaler.fit(train_dataset)
train_dataset_scaled = scaler.transform(train_dataset)
val_test_dataset_scaled = scaler.transform(val_test_dataset)
val_dataset_scaled = scaler.transform(val_dataset)
test_dataset_scaled = scaler.transform(test_dataset)
train_dataset_scaled
       总人口(万人)   出生人口(万)  中国人均GPA(美元计)  中国性别比例(按照女生=100)  自然增长率(%)  \
1970 -1.785855  2.118316     -0.978375          0.191621  2.796566   
1971 -1.611715  1.512902     -0.955117          0.101799  2.182756   
1972 -1.460072  1.509094     -0.894647          0.056887  1.910754   
1973 -1.301734  1.116907     -0.773707          0.146710  1.602645   
1974 -1.173445  0.275419     -0.759752          0.169165  0.779418   
1975 -1.051928 -0.196728     -0.676024          0.348810  0.346141   
1976 -0.950963 -1.160061     -0.736494          0.472316 -0.388024   
1977 -0.833650 -1.411365     -0.643464          0.494771 -0.532450   
1978 -0.725834 -1.601747     -0.778358          0.483543 -0.561335   
1979 -0.618173 -1.670284     -0.652767          0.303899 -0.655212   
1980 -0.527639 -1.438018     -0.601600          0.281443 -0.592627   
1981 -0.421224 -0.341418     -0.587645          0.427404  0.052475   
1982 -0.298073  0.290650     -0.559736          0.517227  0.324477   
1983 -0.192670 -0.387110     -0.457402          0.988794 -0.250820   
1984 -0.087657 -0.394725     -0.341113          0.988794 -0.301369   
1985  0.028645  0.161190     -0.136445          1.471589 -0.017331   
1986  0.157557  0.838950     -0.196915          1.471589  0.297999   
1987  0.297134  1.349173     -0.336462          0.517227  0.548337   
1988  0.431495  1.109292     -0.187612          0.607049  0.336512   
1989  0.562120  0.922718     -0.062021          0.753010  0.170422   
1990  0.688930  0.838950     -0.029460          0.607049  0.013961   
1991  0.804920  0.366803      0.044965         -0.235035 -0.325440   
1992  0.909855 -0.154844      0.198466         -1.638508 -0.681690   
1993  1.014635 -0.128191      0.249633         -1.739558 -0.693725   
1994  1.118403 -0.211959      0.696181         -1.369042 -0.751496   
1995  1.217345 -0.387110      1.328791         -1.705875 -0.910364   
1996  1.316053 -0.368072      1.793945         -2.682693 -0.941656   
1997  1.412348 -0.478493      2.128857         -1.537458 -1.028312   
1998  1.500702 -0.836411      2.347479         -0.672919 -1.249765   
1999  1.580494 -1.243829      2.556799          0.180393 -1.480846   

      城镇人口(城镇+乡村=100)      乡村人口  美元兑换人民币汇率  中国就业人口(万人)  
1970        -1.118440  1.124925  -0.586832   -1.335278  
1971        -1.140859  1.147199  -0.664354   -1.250396  
1972        -1.165148  1.171330  -0.675195   -1.224339  
1973        -1.152069  1.158336  -0.762839   -1.162081  
1974        -1.159543  1.165761  -0.834781   -1.106143  
1975        -1.125913  1.132349  -0.784322   -1.043808  
1976        -1.107230  1.111931  -0.818599   -0.991848  
1977        -1.086678  1.093369  -0.878503   -0.949485  
1978        -1.017550  1.024689  -0.939444   -0.889022  
1979        -0.823245  0.831644  -0.971688   -0.820992  
1980        -0.742907  0.751827  -0.958097   -0.716683  
1981        -0.599046  0.608898  -0.888428   -0.610268  
1982        -0.417819  0.428846  -0.813696   -0.487782  
1983        -0.344954  0.337892  -0.780576   -0.398765  
1984        -0.066575  0.079879  -0.640559   -0.261377  
1985         0.064208 -0.050056  -0.397553   -0.130621  
1986         0.215542 -0.200409  -0.191853   -0.020695  
1987         0.365008 -0.348905  -0.084519    0.096408  
1988         0.456555 -0.439860  -0.084519    0.217412  
1989         0.344456 -0.514108  -0.067061    0.295039  
1990         0.381823 -0.551232   0.338640    1.029957  
1991         0.667676 -0.649611   0.553427    1.087846  
1992         0.764829 -0.746134   0.630032    1.139415  
1993         0.863850 -0.844513   0.728478    1.190594  
1994         0.961002 -0.941036   1.867103    1.241071  
1995         1.060023 -1.041271   1.760287    1.288661  
1996         1.329062 -1.306709   1.745739    1.357706  
1997         1.596232 -1.572147   1.736014    1.425580  
1998         1.865270 -1.858003   1.731750    1.481518  
1999         2.132440 -2.104879   1.731949    1.548379  

4. 构建网络模型

LSTNet是2018年提出的时序预测模型, 它同时利用卷积层和循环层的优势, 提取时间序列多变量之间的局部依赖模式和捕获复杂的长期依赖。

我们将训练最大轮数设置为2000,模型输入时间序列长度为5,输出序列长度为5。其余参数说明如下:

  • in_chunk_len (int) – 模型输入的时间序列长度.

  • out_chunk_len (int) – 模型输出的时间序列长度.

  • skip_chunk_len (int) – 可选变量, 输入序列与输出序列之间跳过的序列长度, 既不作为特征也不作为序测目标使用, 默认值为0

  • sampling_stride (int) – 相邻样本间的采样间隔.

  • loss_fn (Callable[…, paddle.Tensor]|None) – 损失函数.

  • optimizer_fn (Callable[…, Optimizer]) – 优化算法.

  • optimizer_params (Dict[str, Any]) – 优化器参数.

  • eval_metrics (List[str]) – 模型训练过程中的需要观测的评估指标.

  • callbacks (List[Callback]) – 自定义callback函数.

  • batch_size (int) – 训练数据或评估数据的批大小.

  • max_epochs (int) – 训练的最大轮数.

  • verbose (int) – 模型训练过程中打印日志信息的间隔.

  • patience (int) – 模型训练过程中, 当评估指标超过一定轮数不再变优,模型提前停止训练.

  • seed (int|None) – 全局随机数种子, 注: 保证每次模型参数初始化一致.

  • skip_size (int) – 递归跳跃组件(Skip RNN)用来捕获时间序列中的周期性所需的周期长度.

  • channels (int) – 第一层Conv1D的通道数量.

  • kernel_size (int) – 第一层Conv1D的卷积核大小.

  • rnn_cell_type (str) – RNN cell的类型, 支持GRU或LSTM.

  • rnn_num_cells (int) – RNN层中神经元的数量.

  • skip_rnn_cell_type (str) – Skip RNN cell的类型, 支持GRU或LSTM.

  • skip_rnn_num_cells (int) – Skip RNN层中神经元的数量.

  • dropout_rate (float) – 神经元丢弃概率.

  • output_activation (str|None) – 输出层的激活函数类型, 可以是None(无激活函数), sigmoid, tanh.

lstm = LSTNetRegressor(
    in_chunk_len = 5,
    out_chunk_len = 5,
    max_epochs=2000
)

5. 模型训练

使用归一化后的训练数据与验证数据传入模型,进行模型训练。

lstm.fit(train_dataset_scaled, val_dataset_scaled)
[2022-11-30 16:10:43,402] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 000| loss: 1.335296| val_0_mae: 2.026456| 0:00:00s
[2022-11-30 16:10:43,411] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 001| loss: 1.431604| val_0_mae: 2.010468| 0:00:00s
[2022-11-30 16:10:43,419] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 002| loss: 1.398407| val_0_mae: 1.994482| 0:00:00s
[2022-11-30 16:10:43,428] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 003| loss: 1.504742| val_0_mae: 1.978191| 0:00:00s
[2022-11-30 16:10:43,436] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 004| loss: 1.397552| val_0_mae: 1.961923| 0:00:00s
[2022-11-30 16:10:43,445] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 005| loss: 1.433584| val_0_mae: 1.945455| 0:00:00s
[2022-11-30 16:10:43,453] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 006| loss: 1.365905| val_0_mae: 1.929153| 0:00:00s
[2022-11-30 16:10:43,461] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 007| loss: 1.234721| val_0_mae: 1.913647| 0:00:00s
[2022-11-30 16:10:43,469] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 008| loss: 1.225891| val_0_mae: 1.898627| 0:00:00s
[2022-11-30 16:10:43,478] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 009| loss: 1.268852| val_0_mae: 1.883862| 0:00:00s
[2022-11-30 16:10:43,486] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 010| loss: 1.266971| val_0_mae: 1.869215| 0:00:00s
[2022-11-30 16:10:43,494] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 011| loss: 1.260334| val_0_mae: 1.854911| 0:00:00s
[2022-11-30 16:10:43,502] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 012| loss: 1.268996| val_0_mae: 1.840751| 0:00:00s
[2022-11-30 16:10:43,511] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 013| loss: 1.205134| val_0_mae: 1.826880| 0:00:00s
[2022-11-30 16:10:43,519] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 014| loss: 1.238074| val_0_mae: 1.812879| 0:00:00s
[2022-11-30 16:10:43,527] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 015| loss: 1.247744| val_0_mae: 1.798924| 0:00:00s
[2022-11-30 16:10:43,537] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 016| loss: 1.228621| val_0_mae: 1.785040| 0:00:00s
[2022-11-30 16:10:43,546] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 017| loss: 1.187945| val_0_mae: 1.771226| 0:00:00s
[2022-11-30 16:10:43,554] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 018| loss: 1.175421| val_0_mae: 1.757583| 0:00:00s
[2022-11-30 16:10:43,562] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 019| loss: 1.145478| val_0_mae: 1.744175| 0:00:00s
[2022-11-30 16:10:43,571] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 020| loss: 1.162446| val_0_mae: 1.730875| 0:00:00s
[2022-11-30 16:10:43,579] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 021| loss: 1.209757| val_0_mae: 1.717458| 0:00:00s
[2022-11-30 16:10:43,587] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 022| loss: 1.158366| val_0_mae: 1.704297| 0:00:00s
[2022-11-30 16:10:43,596] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 023| loss: 1.198785| val_0_mae: 1.691088| 0:00:00s
[2022-11-30 16:10:43,604] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 024| loss: 1.077143| val_0_mae: 1.678203| 0:00:00s
[2022-11-30 16:10:43,612] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 025| loss: 1.109716| val_0_mae: 1.665533| 0:00:00s
[2022-11-30 16:10:43,620] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 026| loss: 1.094385| val_0_mae: 1.653193| 0:00:00s
[2022-11-30 16:10:43,628] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 027| loss: 1.080788| val_0_mae: 1.641071| 0:00:00s
[2022-11-30 16:10:43,639] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 028| loss: 1.083594| val_0_mae: 1.629065| 0:00:00s
[2022-11-30 16:10:43,647] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 029| loss: 1.116938| val_0_mae: 1.616999| 0:00:00s
[2022-11-30 16:10:43,655] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 030| loss: 1.128399| val_0_mae: 1.604847| 0:00:00s
[2022-11-30 16:10:43,665] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 031| loss: 1.057482| val_0_mae: 1.592970| 0:00:00s
[2022-11-30 16:10:43,673] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 032| loss: 0.988231| val_0_mae: 1.581546| 0:00:00s
[2022-11-30 16:10:43,681] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 033| loss: 1.047795| val_0_mae: 1.573081| 0:00:00s
[2022-11-30 16:10:43,689] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 034| loss: 1.000459| val_0_mae: 1.564725| 0:00:00s
[2022-11-30 16:10:43,697] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 035| loss: 0.981792| val_0_mae: 1.556305| 0:00:00s
[2022-11-30 16:10:43,705] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 036| loss: 1.006375| val_0_mae: 1.547980| 0:00:00s
[2022-11-30 16:10:43,714] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 037| loss: 1.042193| val_0_mae: 1.539739| 0:00:00s
[2022-11-30 16:10:43,722] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 038| loss: 1.035661| val_0_mae: 1.531963| 0:00:00s
[2022-11-30 16:10:43,730] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 039| loss: 0.991913| val_0_mae: 1.524082| 0:00:00s
[2022-11-30 16:10:43,738] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 040| loss: 1.005429| val_0_mae: 1.516243| 0:00:00s
[2022-11-30 16:10:43,746] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 041| loss: 0.994254| val_0_mae: 1.508542| 0:00:00s
[2022-11-30 16:10:43,754] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 042| loss: 0.898892| val_0_mae: 1.500784| 0:00:00s
[2022-11-30 16:10:43,762] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 043| loss: 0.959420| val_0_mae: 1.493016| 0:00:00s
[2022-11-30 16:10:43,770] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 044| loss: 0.922088| val_0_mae: 1.485151| 0:00:00s
[2022-11-30 16:10:43,778] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 045| loss: 0.922496| val_0_mae: 1.477284| 0:00:00s
[2022-11-30 16:10:43,786] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 046| loss: 0.899761| val_0_mae: 1.469546| 0:00:00s
[2022-11-30 16:10:43,794] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 047| loss: 0.887426| val_0_mae: 1.461833| 0:00:00s
[2022-11-30 16:10:43,802] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 048| loss: 0.933051| val_0_mae: 1.454159| 0:00:00s
[2022-11-30 16:10:43,811] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 049| loss: 0.909708| val_0_mae: 1.446734| 0:00:00s
[2022-11-30 16:10:43,819] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 050| loss: 0.858118| val_0_mae: 1.439384| 0:00:00s
[2022-11-30 16:10:43,827] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 051| loss: 0.896318| val_0_mae: 1.432108| 0:00:00s
[2022-11-30 16:10:43,835] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 052| loss: 0.873903| val_0_mae: 1.425059| 0:00:00s
[2022-11-30 16:10:43,843] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 053| loss: 0.904945| val_0_mae: 1.417843| 0:00:00s
[2022-11-30 16:10:43,852] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 054| loss: 0.877615| val_0_mae: 1.410705| 0:00:00s
[2022-11-30 16:10:43,860] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 055| loss: 0.833624| val_0_mae: 1.403709| 0:00:00s
[2022-11-30 16:10:43,868] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 056| loss: 0.854565| val_0_mae: 1.396727| 0:00:00s
[2022-11-30 16:10:43,876] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 057| loss: 0.840873| val_0_mae: 1.389806| 0:00:00s
[2022-11-30 16:10:43,888] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 058| loss: 0.794746| val_0_mae: 1.382859| 0:00:00s
[2022-11-30 16:10:43,896] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 059| loss: 0.812057| val_0_mae: 1.375864| 0:00:00s
[2022-11-30 16:10:43,905] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 060| loss: 0.771674| val_0_mae: 1.368605| 0:00:00s
[2022-11-30 16:10:43,913] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 061| loss: 0.775358| val_0_mae: 1.361753| 0:00:00s
[2022-11-30 16:10:43,921] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 062| loss: 0.785709| val_0_mae: 1.354705| 0:00:00s
[2022-11-30 16:10:43,929] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 063| loss: 0.761827| val_0_mae: 1.347852| 0:00:00s
[2022-11-30 16:10:43,937] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 064| loss: 0.809033| val_0_mae: 1.341352| 0:00:00s
[2022-11-30 16:10:43,945] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 065| loss: 0.780791| val_0_mae: 1.334946| 0:00:00s
[2022-11-30 16:10:43,954] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 066| loss: 0.731239| val_0_mae: 1.328409| 0:00:00s
[2022-11-30 16:10:43,962] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 067| loss: 0.730782| val_0_mae: 1.321399| 0:00:00s
[2022-11-30 16:10:43,970] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 068| loss: 0.809380| val_0_mae: 1.314795| 0:00:00s
[2022-11-30 16:10:43,978] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 069| loss: 0.742965| val_0_mae: 1.308291| 0:00:00s
[2022-11-30 16:10:43,986] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 070| loss: 0.727096| val_0_mae: 1.301758| 0:00:00s
[2022-11-30 16:10:43,994] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 071| loss: 0.697883| val_0_mae: 1.295217| 0:00:00s
[2022-11-30 16:10:44,002] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 072| loss: 0.713083| val_0_mae: 1.288782| 0:00:00s
[2022-11-30 16:10:44,010] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 073| loss: 0.709170| val_0_mae: 1.282597| 0:00:00s
[2022-11-30 16:10:44,019] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 074| loss: 0.696470| val_0_mae: 1.276140| 0:00:00s
[2022-11-30 16:10:44,027] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 075| loss: 0.694704| val_0_mae: 1.269327| 0:00:00s
[2022-11-30 16:10:44,035] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 076| loss: 0.667143| val_0_mae: 1.262488| 0:00:00s
[2022-11-30 16:10:44,043] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 077| loss: 0.712231| val_0_mae: 1.256098| 0:00:00s
[2022-11-30 16:10:44,051] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 078| loss: 0.631457| val_0_mae: 1.249475| 0:00:00s
[2022-11-30 16:10:44,060] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 079| loss: 0.668609| val_0_mae: 1.242805| 0:00:00s
[2022-11-30 16:10:44,069] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 080| loss: 0.668821| val_0_mae: 1.236076| 0:00:00s
[2022-11-30 16:10:44,077] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 081| loss: 0.711819| val_0_mae: 1.229457| 0:00:00s
[2022-11-30 16:10:44,086] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 082| loss: 0.694410| val_0_mae: 1.222744| 0:00:00s
[2022-11-30 16:10:44,094] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 083| loss: 0.665902| val_0_mae: 1.216351| 0:00:00s
[2022-11-30 16:10:44,103] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 084| loss: 0.635740| val_0_mae: 1.209996| 0:00:00s
[2022-11-30 16:10:44,111] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 085| loss: 0.614305| val_0_mae: 1.204036| 0:00:00s
[2022-11-30 16:10:44,120] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 086| loss: 0.620907| val_0_mae: 1.198223| 0:00:00s
[2022-11-30 16:10:44,133] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 087| loss: 0.561215| val_0_mae: 1.192517| 0:00:00s
[2022-11-30 16:10:44,142] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 088| loss: 0.563731| val_0_mae: 1.186862| 0:00:00s
[2022-11-30 16:10:44,150] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 089| loss: 0.661129| val_0_mae: 1.182760| 0:00:00s
[2022-11-30 16:10:44,158] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 090| loss: 0.616183| val_0_mae: 1.178985| 0:00:00s
[2022-11-30 16:10:44,167] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 091| loss: 0.574570| val_0_mae: 1.175775| 0:00:00s
[2022-11-30 16:10:44,175] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 092| loss: 0.602581| val_0_mae: 1.173194| 0:00:00s
[2022-11-30 16:10:44,183] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 093| loss: 0.585698| val_0_mae: 1.170198| 0:00:00s
[2022-11-30 16:10:44,191] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 094| loss: 0.629226| val_0_mae: 1.167235| 0:00:00s
[2022-11-30 16:10:44,200] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 095| loss: 0.555163| val_0_mae: 1.164508| 0:00:00s
[2022-11-30 16:10:44,208] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 096| loss: 0.573185| val_0_mae: 1.161562| 0:00:00s
[2022-11-30 16:10:44,216] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 097| loss: 0.562863| val_0_mae: 1.158835| 0:00:00s
[2022-11-30 16:10:44,224] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 098| loss: 0.574777| val_0_mae: 1.156082| 0:00:00s
[2022-11-30 16:10:44,232] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 099| loss: 0.555662| val_0_mae: 1.153205| 0:00:00s
[2022-11-30 16:10:44,241] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 100| loss: 0.601033| val_0_mae: 1.150310| 0:00:00s
[2022-11-30 16:10:44,249] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 101| loss: 0.580240| val_0_mae: 1.147768| 0:00:00s
[2022-11-30 16:10:44,261] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 102| loss: 0.591648| val_0_mae: 1.145656| 0:00:00s
[2022-11-30 16:10:44,270] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 103| loss: 0.542775| val_0_mae: 1.144088| 0:00:00s
[2022-11-30 16:10:44,278] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 104| loss: 0.539643| val_0_mae: 1.142433| 0:00:00s
[2022-11-30 16:10:44,286] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 105| loss: 0.515184| val_0_mae: 1.141113| 0:00:00s
[2022-11-30 16:10:44,294] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 106| loss: 0.570678| val_0_mae: 1.140380| 0:00:00s
[2022-11-30 16:10:44,302] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 107| loss: 0.559144| val_0_mae: 1.139631| 0:00:00s
[2022-11-30 16:10:44,311] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 108| loss: 0.527249| val_0_mae: 1.139258| 0:00:00s
[2022-11-30 16:10:44,319] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 109| loss: 0.599206| val_0_mae: 1.138611| 0:00:00s
[2022-11-30 16:10:44,327] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 110| loss: 0.535672| val_0_mae: 1.137478| 0:00:00s
[2022-11-30 16:10:44,336] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 111| loss: 0.541393| val_0_mae: 1.136143| 0:00:00s
[2022-11-30 16:10:44,344] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 112| loss: 0.488889| val_0_mae: 1.134465| 0:00:00s
[2022-11-30 16:10:44,353] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 113| loss: 0.517146| val_0_mae: 1.132642| 0:00:00s
[2022-11-30 16:10:44,361] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 114| loss: 0.524021| val_0_mae: 1.130486| 0:00:00s
[2022-11-30 16:10:44,369] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 115| loss: 0.499090| val_0_mae: 1.127806| 0:00:00s
[2022-11-30 16:10:44,378] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 116| loss: 0.449979| val_0_mae: 1.125503| 0:00:00s
[2022-11-30 16:10:44,389] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 117| loss: 0.488240| val_0_mae: 1.123149| 0:00:01s
[2022-11-30 16:10:44,398] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 118| loss: 0.479148| val_0_mae: 1.120922| 0:00:01s
[2022-11-30 16:10:44,406] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 119| loss: 0.444483| val_0_mae: 1.118181| 0:00:01s
[2022-11-30 16:10:44,414] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 120| loss: 0.505287| val_0_mae: 1.115203| 0:00:01s
[2022-11-30 16:10:44,423] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 121| loss: 0.444154| val_0_mae: 1.112558| 0:00:01s
[2022-11-30 16:10:44,431] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 122| loss: 0.453889| val_0_mae: 1.109772| 0:00:01s
[2022-11-30 16:10:44,439] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 123| loss: 0.427750| val_0_mae: 1.106739| 0:00:01s
[2022-11-30 16:10:44,447] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 124| loss: 0.451959| val_0_mae: 1.103423| 0:00:01s
[2022-11-30 16:10:44,455] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 125| loss: 0.560345| val_0_mae: 1.101498| 0:00:01s
[2022-11-30 16:10:44,463] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 126| loss: 0.447812| val_0_mae: 1.099758| 0:00:01s
[2022-11-30 16:10:44,471] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 127| loss: 0.502624| val_0_mae: 1.098869| 0:00:01s
[2022-11-30 16:10:44,479] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 128| loss: 0.419782| val_0_mae: 1.098326| 0:00:01s
[2022-11-30 16:10:44,488] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 129| loss: 0.433613| val_0_mae: 1.098895| 0:00:01s
[2022-11-30 16:10:44,495] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 130| loss: 0.425311| val_0_mae: 1.099732| 0:00:01s
[2022-11-30 16:10:44,502] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 131| loss: 0.447280| val_0_mae: 1.100813| 0:00:01s
[2022-11-30 16:10:44,510] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 132| loss: 0.424777| val_0_mae: 1.102612| 0:00:01s
[2022-11-30 16:10:44,519] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 133| loss: 0.437228| val_0_mae: 1.105103| 0:00:01s
[2022-11-30 16:10:44,527] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 134| loss: 0.430760| val_0_mae: 1.107998| 0:00:01s
[2022-11-30 16:10:44,534] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 135| loss: 0.515581| val_0_mae: 1.111710| 0:00:01s
[2022-11-30 16:10:44,542] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 136| loss: 0.430533| val_0_mae: 1.115131| 0:00:01s
[2022-11-30 16:10:44,549] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 137| loss: 0.463041| val_0_mae: 1.117901| 0:00:01s
[2022-11-30 16:10:44,557] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 138| loss: 0.399685| val_0_mae: 1.120615| 0:00:01s
[2022-11-30 16:10:44,558] [paddlets.models.common.callbacks.callbacks] [INFO] 
Early stopping occurred at epoch 138 with best_epoch = 128 and best_val_0_mae = 1.098326
[2022-11-30 16:10:44,559] [paddlets.models.common.callbacks.callbacks] [INFO] Best weights from best epoch are automatically used!

6. 模型的预测

对归一化后的验证集数据进行预测,数据集切分后的形状先是验证集,再是测试集

  • 给训练后的LSTM网络模型传入验证集数据,然后会预测出后面一个序列的数据(根据现有序列推算新序列)
  • 而对于验证集来讲,验证集后面的序列在测试集中,预测出的新序列就可以和测试集的真实数据做一个对比。
  • 想获得测试集与预测出的数据重合部分,需要用到数组切分
val_dataset_scaled
       总人口(万人)  中国人均GPA(美元计)  中国就业人口(万人)  中国性别比例(按照女生=100)      乡村人口  \
2000  1.654992      2.956831    1.602288          1.134755 -2.372173   
2001  1.723807      3.394077    1.657836          0.303899 -2.639467   
2002  1.788108      3.835973    1.695518          0.371266 -2.904905   
2003  1.848360      4.487189    1.731094          0.528455 -3.172199   
2004  1.907601      5.510529    1.772287          0.629505 -3.400513   
2005  1.967386      6.650157    1.794366          0.640732 -3.628827   
2006  2.021255      8.259592    1.827991          0.629505 -3.897977   
2007  2.074268     11.022609    1.854751          0.517227 -4.167128   
2008  2.126658     14.627556    1.873709          0.382493 -4.371310   
2009  2.177102     16.320717    1.894305          0.225304 -4.621899   

       出生人口(万)  城镇人口(城镇+乡村=100)  美元兑换人民币汇率  自然增长率(%)  
2000 -1.479902         2.401478   1.731471 -1.625272  
2001 -1.742629         2.670516   1.730913 -1.776919  
2002 -1.952049         2.937686   1.730913 -1.897274  
2003 -2.131009         3.206725   1.731072 -2.003186  
2004 -2.153854         3.436528   1.731311 -2.036886  
2005 -2.062471         3.666332   1.660884 -2.032072  
2006 -2.180508         3.918555   1.544264 -2.178905  
2007 -2.142431         4.208145   1.376268 -2.205383  
2008 -2.092932         4.413661   1.162158 -2.227046  
2009 -2.157662         4.665884   1.146215 -2.277596  
subset_test_pred_dataset = lstm.predict(val_dataset_scaled)
subset_test_dataset, _ = test_dataset_scaled.split(len(subset_test_pred_dataset.target))
subset_test_dataset, _ = test_dataset_scaled.split(len(subset_test_pred_dataset.target))
subset_test_dataset.plot(add_data=subset_test_pred_dataset, labels=['Pred'])
<AxesSubplot:>



/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/IPython/core/events.py:89: UserWarning: Glyph 8722 (\N{MINUS SIGN}) missing from current font.
  func(*args, **kwargs)
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/IPython/core/pylabtools.py:151: UserWarning: Glyph 8722 (\N{MINUS SIGN}) missing from current font.
  fig.canvas.print_figure(bytes_io, **kw)

在这里插入图片描述

7. 模型评估

使用平均绝对误差MAE(Mean Absolute Error)指标对预测结果进行评估,计算公式如下:

mae = MAE()
mae(subset_test_dataset, subset_test_pred_dataset)
{'总人口(万人)': 1.3227954641250552}

8. 预测结果反归一化

因为传入模型进行训练的数据都是经过标准化/归一化处理后的数据,因此对于经过模型预测出的结果,也是标准化后的数据,需要我们进行反归一化处理,从而查看原始大小的预测数据。

  • 对预测出的数据进行反归一化
  • 对测试集中真实数据标签进行反归一化
  • 对二者反归一化的结果进行可视化

8.1 反归一化

subset_test_pred_dataset_new = scaler.inverse_transform(subset_test_pred_dataset)
subset_test_dataset_new = scaler.inverse_transform(subset_test_dataset)

8.2 可视化结果对比

subset_test_dataset, _ = test_dataset_scaled.split(len(subset_test_pred_dataset.target))
caled.split(len(subset_test_pred_dataset.target))
subset_test_dataset_new.plot(add_data=subset_test_pred_dataset_new, labels=['Pred'])
<AxesSubplot:>

在这里插入图片描述

9. 总结 🌟

  • 本项目使用PaddleTS中LSTNet模型进行了完整预测流程,训练后的模型达到了较为不错的效果。
  • 由于本数据集样本量有限,今后可以选择在样本更充分的数据集进行实验。
  • 今后可以尝试手搭网络实现LSTM神经网络预测

再一次感谢项目导师顾茜的指导

本项目由北京科技大学飞桨领航团的3名成员共同完成

请点击此处查看本环境基本用法.

Please click here for more detailed instructions.

此文章为搬运
原项目链接

好的,以下是一个使用LSTNet模型预测多输入单输出的Python代码示例: ```python import torch import torch.nn as nn import torch.nn.functional as F class LSTNet(nn.Module): def __init__(self, input_size, output_size, hidden_size, skip_size, ar_window, cnn_kernel): super(LSTNet, self).__init__() self.input_size = input_size self.output_size = output_size self.hidden_size = hidden_size self.skip_size = skip_size self.ar_window = ar_window self.cnn_kernel = cnn_kernel # CNN layer self.cnn_layer = nn.Conv2d(1, skip_size, kernel_size=(cnn_kernel, input_size)) # LSTM layer self.lstm_layer = nn.LSTM(skip_size * (ar_window + 1), hidden_size) # Fully connected output layer self.output_layer = nn.Linear(hidden_size, output_size) def forward(self, x): # CNN layer cnn_out = self.cnn_layer(x.unsqueeze(1).unsqueeze(1)) # Reshape CNN output cnn_out = cnn_out.view(cnn_out.size(0), cnn_out.size(1), -1) # LSTM layer lstm_out, _ = self.lstm_layer(cnn_out) # Fully connected output layer output = self.output_layer(lstm_out[:, -1, :]) return output # Define model hyperparameters input_size = 5 output_size = 1 hidden_size = 64 skip_size = 16 ar_window = 24 cnn_kernel = 3 # Initialize model model = LSTNet(input_size, output_size, hidden_size, skip_size, ar_window, cnn_kernel) # Define input tensor x = torch.randn(32, ar_window + 1, input_size) # Make prediction prediction = model(x) print(prediction) ``` 这个模型接受一个大小为 (batch_size, ar_window + 1, input_size) 的输入张量 x,并返回一个大小为 (batch_size, output_size) 的输出张量 prediction。其中,ar_window 表示自回归窗口大小,cnn_kernel 表示 CNN 层卷积核大小。在模型中,我们使用了一个 CNN 层和一个 LSTM 层来学习输入序列的特征,最后通过一个全连接输出层得到单一的预测值。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值