使用PaddlePaddle完成新冠疫情病例数预测

使用PaddlePaddle完成新冠疫情病例数预测
201912月以来,新冠疫情在全球肆虐,呈现大流行的特征。新型冠状病毒肺炎以发热、干咳、乏力等为主要表现,重症病例多在1周后出现呼吸困难,严重者快速进展为急性呼吸窘迫综合征、脓毒症休克、难以纠正的代谢性酸中毒和出凝血功能障碍及多器官功能衰竭等,对人们的健康造成了极其严重的威胁。同时,为抵御新冠病毒的扩散,不少国家和地区采取了封锁性防疫举措,全球经济复苏的进程因此受阻,政府债务不断上升。

在这种背景下,各国人民都期盼着疫情的结束,早日恢复往常的生产、生活方式。本文关注到这一问题,结合约翰斯·霍普金斯大学发布的全球新冠肺炎实时统计数据,通过时间卷积神经网络对时间序列建模,实现预测未来病例数的目的。

当然,疫情受很多因素影响,这里的项目采用的是一个相对简单的模型,预测结果并不准确。我们都希望疫情早日消散,病例数早日为零。尤其近期,国内多地疫情出现反复,各位小伙伴一定要严格防护,保护好自己和家人!

时间卷积网络(TCN,Temporal Convolutional Networks)
时间序列是指按照时间先后顺序排列而成的序列,例如每日发电量、每小时营业额等组成的序列。通过分析时间序列中的发展过程、方向和趋势,我们可以预测下一段时间可能出现的情况。在本例中,我们使用时间卷积网络TCN进行建模,将学习到的特征接入全连接层完成预测。TCN的网络如下所示:



图1:TCN示意图

图中是一个filters number=3, dilated rate=1的时间卷积网络,它能够学习前T个时序的数据特征。关于TCN更详细的资料请参考论文:An Empirical Evaluation of Generic Convolutional and Recurrent Networks for Sequence Modeling。

代码实践
准备环境
我们首先需要导入必要的包。

这里我们使用paddlenlp.seq2vec中内置好的模型,关于seq2vec的详细介绍可参考这个项目:seq2vec是什么? 瞧瞧怎么用它做情感分析

In [1]
!pip install paddlenlp>=2.0.0b -i https://pypi.org/simple

import os
import sys

import paddle
import paddle.nn as nn
import numpy as np

import pandas as pd
import seaborn as sns
from pylab import rcParams
import matplotlib.pyplot as plt
from matplotlib import rc
from sklearn.preprocessing import MinMaxScaler
from pandas.plotting import register_matplotlib_converters

sys.path.append(os.path.abspath(os.path.join(os.getcwd(), "../..")))
from paddlenlp.seq2vec import TCNEncoder
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/pandas/core/tools/datetimes.py:3: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
  from collections import MutableMapping
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/rcsetup.py:20: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
  from collections import Iterable, Mapping
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/colors.py:53: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
  from collections import Sized
为了更好地展示数据结果,我们在这里配置画图功能。

In [2]
# config matplotlib
%matplotlib inline
%config InlineBackend.figure_format='retina'
sns.set(style='whitegrid', palette='muted', font_scale=1.2)
HAPPY_COLORS_PALETTE = ["#01BEFE", "#FFDD00", "#FF7D00", "#FF006D", "#93D30C", "#8F00FF"]
sns.set_palette(sns.color_palette(HAPPY_COLORS_PALETTE))
rcParams['figure.figsize'] = 14, 10
register_matplotlib_converters()
数据下载
数据集由约翰·霍普金斯大学系统科学与工程中心提供,每日最新数据可以从https://github.com/CSSEGISandData/COVID-19 仓库中获取,我们在本例中提供了20201124日下载的病例数据。

In [3]
# !wget https://github.com/CSSEGISandData/COVID-19/blob/master/csse_covid_19_data/csse_covid_19_time_series/time_series_covid19_confirmed_global.csv
数据预览
数据集中包含了国家/地区、省份/州、纬度、经度、日期,以及从2020122日至今的病例数等信息。

In [4]
df_all = pd.read_csv('time_series_covid19_confirmed_global.csv')
df_all.head()
  Province/State Country/Region       Lat       Long  1/22/20  1/23/20  \
0            NaN    Afghanistan  33.93911  67.709953        0        0   
1            NaN        Albania  41.15330  20.168300        0        0   
2            NaN        Algeria  28.03390   1.659600        0        0   
3            NaN        Andorra  42.50630   1.521800        0        0   
4            NaN         Angola -11.20270  17.873900        0        0   

   1/24/20  1/25/20  1/26/20  1/27/20    ...     11/13/20  11/14/20  11/15/20  \
0        0        0        0        0    ...        42969     43035     43240   
1        0        0        0        0    ...        26701     27233     27830   
2        0        0        0        0    ...        65975     66819     67679   
3        0        0        0        0    ...         5725      5725      5872   
4        0        0        0        0    ...        13228     13374     13451   

   11/16/20  11/17/20  11/18/20  11/19/20  11/20/20  11/21/20  11/22/20  
0     43403     43628     43851     44228     44443     44503     44706  
1     28432     29126     29837     30623     31459     32196     32761  
2     68589     69591     70629     71652     72755     73774     74862  
3      5914      5951      6018      6066      6142      6207      6256  
4     13615     13818     13922     14134     14267     14413     14493  

[5 rows x 310 columns]
我们将对全世界的病例数进行预测,因此我们不需要关心具体国家的经纬度等信息,只需关注具体日期下的全球病例数即可。

In [5]
df = df_all.iloc[:, 4:]
daily_cases = df.sum(axis=0)
daily_cases.index = pd.to_datetime(daily_cases.index)
daily_cases.head()
2020-01-22     555
2020-01-23     654
2020-01-24     941
2020-01-25    1434
2020-01-26    2118
dtype: int64
In [6]
plt.figure(figsize=(12,12))
plt.plot(daily_cases)
plt.title("Cumulative daily cases");
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/cbook/__init__.py:2349: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
  if isinstance(obj, collections.Iterator):
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/cbook/__init__.py:2366: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
  return list(data) if isinstance(data, collections.MappingView) else data

<Figure size 864x864 with 1 Axes>
为了提高样本时间序列的平稳性,我们继续取一阶差分。

In [7]
daily_cases = daily_cases.diff().fillna(daily_cases[0]).astype(np.int64)
daily_cases.head()
2020-01-22    555
2020-01-23     99
2020-01-24    287
2020-01-25    493
2020-01-26    684
dtype: int64
In [8]
plt.plot(daily_cases)
plt.title("Daily cases");

<Figure size 1008x720 with 1 Axes>
数据预处理
首先划分数据集为训练集与验证集,我们取最后30天的数据作为测试集,其余作为训练集。

In [9]
TEST_DATA_SIZE = 30

train_data = daily_cases[:-TEST_DATA_SIZE]
test_data = daily_cases[-TEST_DATA_SIZE:]

print("The number of the samples in train set is : %i"%train_data.shape[0])
The number of the samples in train set is : 276
为了提升模型收敛速度与性能,我们使用scikit-learn进行数据归一化。

In [10]
scaler = MinMaxScaler()
train_data = scaler.fit_transform(np.expand_dims(train_data, axis=1)).astype('float32')
test_data = scaler.transform(np.expand_dims(test_data, axis=1)).astype('float32')
现在开始组建时间序列,可以用前10天的病例数预测当天的病例数。为了让测试集中的所有数据都能参与预测,我们将向测试集补充少量数据,这部分数据只会作为模型的输入。

In [11]
SEQ_LEN = 10

def create_sequences(data, seq_length):
    xs = []
    ys = []

    for i in range(len(data)-seq_length+1):
        x = data[i:i+seq_length-1]
        y = data[i+seq_length-1]
        xs.append(x)
        ys.append(y)

    return np.array(xs), np.array(ys)

x_train, y_train = create_sequences(train_data, SEQ_LEN)
test_data = np.concatenate((train_data[-SEQ_LEN+1:],test_data),axis=0)
x_test, y_test = create_sequences(test_data, SEQ_LEN)

print("The shape of x_train is: %s"%str(x_train.shape))
print("The shape of y_train is: %s"%str(y_train.shape))
print("The shape of x_test is: %s"%str(x_test.shape))
print("The shape of y_test is: %s"%str(y_test.shape))
The shape of x_train is: (267, 9, 1)
The shape of y_train is: (267, 1)
The shape of x_test is: (30, 9, 1)
The shape of y_test is: (30, 1)
数据集处理完毕,将数据集封装到CovidDataset,以便模型训练、预测时调用。

In [12]
class CovidDataset(paddle.io.Dataset):
    def __init__(self, feature, label):
        self.feature = feature
        self.label = label
        super(CovidDataset, self).__init__()

    def __len__(self):
        return len(self.label)

    def __getitem__(self, index):
        return [self.feature[index], self.label[index]]

train_dataset = CovidDataset(x_train, y_train)
test_dataset = CovidDataset(x_test, y_test)
组网
现在开始组建模型网络,我们采用时间卷积网络TCN作为特征提取器,将提取到的时序信息传送给全连接层获得最终的预测结果。

In [13]
class TimeSeriesNetwork(nn.Layer):

  def __init__(self, input_size, next_k=1, num_channels=[64,128,256]):
    super(TimeSeriesNetwork, self).__init__()

    self.last_num_channel = num_channels[-1]

    self.tcn = TCNEncoder(
      input_size=input_size,
      num_channels=num_channels,
      kernel_size=2, 
      dropout=0.2
    )

    self.linear = nn.Linear(in_features= self.last_num_channel, out_features=next_k)

  def forward(self, x):
    tcn_out = self.tcn(x)
    y_pred = self.linear(tcn_out)
    return y_pred

network = TimeSeriesNetwork(input_size=1)
定义优化器、损失函数
在这里我们使用Adam优化器、均方差损失函数,为启动训练做最后的准备。

In [14]
LR = 1e-3

model = paddle.Model(network)

optimizer = paddle.optimizer.Adam(
        learning_rate=LR, parameters=model.parameters())

loss = paddle.nn.MSELoss(reduction='sum')

model.prepare(optimizer, loss)
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/distributed/parallel.py:119: UserWarning: Currently not a parallel execution environment, `paddle.distributed.init_parallel_env` will not do anything.
  "Currently not a parallel execution environment, `paddle.distributed.init_parallel_env` will not do anything."
训练
配置必要的超参数,启动训练。

In [15]
USE_GPU = False
TRAIN_EPOCH = 100
LOG_FREQ = 10
SAVE_DIR = os.path.join(os.getcwd(),"save_dir")
SAVE_FREQ = 10

if USE_GPU:
    paddle.set_device("gpu")
else:
    paddle.set_device("cpu")

model.fit(train_dataset, 
    batch_size=32,
    drop_last=True,
    epochs=TRAIN_EPOCH,
    log_freq=LOG_FREQ,
    save_dir=SAVE_DIR,
    save_freq=SAVE_FREQ,
    verbose=1
    )
The loss value printed in the log is the current step, and the metric is the average value of previous step.
Epoch 1/100
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/fluid/layers/utils.py:77: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
  return (isinstance(seq, collections.Sequence) and
step 8/8 [==============================] - loss: 9.9573 - 69ms/step
save checkpoint at /home/aistudio/save_dir/0
Epoch 2/100
step 8/8 [==============================] - loss: 3.9386 - 47ms/step
Epoch 3/100
step 8/8 [==============================] - loss: 0.4986 - 47ms/step
Epoch 4/100
step 8/8 [==============================] - loss: 0.8879 - 47ms/step
Epoch 5/100
step 8/8 [==============================] - loss: 0.8245 - 46ms/step
Epoch 6/100
step 8/8 [==============================] - loss: 0.3651 - 46ms/step
Epoch 7/100
step 8/8 [==============================] - loss: 0.2094 - 47ms/step
Epoch 8/100
step 8/8 [==============================] - loss: 0.2857 - 46ms/step
Epoch 9/100
step 8/8 [==============================] - loss: 0.3986 - 46ms/step
Epoch 10/100
step 8/8 [==============================] - loss: 0.3816 - 46ms/step
Epoch 11/100
step 8/8 [==============================] - loss: 0.2462 - 46ms/step
save checkpoint at /home/aistudio/save_dir/10
Epoch 12/100
step 8/8 [==============================] - loss: 0.2152 - 46ms/step
Epoch 13/100
step 8/8 [==============================] - loss: 0.2527 - 46ms/step
Epoch 14/100
step 8/8 [==============================] - loss: 0.2782 - 46ms/step
Epoch 15/100
step 8/8 [==============================] - loss: 0.2057 - 46ms/step
Epoch 16/100
step 8/8 [==============================] - loss: 0.3428 - 46ms/step
Epoch 17/100
step 8/8 [==============================] - loss: 0.3553 - 46ms/step
Epoch 18/100
step 8/8 [==============================] - loss: 0.4045 - 46ms/step
Epoch 19/100
step 8/8 [==============================] - loss: 0.2490 - 47ms/step
Epoch 20/100
step 8/8 [==============================] - loss: 0.2610 - 46ms/step
Epoch 21/100
step 8/8 [==============================] - loss: 0.1904 - 46ms/step
save checkpoint at /home/aistudio/save_dir/20
Epoch 22/100
step 8/8 [==============================] - loss: 0.3390 - 47ms/step
Epoch 23/100
step 8/8 [==============================] - loss: 0.2354 - 47ms/step
Epoch 24/100
step 8/8 [==============================] - loss: 0.2956 - 47ms/step
Epoch 25/100
step 8/8 [==============================] - loss: 0.4083 - 47ms/step
Epoch 26/100
step 8/8 [==============================] - loss: 0.4657 - 46ms/step
Epoch 27/100
step 8/8 [==============================] - loss: 0.1771 - 46ms/step
Epoch 28/100
step 8/8 [==============================] - loss: 0.1987 - 47ms/step
Epoch 29/100
step 8/8 [==============================] - loss: 0.1948 - 47ms/step
Epoch 30/100
step 8/8 [==============================] - loss: 0.1985 - 46ms/step
Epoch 31/100
step 8/8 [==============================] - loss: 0.2031 - 46ms/step
save checkpoint at /home/aistudio/save_dir/30
Epoch 32/100
step 8/8 [==============================] - loss: 0.3883 - 46ms/step
Epoch 33/100
step 8/8 [==============================] - loss: 0.2265 - 46ms/step
Epoch 34/100
step 8/8 [==============================] - loss: 0.2596 - 46ms/step
Epoch 35/100
step 8/8 [==============================] - loss: 0.1795 - 46ms/step
Epoch 36/100
step 8/8 [==============================] - loss: 0.3061 - 46ms/step
Epoch 37/100
step 8/8 [==============================] - loss: 0.1926 - 46ms/step
Epoch 38/100
step 8/8 [==============================] - loss: 0.0923 - 46ms/step
Epoch 39/100
step 8/8 [==============================] - loss: 0.0637 - 46ms/step
Epoch 40/100
step 8/8 [==============================] - loss: 0.1332 - 46ms/step
Epoch 41/100
step 8/8 [==============================] - loss: 0.3362 - 46ms/step
save checkpoint at /home/aistudio/save_dir/40
Epoch 42/100
step 8/8 [==============================] - loss: 0.2125 - 46ms/step
Epoch 43/100
step 8/8 [==============================] - loss: 0.2424 - 46ms/step
Epoch 44/100
step 8/8 [==============================] - loss: 0.2147 - 46ms/step
Epoch 45/100
step 8/8 [==============================] - loss: 0.3598 - 46ms/step
Epoch 46/100
step 8/8 [==============================] - loss: 0.0992 - 47ms/step
Epoch 47/100
step 8/8 [==============================] - loss: 0.3073 - 47ms/step
Epoch 48/100
step 8/8 [==============================] - loss: 0.2622 - 48ms/step
Epoch 49/100
step 8/8 [==============================] - loss: 0.0708 - 47ms/step
Epoch 50/100
step 8/8 [==============================] - loss: 0.1964 - 46ms/step
Epoch 51/100
step 8/8 [==============================] - loss: 0.1757 - 47ms/step
save checkpoint at /home/aistudio/save_dir/50
Epoch 52/100
step 8/8 [==============================] - loss: 0.2064 - 47ms/step
Epoch 53/100
step 8/8 [==============================] - loss: 0.2252 - 46ms/step
Epoch 54/100
step 8/8 [==============================] - loss: 0.1206 - 46ms/step
Epoch 55/100
step 8/8 [==============================] - loss: 0.1766 - 48ms/step
Epoch 56/100
step 8/8 [==============================] - loss: 0.2037 - 49ms/step
Epoch 57/100
step 8/8 [==============================] - loss: 0.1473 - 47ms/step
Epoch 58/100
step 8/8 [==============================] - loss: 0.1644 - 46ms/step
Epoch 59/100
step 8/8 [==============================] - loss: 0.2924 - 48ms/step
Epoch 60/100
step 8/8 [==============================] - loss: 0.1679 - 46ms/step
Epoch 61/100
step 8/8 [==============================] - loss: 0.1058 - 46ms/step
save checkpoint at /home/aistudio/save_dir/60
Epoch 62/100
step 8/8 [==============================] - loss: 0.2017 - 46ms/step
Epoch 63/100
step 8/8 [==============================] - loss: 0.1023 - 46ms/step
Epoch 64/100
step 8/8 [==============================] - loss: 0.1241 - 46ms/step
Epoch 65/100
step 8/8 [==============================] - loss: 0.0647 - 46ms/step
Epoch 66/100
step 8/8 [==============================] - loss: 0.1424 - 46ms/step
Epoch 67/100
step 8/8 [==============================] - loss: 0.0840 - 46ms/step
Epoch 68/100
step 8/8 [==============================] - loss: 0.1229 - 46ms/step
Epoch 69/100
step 8/8 [==============================] - loss: 0.1059 - 46ms/step
Epoch 70/100
step 8/8 [==============================] - loss: 0.1413 - 46ms/step
Epoch 71/100
step 8/8 [==============================] - loss: 0.0539 - 46ms/step
save checkpoint at /home/aistudio/save_dir/70
Epoch 72/100
step 8/8 [==============================] - loss: 0.2738 - 47ms/step
Epoch 73/100
step 8/8 [==============================] - loss: 0.1312 - 46ms/step
Epoch 74/100
step 8/8 [==============================] - loss: 0.0794 - 47ms/step
Epoch 75/100
step 8/8 [==============================] - loss: 0.0574 - 47ms/step
Epoch 76/100
step 8/8 [==============================] - loss: 0.2644 - 47ms/step
Epoch 77/100
step 8/8 [==============================] - loss: 0.1430 - 47ms/step
Epoch 78/100
step 8/8 [==============================] - loss: 0.0969 - 47ms/step
Epoch 79/100
step 8/8 [==============================] - loss: 0.1312 - 47ms/step
Epoch 80/100
step 8/8 [==============================] - loss: 0.1262 - 47ms/step
Epoch 81/100
step 8/8 [==============================] - loss: 0.1845 - 49ms/step
save checkpoint at /home/aistudio/save_dir/80
Epoch 82/100
step 8/8 [==============================] - loss: 0.1114 - 48ms/step
Epoch 83/100
step 8/8 [==============================] - loss: 0.0846 - 46ms/step
Epoch 84/100
step 8/8 [==============================] - loss: 0.3050 - 46ms/step
Epoch 85/100
step 8/8 [==============================] - loss: 0.2671 - 46ms/step
Epoch 86/100
step 8/8 [==============================] - loss: 0.1549 - 46ms/step
Epoch 87/100
step 8/8 [==============================] - loss: 0.1525 - 47ms/step
Epoch 88/100
step 8/8 [==============================] - loss: 0.1109 - 46ms/step
Epoch 89/100
step 8/8 [==============================] - loss: 0.1744 - 46ms/step
Epoch 90/100
step 8/8 [==============================] - loss: 0.1832 - 46ms/step
Epoch 91/100
step 8/8 [==============================] - loss: 0.1008 - 46ms/step
save checkpoint at /home/aistudio/save_dir/90
Epoch 92/100
step 8/8 [==============================] - loss: 0.0896 - 46ms/step
Epoch 93/100
step 8/8 [==============================] - loss: 0.0655 - 46ms/step
Epoch 94/100
step 8/8 [==============================] - loss: 0.0939 - 46ms/step
Epoch 95/100
step 8/8 [==============================] - loss: 0.0810 - 46ms/step
Epoch 96/100
step 8/8 [==============================] - loss: 0.1840 - 46ms/step
Epoch 97/100
step 8/8 [==============================] - loss: 0.1219 - 46ms/step
Epoch 98/100
step 8/8 [==============================] - loss: 0.0795 - 46ms/step
Epoch 99/100
step 8/8 [==============================] - loss: 0.1334 - 46ms/step
Epoch 100/100
step 8/8 [==============================] - loss: 0.0793 - 46ms/step
save checkpoint at /home/aistudio/save_dir/final
预测
使用训练完毕的模型,对测试集中的日期对应的病例数进行预测。

In [16]
preds = model.predict(
        test_data=test_dataset
        )
Predict begin...
step 30/30 [==============================] - ETA: 0s - 13ms/ste - ETA: 0s - 13ms/ste - ETA: 0s - 13ms/ste - ETA: 0s - 13ms/ste - ETA: 0s - 12ms/ste - ETA: 0s - 12ms/ste - ETA: 0s - 12ms/ste - ETA: 0s - 12ms/ste - ETA: 0s - 12ms/ste - ETA: 0s - 12ms/ste - ETA: 0s - 12ms/ste - ETA: 0s - 12ms/ste - ETA: 0s - 12ms/ste - ETA: 0s - 12ms/ste - 12ms/step          
Predict samples: 30
数据后处理
将归一化的数据转换为原始数据,画出真实值对应的曲线和预测值对应的曲线。

In [17]
true_cases = scaler.inverse_transform(
    np.expand_dims(y_test.flatten(), axis=0)
).flatten()

predicted_cases = scaler.inverse_transform(
  np.expand_dims(np.array(preds).flatten(), axis=0)
).flatten()
In [18]
print (type(daily_cases))
daily_cases[1:3]
print (len(daily_cases), len(train_data))
daily_cases.index[:len(train_data)]
<class 'pandas.core.series.Series'>
306 276
DatetimeIndex(['2020-01-22', '2020-01-23', '2020-01-24', '2020-01-25',
               '2020-01-26', '2020-01-27', '2020-01-28', '2020-01-29',
               '2020-01-30', '2020-01-31',
               ...
               '2020-10-14', '2020-10-15', '2020-10-16', '2020-10-17',
               '2020-10-18', '2020-10-19', '2020-10-20', '2020-10-21',
               '2020-10-22', '2020-10-23'],
              dtype='datetime64[ns]', length=276, freq=None)
In [19]
plt.plot(
  daily_cases.index[:len(train_data)], 
  scaler.inverse_transform(train_data).flatten(),
  label='Historical Daily Cases'
)

plt.plot(
  daily_cases.index[len(train_data):len(train_data) + len(true_cases)], 
  true_cases,
  label='Real Daily Cases'
)

plt.plot(
  daily_cases.index[len(train_data):len(train_data) + len(true_cases)], 
  predicted_cases, 
  label='Predicted Daily Cases'
)

plt.legend();

<Figure size 1008x720 with 1 Axes>
进一步优化模型
从上图中,我们可以看到模型大体上预测到了病例数涨幅的升降情况,在具体数值上则出现了一些误差。读者可以发挥创造力,进一步提升模型的精度与功能,例如:

预测未来n天
我们现在是用已知的9天病例数,预测第10天的病例数,我们可以将第10天的预测结果与前8天的真实病例数拼接,预测第11天的病例数,以此类推即可预测未来n天的病例数。

优化模型网络
本文采用的是TCN模型,如果不考虑模型的速度性能,可以尝试LSTM, GRU, transformer等模型,进一步提升模型的拟合能力。

优化模型超参数
本文没有对超参设置进行探索,读者可以探索设置更加合理的学习率,训练轮次,TCN通道数等。

考虑更多的数据特征
本文只考虑了病例的日期,没有考虑政策、疫苗研制情况等具体环境的影响,读者可以搜集更多的新闻资料,加入更多的数据特征。

PaddleNLP 更多项目
seq2vec是什么? 瞧瞧怎么用它做情感分析
如何通过预训练模型Fine-tune下游任务
使用BiGRU-CRF模型完成快递单信息抽取
使用预训练模型ERNIE优化快递单信息抽取
使用Seq2Seq模型完成自动对联
使用预训练模型ERNIE-GEN实现智能写诗
使用预训练模型完成阅读理解
自定义数据集实现文本多分类任务
  • 1
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值