基于LSTM的新冠预测,pytorch

import torch
from torch import nn
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
df = pd.read_csv('./COVID-19_USA.csv')
df.head()
countryEnglishNameconfirmedCountnowconfirmedCountcuredCountdeadCountdeadRatecuredRate
0United States of America00000.00.0
1United States of America00000.00.0
2United States of America00000.00.0
3United States of America00000.00.0
4United States of America00000.00.0
value = df['confirmedCount'].values[20:140]
value
array([     13,      13,      14,      15,      15,      15,      15,
            15,      15,      15,      15,      35,      35,      35,
            53,      57,      60,      60,      63,      63,      77,
           100,     122,     153,     232,     324,     445,     572,
           704,    1004,    1004,    1635,    2084,    2885,    3700,
          4661,    6420,   10259,   14250,   19624,   26997,   35360,
         46450,   55243,   69194,   85840,  105470,  124686,  143101,
        164603,  189633,  216722,  245658,  278537,  312245,  337971,
        368449,  399979,  432579,  467184,  501615,  530830,  558526,
        583220,  609696,  640014,  672246,  706832,  735366,  760570,
        788920,  826184,  843937,  870468,  907096,  940797,  967585,
        989357, 1014568, 1040608, 1070032, 1107815, 1133069, 1158341,
       1181885, 1206323, 1231943, 1256972, 1286833, 1312099, 1332411,
       1351200, 1371395, 1395265, 1419998, 1446875, 1470199, 1490195,
       1510988, 1528568, 1577758, 1604189, 1626258, 1646495, 1665882,
       1684173, 1702911, 1724873, 1750203, 1773020, 1792512, 1812125,
       1832412, 1854476, 1872660, 1901391, 1920552, 1941748], dtype=int64)
print(len(value))
x = []
y = []
seq = 3
for i in range(len(value)-seq-1):
    x.append(value[i:i+seq])
    y.append(value[i+seq])
118

LSTM 的输入:input,(h_0,c_0)

input:输入数据,shape 为(句子长度seq_len, 句子数量batch, 每个单词向量的长度input_size);
h_0:默认为0,shape 为(num_layers * num_directions单向为1双向为2, batch, 隐藏层节点数hidden_size);
c_0:默认为0,shape 为(num_layers * num_directions, batch, hidden_size);

LSTM 的输出:output,(h_n,c_n)

output:输出的 shape 为(seq_len, batch, num_directions * hidden_size);
h_n:shape 为(num_layers * num_directions, batch, hidden_size);
c_n:shape 为(num_layers * num_directions, batch, hidden_size);

import torch
import torch.nn as nn

rnn = nn.LSTM(10, 20, 3) # 一个单词向量长度为10,隐藏层节点数为20,LSTM数量为3
input = torch.randn(8, 3, 10) # batch_size为3(输入数据有3个句子),每个句子有8个单词,每个单词向量长度为10
h_0, c_0 = torch.randn(3, 3, 20), torch.randn(3, 3, 20)
output, (h_n, c_n) = rnn(input, (h_0, c_0))

print(“input.shape:”, input.shape)
print(“h_0.shape:”, h_0.shape)
print(“c_0.shape:”, c_0.shape)
print("*" * 50)
print(“output.shape:”, output.shape)
print(“h_n.shape:”, h_n.shape)
print(“c_n.shape:”, c_n.shape)

#print(x, '\n', y)
train_x = (torch.tensor(x[0:90]).float()/100000.).reshape(-1, seq, 1)
train_y = (torch.tensor(y[0:90]).float()/100000.).reshape(-1, 1)
test_x = (torch.tensor(x[90:110]).float()/100000.).reshape(-1, seq, 1)
test_y = (torch.tensor(y[90:110]).float()/100000.).reshape(-1, 1)
print(test_y)
tensor([[13.9527],
        [14.2000],
        [14.4688],
        [14.7020],
        [14.9019],
        [15.1099],
        [15.2857],
        [15.7776],
        [16.0419],
        [16.2626],
        [16.4650],
        [16.6588],
        [16.8417],
        [17.0291],
        [17.2487],
        [17.5020],
        [17.7302],
        [17.9251],
        [18.1213],
        [18.3241]])
# 模型训练
class LSTM(nn.Module):
    def __init__(self):
        super(LSTM, self).__init__()
        self.lstm = nn.LSTM(input_size=1, hidden_size=16, num_layers=1, batch_first=True)
        self.linear = nn.Linear(16 * seq, 1)
    def forward(self, x):
        x, (h, c) = self.lstm(x)
        x = x.reshape(-1, 16 * seq)
        x = self.linear(x)
        return x
# 模型训练
model = LSTM()
optimzer = torch.optim.Adam(model.parameters(), lr=0.005)
loss_func = nn.MSELoss()
model.train()

for epoch in range(2000):
    output = model(train_x)
    loss = loss_func(output, train_y)
    optimzer.zero_grad()
    loss.backward()
    optimzer.step()
    if epoch % 20 == 0:
        tess_loss = loss_func(model(test_x), test_y)
        print("epoch:{}, train_loss:{}, test_loss:{}".format(epoch, loss, tess_loss))
epoch:0, train_loss:37.56373977661133, test_loss:263.2969055175781
epoch:20, train_loss:18.516111373901367, test_loss:166.8538360595703
epoch:40, train_loss:3.2988765239715576, test_loss:58.48737716674805
epoch:60, train_loss:1.0975019931793213, test_loss:28.006675720214844
epoch:80, train_loss:0.2535986304283142, test_loss:16.773616790771484
epoch:100, train_loss:0.10097146779298782, test_loss:12.012392044067383
epoch:120, train_loss:0.05068375915288925, test_loss:9.413983345031738
epoch:140, train_loss:0.031387731432914734, test_loss:7.986786842346191
epoch:160, train_loss:0.021674150601029396, test_loss:7.1016082763671875
epoch:180, train_loss:0.015867270529270172, test_loss:6.489916801452637
epoch:200, train_loss:0.01209135353565216, test_loss:6.03498649597168
epoch:220, train_loss:0.009504670277237892, test_loss:5.680188179016113
epoch:240, train_loss:0.00768320681527257, test_loss:5.396552562713623
epoch:260, train_loss:0.0064044492319226265, test_loss:5.167068004608154
epoch:280, train_loss:0.005527254659682512, test_loss:4.980656623840332
epoch:300, train_loss:0.004928725305944681, test_loss:4.828995227813721
epoch:320, train_loss:0.004502041265368462, test_loss:4.704446315765381
epoch:340, train_loss:0.004170786123722792, test_loss:4.600121974945068
epoch:360, train_loss:0.0038912463933229446, test_loss:4.5099616050720215
epoch:380, train_loss:0.0036443774588406086, test_loss:4.428786277770996
epoch:400, train_loss:0.0034239296801388264, test_loss:4.352680683135986
epoch:420, train_loss:0.003226183820515871, test_loss:4.279394626617432
epoch:440, train_loss:0.0030467547476291656, test_loss:4.208122730255127
epoch:460, train_loss:0.002882403088733554, test_loss:4.1388139724731445
epoch:480, train_loss:0.0027316079940646887, test_loss:4.071575164794922
epoch:500, train_loss:0.002594325691461563, test_loss:4.006669044494629
epoch:520, train_loss:0.0024710919242352247, test_loss:3.9445641040802
epoch:540, train_loss:0.0023619714193046093, test_loss:3.8857200145721436
epoch:560, train_loss:0.0022660386748611927, test_loss:3.8303375244140625
epoch:580, train_loss:0.0021817032247781754, test_loss:3.77828311920166
epoch:600, train_loss:0.002107286360114813, test_loss:3.729255199432373
epoch:620, train_loss:0.0020413610618561506, test_loss:3.682926654815674
epoch:640, train_loss:0.001982747344300151, test_loss:3.639024496078491
epoch:660, train_loss:0.0019304585875943303, test_loss:3.597336769104004
epoch:680, train_loss:0.001883711782284081, test_loss:3.5576891899108887
epoch:700, train_loss:0.0018417979590594769, test_loss:3.5199549198150635
epoch:720, train_loss:0.0018041135044768453, test_loss:3.4840168952941895
epoch:740, train_loss:0.001770111615769565, test_loss:3.4497766494750977
epoch:760, train_loss:0.0017393192974850535, test_loss:3.41713809967041
epoch:780, train_loss:0.0017113107023760676, test_loss:3.386018753051758
epoch:800, train_loss:0.0016857198206707835, test_loss:3.3563263416290283
epoch:820, train_loss:0.0016621954273432493, test_loss:3.3279783725738525
epoch:840, train_loss:0.0016404724447056651, test_loss:3.3008854389190674
epoch:860, train_loss:0.0016202878905460238, test_loss:3.2749710083007812
epoch:880, train_loss:0.001601424883119762, test_loss:3.2501537799835205
epoch:900, train_loss:0.0015837030950933695, test_loss:3.226363182067871
epoch:920, train_loss:0.0015669644344598055, test_loss:3.2035205364227295
epoch:940, train_loss:0.0015510818921029568, test_loss:3.181558132171631
epoch:960, train_loss:0.001535946037620306, test_loss:3.1604182720184326
epoch:980, train_loss:0.001521461526863277, test_loss:3.140035629272461
epoch:1000, train_loss:0.0015075404662638903, test_loss:3.1203596591949463
epoch:1020, train_loss:0.0014941382687538862, test_loss:3.101344108581543
epoch:1040, train_loss:0.001481189508922398, test_loss:3.082942008972168
epoch:1060, train_loss:0.0014686313224956393, test_loss:3.0651092529296875
epoch:1080, train_loss:0.0014564390294253826, test_loss:3.0478177070617676
epoch:1100, train_loss:0.001444563502445817, test_loss:3.0310306549072266
epoch:1120, train_loss:0.0014329585246741772, test_loss:3.014721632003784
epoch:1140, train_loss:0.0014216136187314987, test_loss:2.998866319656372
epoch:1160, train_loss:0.0014104940928518772, test_loss:2.9834399223327637
epoch:1180, train_loss:0.0013995792251080275, test_loss:2.968425989151001
epoch:1200, train_loss:0.001388857257552445, test_loss:2.953805923461914
epoch:1220, train_loss:0.0013782993191853166, test_loss:2.939566135406494
epoch:1240, train_loss:0.0013679120456799865, test_loss:2.9256908893585205
epoch:1260, train_loss:0.001357687870040536, test_loss:2.9121646881103516
epoch:1280, train_loss:0.0013476117746904492, test_loss:2.8989791870117188
epoch:1300, train_loss:0.0013376886490732431, test_loss:2.8861305713653564
epoch:1320, train_loss:0.001327910111285746, test_loss:2.873605728149414
epoch:1340, train_loss:0.0013182887341827154, test_loss:2.861402750015259
epoch:1360, train_loss:0.0013088179985061288, test_loss:2.8495066165924072
epoch:1380, train_loss:0.0012995086144655943, test_loss:2.837918519973755
epoch:1400, train_loss:0.0012903454480692744, test_loss:2.8266289234161377
epoch:1420, train_loss:0.0012813681969419122, test_loss:2.81563663482666
epoch:1440, train_loss:0.0012725305277854204, test_loss:2.8049354553222656
epoch:1460, train_loss:0.0012638678308576345, test_loss:2.7945163249969482
epoch:1480, train_loss:0.0012553795240819454, test_loss:2.7843799591064453
epoch:1500, train_loss:0.0012470469810068607, test_loss:2.774510622024536
epoch:1520, train_loss:0.001238883938640356, test_loss:2.764907121658325
epoch:1540, train_loss:0.0012308855075389147, test_loss:2.7555596828460693
epoch:1560, train_loss:0.0012230485444888473, test_loss:2.7464544773101807
epoch:1580, train_loss:0.0012153665302321315, test_loss:2.737583637237549
epoch:1600, train_loss:0.0012078447034582496, test_loss:2.7289376258850098
epoch:1620, train_loss:0.0012004825985059142, test_loss:2.720499277114868
epoch:1640, train_loss:0.001193271717056632, test_loss:2.7122607231140137
epoch:1660, train_loss:0.0011862049577757716, test_loss:2.7042083740234375
epoch:1680, train_loss:0.001179289072751999, test_loss:2.6963284015655518
epoch:1700, train_loss:0.00117252126801759, test_loss:2.6886096000671387
epoch:1720, train_loss:0.0011658959556370974, test_loss:2.681042432785034
epoch:1740, train_loss:0.001159411738626659, test_loss:2.673621892929077
epoch:1760, train_loss:0.0011530747869983315, test_loss:2.6663315296173096
epoch:1780, train_loss:0.0011468705488368869, test_loss:2.6591718196868896
epoch:1800, train_loss:0.0011408105492591858, test_loss:2.6521382331848145
epoch:1820, train_loss:0.001134884194470942, test_loss:2.6452348232269287
epoch:1840, train_loss:0.0011290841503068805, test_loss:2.638458728790283
epoch:1860, train_loss:0.001123423338867724, test_loss:2.631808280944824
epoch:1880, train_loss:0.0011178869754076004, test_loss:2.625295877456665
epoch:1900, train_loss:0.0011124806478619576, test_loss:2.618925094604492
epoch:1920, train_loss:0.001107192481867969, test_loss:2.6127049922943115
epoch:1940, train_loss:0.0011020202655345201, test_loss:2.6066269874572754
epoch:1960, train_loss:0.001096969353966415, test_loss:2.600712299346924
epoch:1980, train_loss:0.0010920269414782524, test_loss:2.594951629638672
model.eval()
prediction = list((model(train_x).data.reshape(-1))*100000) + list((model(test_x).data.reshape(-1))*100000)
plt.plot(value[3:], label='True Value')
plt.plot(prediction[:91], label='LSTM fit')
plt.plot(np.arange(90, 110, 1), prediction[90:], label='LSTM pred')
print(len(value[3:]))
print(len(prediction[90:]))
plt.legend(loc='best')
plt.title('Cumulative infections prediction(USA)')
plt.xlabel('Day')
plt.ylabel('Cumulative Cases')
plt.show()
115
20

在这里插入图片描述

df_2 = pd.read_csv('./COVID-19_China.csv')
df_2.head()
countryEnglishNameconfirmedCountnowconfirmedCountcuredCountdeadCountdeadRatecuredRate
0China54449928173.1250005.147059
1China63959230172.6604074.694836
2China90183936262.8856833.995560
3China1377129739412.9774872.832244
4China2076197149562.6974952.360308
value = df_2['confirmedCount'].values[20:140]
value
array([42747, 44765, 59907, 63950, 66581, 68595, 70644, 72533, 74284,
       74680, 75571, 76396, 77048, 77269, 77785, 78195, 78631, 78962,
       79394, 79972, 80175, 80303, 80424, 80581, 80734, 80815, 80868,
       80905, 80932, 80969, 80981, 80995, 81029, 81062, 81099, 81135,
       81202, 81264, 81385, 81457, 81566, 81691, 81806, 81896, 82034,
       82164, 82282, 82420, 82504, 82600, 82690, 82771, 82857, 82898,
       82965, 83038, 83094, 83188, 83263, 83323, 83399, 83522, 83606,
       83699, 83751, 83798, 84155, 84185, 84225, 84239, 84278, 84294,
       84305, 84313, 84330, 84338, 84341, 84367, 84369, 84373, 84387,
       84391, 84393, 84403, 84404, 84407, 84414, 84416, 84416, 84434,
       84450, 84451, 84461, 84465, 84471, 84478, 84487, 84494, 84503,
       84506, 84522, 84522, 84525, 84536, 84543, 84545, 84547, 84561,
       84569, 84572, 84593, 84603, 84602, 84609, 84617, 84624, 84630,
       84634], dtype=int64)
print(len(value))
x = []
y = []
seq = 3
for i in range(len(value)-seq-1):
    x.append(value[i:i+seq])
    y.append(value[i+seq])
118
train_x = (torch.tensor(x[0:90]).float()/100000.).reshape(-1, seq, 1)
train_y = (torch.tensor(y[0:90]).float()/100000.).reshape(-1, 1)
test_x = (torch.tensor(x[90:110]).float()/100000.).reshape(-1, seq, 1)
test_y = (torch.tensor(y[90:110]).float()/100000.).reshape(-1, 1)
print(test_y)
tensor([[0.8446],
        [0.8447],
        [0.8448],
        [0.8449],
        [0.8449],
        [0.8450],
        [0.8451],
        [0.8452],
        [0.8452],
        [0.8453],
        [0.8454],
        [0.8454],
        [0.8454],
        [0.8455],
        [0.8456],
        [0.8457],
        [0.8457],
        [0.8459],
        [0.8460],
        [0.8460]])
# 模型训练
model = LSTM()
optimzer = torch.optim.Adam(model.parameters(), lr=0.005)
loss_func = nn.MSELoss()
model.train()

for epoch in range(2000):
    output = model(train_x)
    loss = loss_func(output, train_y)
    optimzer.zero_grad()
    loss.backward()
    optimzer.step()
    if epoch % 20 == 0:
        tess_loss = loss_func(model(test_x), test_y)
        print("epoch:{}, train_loss:{}, test_loss:{}".format(epoch, loss, tess_loss))
epoch:0, train_loss:0.3825637102127075, test_loss:0.3793776333332062
epoch:20, train_loss:0.029308181256055832, test_loss:0.01525117177516222
epoch:40, train_loss:0.0014625154435634613, test_loss:0.001359123969450593
epoch:60, train_loss:0.0011256280122324824, test_loss:0.001867746701464057
epoch:80, train_loss:0.0008417891804128885, test_loss:0.0004369680245872587
epoch:100, train_loss:0.0008008113945834339, test_loss:0.0004433710710145533
epoch:120, train_loss:0.0007676983368583024, test_loss:0.0005258770543150604
epoch:140, train_loss:0.0007398930029012263, test_loss:0.0005360021023079753
epoch:160, train_loss:0.0007106903940439224, test_loss:0.0005205624038353562
epoch:180, train_loss:0.000680193246807903, test_loss:0.0005014035850763321
epoch:200, train_loss:0.0006485295598395169, test_loss:0.0004820456088054925
epoch:220, train_loss:0.0006157823954708874, test_loss:0.0004624216235242784
epoch:240, train_loss:0.0005820132209919393, test_loss:0.0004422594793140888
epoch:260, train_loss:0.0005472666234709322, test_loss:0.0004214201180730015
epoch:280, train_loss:0.000511578400619328, test_loss:0.0003998256870545447
epoch:300, train_loss:0.00047498263302259147, test_loss:0.00037742816493846476
epoch:320, train_loss:0.0004375250719022006, test_loss:0.00035419128835201263
epoch:340, train_loss:0.00039927675970830023, test_loss:0.00033008967875503004
epoch:360, train_loss:0.00036035984521731734, test_loss:0.0003051073872484267
epoch:380, train_loss:0.00032097692019306123, test_loss:0.0002792503801174462
epoch:400, train_loss:0.0002814592735376209, test_loss:0.00025257302331738174
epoch:420, train_loss:0.0002423272526357323, test_loss:0.00022522213112097234
epoch:440, train_loss:0.0002043575659627095, test_loss:0.00019748158229049295
epoch:460, train_loss:0.00016863604832906276, test_loss:0.000169840976013802
epoch:480, train_loss:0.00013653359201271087, test_loss:0.00014304733485914767
epoch:500, train_loss:0.0001095230836654082, test_loss:0.00011809338320745155
epoch:520, train_loss:8.876808715285733e-05, test_loss:9.60855686571449e-05
epoch:540, train_loss:7.458930485881865e-05, test_loss:7.79658803367056e-05
epoch:560, train_loss:6.617035978706554e-05, test_loss:6.417164695449173e-05
epoch:580, train_loss:6.185458914842457e-05, test_loss:5.449241871247068e-05
epoch:600, train_loss:5.988311022520065e-05, test_loss:4.820094909518957e-05
epoch:620, train_loss:5.898631934542209e-05, test_loss:4.4363230699673295e-05
epoch:640, train_loss:5.848926957696676e-05, test_loss:4.211986743030138e-05
epoch:660, train_loss:5.811382652609609e-05, test_loss:4.081864972249605e-05
epoch:680, train_loss:5.7767596445046365e-05, test_loss:4.003203866886906e-05
epoch:700, train_loss:5.742487701354548e-05, test_loss:3.9503516745753586e-05
epoch:720, train_loss:5.707951640943065e-05, test_loss:3.909639417543076e-05
epoch:740, train_loss:5.6730175856500864e-05, test_loss:3.8741745811421424e-05
epoch:760, train_loss:5.637647700496018e-05, test_loss:3.8406542444135994e-05
epoch:780, train_loss:5.601833254331723e-05, test_loss:3.807685061474331e-05
epoch:800, train_loss:5.56554296053946e-05, test_loss:3.774597280425951e-05
epoch:820, train_loss:5.5287899158429354e-05, test_loss:3.741172986337915e-05
epoch:840, train_loss:5.491534466273151e-05, test_loss:3.707344876602292e-05
epoch:860, train_loss:5.4537584219360724e-05, test_loss:3.673039100249298e-05
epoch:880, train_loss:5.4154432291397825e-05, test_loss:3.638227281044237e-05
epoch:900, train_loss:5.3765703341923654e-05, test_loss:3.603072764235549e-05
epoch:920, train_loss:5.3371091780718416e-05, test_loss:3.567387830116786e-05
epoch:940, train_loss:5.297027018968947e-05, test_loss:3.53103423549328e-05
epoch:960, train_loss:5.2563085773726925e-05, test_loss:3.494161501294002e-05
epoch:980, train_loss:5.214917837292887e-05, test_loss:3.456864942563698e-05
epoch:1000, train_loss:5.17280786880292e-05, test_loss:3.418769483687356e-05
epoch:1020, train_loss:5.129948112880811e-05, test_loss:3.3801185054471716e-05
epoch:1040, train_loss:5.086324381409213e-05, test_loss:3.34087380906567e-05
epoch:1060, train_loss:5.041856275056489e-05, test_loss:3.300896059954539e-05
epoch:1080, train_loss:4.9965601647272706e-05, test_loss:3.2601448765490204e-05
epoch:1100, train_loss:4.950328002450988e-05, test_loss:3.2187974284170195e-05
epoch:1120, train_loss:4.9031306843971834e-05, test_loss:3.1764171581016853e-05
epoch:1140, train_loss:4.854957660427317e-05, test_loss:3.133570862701163e-05
epoch:1160, train_loss:4.805713615496643e-05, test_loss:3.089520396315493e-05
epoch:1180, train_loss:4.7553505282849073e-05, test_loss:3.044824188691564e-05
epoch:1200, train_loss:4.7038094635354355e-05, test_loss:2.999177922902163e-05
epoch:1220, train_loss:4.651049312087707e-05, test_loss:2.9524149795179255e-05
epoch:1240, train_loss:4.596983490046114e-05, test_loss:2.904612665588502e-05
epoch:1260, train_loss:4.541549060377292e-05, test_loss:2.8558717531268485e-05
epoch:1280, train_loss:4.484672172111459e-05, test_loss:2.8060558179276995e-05
epoch:1300, train_loss:4.4262782466830686e-05, test_loss:2.75501352007268e-05
epoch:1320, train_loss:4.3662937969202176e-05, test_loss:2.7027786927646957e-05
epoch:1340, train_loss:4.3046358769061044e-05, test_loss:2.649312409630511e-05
epoch:1360, train_loss:4.241224087309092e-05, test_loss:2.5944673325284384e-05
epoch:1380, train_loss:4.175960930297151e-05, test_loss:2.5382616513525136e-05
epoch:1400, train_loss:4.108754365006462e-05, test_loss:2.4807826775941066e-05
epoch:1420, train_loss:4.039523264509626e-05, test_loss:2.4217149984906428e-05
epoch:1440, train_loss:3.968160672229715e-05, test_loss:2.3612001314177178e-05
epoch:1460, train_loss:3.8945843698456883e-05, test_loss:2.2989526769379154e-05
epoch:1480, train_loss:3.818666664301418e-05, test_loss:2.2353377062245272e-05
epoch:1500, train_loss:3.7403442547656596e-05, test_loss:2.1699825083487667e-05
epoch:1520, train_loss:3.659499634522945e-05, test_loss:2.1029607523814775e-05
epoch:1540, train_loss:3.576060043997131e-05, test_loss:2.034137287409976e-05
epoch:1560, train_loss:3.489920709398575e-05, test_loss:1.9637136574601755e-05
epoch:1580, train_loss:3.4010074159596115e-05, test_loss:1.8915245163952932e-05
epoch:1600, train_loss:3.3092539524659514e-05, test_loss:1.8175118384533562e-05
epoch:1620, train_loss:3.214598837075755e-05, test_loss:1.741875894367695e-05
epoch:1640, train_loss:3.1170118745649233e-05, test_loss:1.6644184142933227e-05
epoch:1660, train_loss:3.016489245055709e-05, test_loss:1.5854122466407716e-05
epoch:1680, train_loss:2.9130327675375156e-05, test_loss:1.5049477951833978e-05
epoch:1700, train_loss:2.8066970116924495e-05, test_loss:1.4228941836336162e-05
epoch:1720, train_loss:2.6975792934536003e-05, test_loss:1.3398408555076458e-05
epoch:1740, train_loss:2.5858256776700728e-05, test_loss:1.255551069334615e-05
epoch:1760, train_loss:2.4716307962080464e-05, test_loss:1.170479299617e-05
epoch:1780, train_loss:2.355259857722558e-05, test_loss:1.0850737453438342e-05
epoch:1800, train_loss:2.2370682927430607e-05, test_loss:9.995957043429371e-06
epoch:1820, train_loss:2.117482290486805e-05, test_loss:9.143737770500593e-06
epoch:1840, train_loss:1.9970017092418857e-05, test_loss:8.297822205349803e-06
epoch:1860, train_loss:1.8762442778097466e-05, test_loss:7.465289854735602e-06
epoch:1880, train_loss:1.7558993931743316e-05, test_loss:6.651775947830174e-06
epoch:1900, train_loss:1.6367670468753204e-05, test_loss:5.863402748218505e-06
epoch:1920, train_loss:1.519708621344762e-05, test_loss:5.106710432301043e-06
epoch:1940, train_loss:1.4056418876862153e-05, test_loss:4.386639830045169e-06
epoch:1960, train_loss:1.2955445527040865e-05, test_loss:3.7117310967005324e-06
epoch:1980, train_loss:1.1903614904440474e-05, test_loss:3.087144477831316e-06
model.eval()
prediction = list((model(train_x).data.reshape(-1))*100000) + list((model(test_x).data.reshape(-1))*100000)
plt.plot(value[3:], label='True Value')
plt.plot(prediction[:91], label='LSTM fit')
plt.plot(np.arange(90, 110, 1), prediction[90:], label='LSTM pred')
print(len(value[3:]))
print(len(prediction[90:]))
plt.legend(loc='best')
plt.title('Cumulative infections prediction(China)')
plt.xlabel('Day')
plt.ylabel('Cumulative Cases')
plt.show()
115
20

在这里插入图片描述

import torch
from torch.autograd import Variable
import torch.nn as nn
from graphviz import Digraph
from graphviz import Graph
def make_dot(var, params=None):
    """ Produces Graphviz representation of PyTorch autograd graph
    Blue nodes are the Variables that require grad, orange are Tensors
    saved for backward in torch.autograd.Function
    Args:
        var: output Variable
        params: dict of (name, Variable) to add names to node that
            require grad (TODO: make optional)
    """
    if params is not None:
        assert isinstance(params.values()[0], Variable)
        param_map = {id(v): k for k, v in params.items()}
 
    node_attr = dict(style='filled',
                     shape='box',
                     align='left',
                     fontsize='12',
                     ranksep='0.1',
                     height='0.2')
    dot = Digraph(node_attr=node_attr, graph_attr=dict(size="12,12"))
    seen = set()
 
    def size_to_str(size):
        return '('+(', ').join(['%d' % v for v in size])+')'
 
    def add_nodes(var):
        if var not in seen:
            if torch.is_tensor(var):
                dot.node(str(id(var)), size_to_str(var.size()), fillcolor='orange')
            elif hasattr(var, 'variable'):
                u = var.variable
                name = param_map[id(u)] if params is not None else ''
                node_name = '%s\n %s' % (name, size_to_str(u.size()))
                dot.node(str(id(var)), node_name, fillcolor='lightblue')
            else:
                dot.node(str(id(var)), str(type(var).__name__))
            seen.add(var)
            if hasattr(var, 'next_functions'):
                for u in var.next_functions:
                    if u[0] is not None:
                        dot.edge(str(id(u[0])), str(id(var)))
                        add_nodes(u[0])
            if hasattr(var, 'saved_tensors'):
                for t in var.saved_tensors:
                    dot.edge(str(id(t)), str(id(var)))
                    add_nodes(t)
    add_nodes(var.grad_fn)
    return dot
net = LSTM()
x = train_x
y = net(train_x)
g = make_dot(y)
g.view()
'Digraph.gv.pdf'
#g.view(quiet=True,quiet_view=True)
params = list(net.parameters())
k = 0
for i in params:
        l = 1
        print("该层的结构:" + str(list(i.size())))
        for j in i.size():
            l *= j
        print("该层参数和:" + str(l))
        k = k + l
print("总参数数量和:" + str(k))
该层的结构:[64, 1]
该层参数和:64
该层的结构:[64, 16]
该层参数和:1024
该层的结构:[64]
该层参数和:64
该层的结构:[64]
该层参数和:64
该层的结构:[1, 48]
该层参数和:48
该层的结构:[1]
该层参数和:1
总参数数量和:1265
31+30+19+29
109

参数列表

input_size:x的特征维度
hidden_size:隐藏层的特征维度
num_layers:lstm隐层的层数,默认为1
bias:False则bih=0和bhh=0. 默认为True
batch_first:True则输入输出的数据格式为 (batch, seq, feature)
dropout:除最后一层,每一层的输出都进行dropout,默认为: 0
bidirectional:True则为双向lstm默认为False
输入:input, (h0, c0)
输出:output, (hn,cn)
输入数据格式:
input(seq_len, batch, input_size)
h0(num_layers * num_directions, batch, hidden_size)
c0(num_layers * num_directions, batch, hidden_size)

输出数据格式:
output(seq_len, batch, hidden_size * num_directions)
hn(num_layers * num_directions, batch, hidden_size)
cn(num_layers * num_directions, batch, hidden_size)

Pytorch里的LSTM单元接受的输入都必须是3维的张量(Tensors).每一维代表的意思不能弄错。

第一维体现的是序列(sequence)结构,也就是序列的个数,用文章来说,就是每个句子的长度,因为是喂给网络模型,一般都设定为确定的长度,也就是我们喂给LSTM神经元的每个句子的长度,当然,如果是其他的带有带有序列形式的数据,则表示一个明确分割单位长度,

例如是如果是股票数据内,这表示特定时间单位内,有多少条数据。这个参数也就是明确这个层中有多少个确定的单元来处理输入的数据。

第二维度体现的是batch_size,也就是一次性喂给网络多少条句子,或者股票数据中的,一次性喂给模型多少是个时间单位的数据,具体到每个时刻,也就是一次性喂给特定时刻处理的单元的单词数或者该时刻应该喂给的股票数据的条数

第三位体现的是输入的元素(elements of input),也就是,每个具体的单词用多少维向量来表示,或者股票数据中 每一个具体的时刻的采集多少具体的值,比如最低价,最高价,均价,5日均价,10均价,等等

H0-Hn是什么意思呢?就是每个时刻中间神经元应该保存的这一时刻的根据输入和上一课的时候的中间状态值应该产生的本时刻的状态值,

这个数据单元是起的作用就是记录这一时刻之前考虑到所有之前输入的状态值,形状应该是和特定时刻的输出一致

c0-cn就是开关,决定每个神经元的隐藏状态值是否会影响的下一时刻的神经元的处理,形状应该和h0-hn一致。

当然如果是双向,和多隐藏层还应该考虑方向和隐藏层的层数。

df = pd.read_csv('./COVID-19_Italy.csv')
df.head()
Unnamed: 0countryEnglishNameconfirmedCountnowconfirmedCountcuredCountdeadCountdeadRatecuredRate
00Italy00000.00.0
11Italy00000.00.0
22Italy00000.00.0
33Italy00000.00.0
44Italy00000.00.0
value = df['confirmedCount'].values[20:140]
value
array([     3,      3,      3,      3,      3,      3,      3,      3,
            3,      3,      3,     20,    117,    230,    283,    374,
          528,    653,    888,   1128,   1694,   2036,   2502,   3089,
         3927,   4636,   5883,   7375,   9172,  10283,  12462,  15113,
        17660,  21270,  24938,  29022,  31506,  37178,  41035,  47021,
        55218,  59514,  64378,  70545,  74386,  81129,  87275,  93051,
        97689, 102106, 105792, 110574, 115895, 120281, 125016, 129481,
       132810, 135893, 139887, 143626, 148217, 152860, 156673, 159516,
       162488, 165155, 168941, 172434, 175925, 178972, 183957, 183957,
       187327, 192994, 192994, 195351, 197675, 199414, 201505, 203591,
       205463, 207428, 209328, 210717, 211938, 213013, 214457, 215858,
       217185, 218268, 219070, 219814, 221216, 222104, 223096, 223885,
       224760, 225435, 225886, 226699, 228006, 228658, 229327, 229858,
       230158, 230555, 231139, 231732, 232248, 232664, 233019, 233197,
       233515, 233836, 234013, 234531, 234801, 234998], dtype=int64)
print(len(value))
x = []
y = []
seq = 3
for i in range(len(value)-seq-1):
    x.append(value[i:i+seq])
    y.append(value[i+seq])
118
train_x = (torch.tensor(x[0:90]).float()/100000.).reshape(-1, seq, 1)
train_y = (torch.tensor(y[0:90]).float()/100000.).reshape(-1, 1)
test_x = (torch.tensor(x[90:110]).float()/100000.).reshape(-1, seq, 1)
test_y = (torch.tensor(y[90:110]).float()/100000.).reshape(-1, 1)
print(test_y)
tensor([[2.2210],
        [2.2310],
        [2.2389],
        [2.2476],
        [2.2543],
        [2.2589],
        [2.2670],
        [2.2801],
        [2.2866],
        [2.2933],
        [2.2986],
        [2.3016],
        [2.3056],
        [2.3114],
        [2.3173],
        [2.3225],
        [2.3266],
        [2.3302],
        [2.3320],
        [2.3352]])
# 模型训练
class LSTM(nn.Module):
    def __init__(self):
        super(LSTM, self).__init__()
        self.lstm = nn.LSTM(input_size=1, hidden_size=16, num_layers=1, batch_first=True)
        self.linear = nn.Linear(16 * seq, 1)
    def forward(self, x):
        x, (h, c) = self.lstm(x)
        x = x.reshape(-1, 16 * seq)
        x = self.linear(x)
        return x
# 模型训练
model = LSTM()
optimzer = torch.optim.Adam(model.parameters(), lr=0.005)
loss_func = nn.MSELoss()
model.train()

for epoch in range(2000):
    output = model(train_x)
    loss = loss_func(output, train_y)
    optimzer.zero_grad()
    loss.backward()
    optimzer.step()
    if epoch % 20 == 0:
        tess_loss = loss_func(model(test_x), test_y)
        print("epoch:{}, train_loss:{}, test_loss:{}".format(epoch, loss, tess_loss))
epoch:0, train_loss:1.7632023096084595, test_loss:5.223465442657471
epoch:20, train_loss:0.2965010404586792, test_loss:0.12659545242786407
epoch:40, train_loss:0.02264116331934929, test_loss:0.05677216500043869
epoch:60, train_loss:0.006982944440096617, test_loss:0.0008666875073686242
epoch:80, train_loss:0.0018317289650440216, test_loss:0.003289521439000964
epoch:100, train_loss:0.0015176519518718123, test_loss:0.0015515109989792109
epoch:120, train_loss:0.001411953242495656, test_loss:0.0017424819525331259
epoch:140, train_loss:0.0013288009213283658, test_loss:0.002439487725496292
epoch:160, train_loss:0.0012708938447758555, test_loss:0.002725407015532255
epoch:180, train_loss:0.0012296135537326336, test_loss:0.003102920949459076
epoch:200, train_loss:0.0011986845638602972, test_loss:0.003337680594995618
epoch:220, train_loss:0.0011738166213035583, test_loss:0.0035245800390839577
epoch:240, train_loss:0.0011521923588588834, test_loss:0.0036384714767336845
epoch:260, train_loss:0.0011321220081299543, test_loss:0.003699743188917637
epoch:280, train_loss:0.001112664700485766, test_loss:0.0037150864955037832
epoch:300, train_loss:0.0010933543089777231, test_loss:0.0036975769326090813
epoch:320, train_loss:0.0010739663848653436, test_loss:0.0036563656758517027
epoch:340, train_loss:0.0010544119868427515, test_loss:0.003598833456635475
epoch:360, train_loss:0.0010346685303375125, test_loss:0.003530224785208702
epoch:380, train_loss:0.0010147254215553403, test_loss:0.003454281948506832
epoch:400, train_loss:0.0009945958154276013, test_loss:0.0033735043834894896
epoch:420, train_loss:0.0009742853580974042, test_loss:0.003289596876129508
epoch:440, train_loss:0.0009538100566715002, test_loss:0.003203604370355606
epoch:460, train_loss:0.0009331702603958547, test_loss:0.003116239095106721
epoch:480, train_loss:0.0009123761556111276, test_loss:0.0030278817284852266
epoch:500, train_loss:0.0008914396748878062, test_loss:0.002938809571787715
epoch:520, train_loss:0.0008703632047399879, test_loss:0.0028492698911577463
epoch:540, train_loss:0.0008491512853652239, test_loss:0.00275938818231225
epoch:560, train_loss:0.0008278173627331853, test_loss:0.0026692496612668037
epoch:580, train_loss:0.0008063663844950497, test_loss:0.002579005667939782
epoch:600, train_loss:0.0007848081295378506, test_loss:0.0024887227918952703
epoch:620, train_loss:0.0007631516200490296, test_loss:0.002398498123511672
epoch:640, train_loss:0.0007414081483148038, test_loss:0.002308483235538006
epoch:660, train_loss:0.0007195929065346718, test_loss:0.0022187510039657354
epoch:680, train_loss:0.0006977192242629826, test_loss:0.0021294010803103447
epoch:700, train_loss:0.0006758036906830966, test_loss:0.0020405412651598454
epoch:720, train_loss:0.0006538643501698971, test_loss:0.0019522873917594552
epoch:740, train_loss:0.0006319198873825371, test_loss:0.0018647522665560246
epoch:760, train_loss:0.0006099938764236867, test_loss:0.0017780527705326676
epoch:780, train_loss:0.0005881070392206311, test_loss:0.001692296122200787
epoch:800, train_loss:0.000566287839319557, test_loss:0.0016075713792815804
epoch:820, train_loss:0.0005445620627142489, test_loss:0.0015240126522257924
epoch:840, train_loss:0.000522959919180721, test_loss:0.0014417333295568824
epoch:860, train_loss:0.0005015107453800738, test_loss:0.0013608544832095504
epoch:880, train_loss:0.00048024847637861967, test_loss:0.0012814635410904884
epoch:900, train_loss:0.0004592059995047748, test_loss:0.0012036816915497184
epoch:920, train_loss:0.0004384215862955898, test_loss:0.0011276379227638245
epoch:940, train_loss:0.0004179326933808625, test_loss:0.0010534359607845545
epoch:960, train_loss:0.0003977797459810972, test_loss:0.0009811957133933902
epoch:980, train_loss:0.0003779999096877873, test_loss:0.0009110327227972448
epoch:1000, train_loss:0.00035864009987562895, test_loss:0.000843059562612325
epoch:1020, train_loss:0.00033973553217947483, test_loss:0.0007773857214488089
epoch:1040, train_loss:0.00032133201602846384, test_loss:0.0007141505484469235
epoch:1060, train_loss:0.0003034717810805887, test_loss:0.0006534302374348044
epoch:1080, train_loss:0.00028619240038096905, test_loss:0.0005953567451797426
epoch:1100, train_loss:0.00026953473570756614, test_loss:0.0005400101072154939
epoch:1120, train_loss:0.0002535344392526895, test_loss:0.00048749061534181237
epoch:1140, train_loss:0.00023822428192943335, test_loss:0.00043787891627289355
epoch:1160, train_loss:0.00022363335301633924, test_loss:0.0003912308602593839
epoch:1180, train_loss:0.00020978778775315732, test_loss:0.0003475920238997787
epoch:1200, train_loss:0.00019670836627483368, test_loss:0.00030700559727847576
epoch:1220, train_loss:0.00018440828716848046, test_loss:0.0002694650029297918
epoch:1240, train_loss:0.0001728977804305032, test_loss:0.00023496514768339694
epoch:1260, train_loss:0.00016217738448176533, test_loss:0.0002034622011706233
epoch:1280, train_loss:0.00015224415983539075, test_loss:0.00017491883772891015
epoch:1300, train_loss:0.00014308829850051552, test_loss:0.00014922290574759245
epoch:1320, train_loss:0.00013469379337038845, test_loss:0.00012628933473024517
epoch:1340, train_loss:0.00012703817628789693, test_loss:0.0001059926871675998
epoch:1360, train_loss:0.00012009469355689362, test_loss:8.818962669465691e-05
epoch:1380, train_loss:0.00011383087985450402, test_loss:7.271893991855904e-05
epoch:1400, train_loss:0.0001082118833437562, test_loss:5.941572817391716e-05
epoch:1420, train_loss:0.00010319782450096682, test_loss:4.810726750292815e-05
epoch:1440, train_loss:9.874825627775863e-05, test_loss:3.86049687222112e-05
epoch:1460, train_loss:9.481970482738689e-05, test_loss:3.0730687285540625e-05
epoch:1480, train_loss:9.136966400546953e-05, test_loss:2.4305656552314758e-05
epoch:1500, train_loss:8.835419430397451e-05, test_loss:1.9148759747622535e-05
epoch:1520, train_loss:8.573079685447738e-05, test_loss:1.5098817129910458e-05
epoch:1540, train_loss:8.345866081072018e-05, test_loss:1.199172947963234e-05
epoch:1560, train_loss:8.149856876116246e-05, test_loss:9.681958545115776e-06
epoch:1580, train_loss:7.981352973729372e-05, test_loss:8.036873623495921e-06
epoch:1600, train_loss:7.836938311811537e-05, test_loss:6.933561053301673e-06
epoch:1620, train_loss:7.713426020927727e-05, test_loss:6.265192041610135e-06
epoch:1640, train_loss:7.607969018863514e-05, test_loss:5.936094112257706e-06
epoch:1660, train_loss:7.5179836130701e-05, test_loss:5.865499588253442e-06
epoch:1680, train_loss:7.441190973622724e-05, test_loss:5.9840194808202796e-06
epoch:1700, train_loss:7.375545828836039e-05, test_loss:6.233521162357647e-06
epoch:1720, train_loss:7.319300493691117e-05, test_loss:6.566989213752095e-06
epoch:1740, train_loss:7.270894275279716e-05, test_loss:6.945073437236715e-06
epoch:1760, train_loss:7.229032053146511e-05, test_loss:7.33830847821082e-06
epoch:1780, train_loss:7.192560588009655e-05, test_loss:7.723288035776932e-06
epoch:1800, train_loss:7.160565291997045e-05, test_loss:8.082806743914261e-06
epoch:1820, train_loss:7.132247992558405e-05, test_loss:8.406545930483844e-06
epoch:1840, train_loss:7.106920384103432e-05, test_loss:8.6846148406039e-06
epoch:1860, train_loss:7.084051321726292e-05, test_loss:8.915197213354986e-06
epoch:1880, train_loss:7.063188240863383e-05, test_loss:9.095112545765005e-06
epoch:1900, train_loss:7.043959340080619e-05, test_loss:9.22342951525934e-06
epoch:1920, train_loss:7.026109233265743e-05, test_loss:9.305442290497012e-06
epoch:1940, train_loss:7.009309774730355e-05, test_loss:9.340597898699343e-06
epoch:1960, train_loss:6.99344091117382e-05, test_loss:9.336043149232864e-06
epoch:1980, train_loss:6.978306191740558e-05, test_loss:9.294308256357908e-06
model.eval()
prediction = list((model(train_x).data.reshape(-1))*100000) + list((model(test_x).data.reshape(-1))*100000)
plt.plot(value[3:], label='True Value')
plt.plot(prediction[:91], label='LSTM fit')
plt.plot(np.arange(90, 110, 1), prediction[90:], label='LSTM pred')
print(len(value[3:]))
print(len(prediction[90:]))
plt.legend(loc='best')
plt.title('Cumulative infections prediction(Italy)')
plt.xlabel('Day')
plt.ylabel('Cumulative Cases')
plt.show()
115
20

在这里插入图片描述



  • 6
    点赞
  • 65
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 6
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 6
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

潘诺西亚的火山

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值