使用paddle2.0的LSTM代码实现完荷预测

未来数据预测

基于过去一段时间的数据序列,预测未来一天的数据(例如股票、温度等)。
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
data=pd.read_csv("consumption.csv")
# data.head()
data.tail()
user_idrecord_dateconsumption
88548114542016/8/272441.0
88548214542016/8/282200.0
88548314542016/8/292221.0
88548414542016/8/301671.0
88548514542016/8/312176.0
data["user_id"] == 1 # 掩码
print(data["user_id"] == 1)
0          True
1          True
2          True
3          True
4          True
          ...  
885481    False
885482    False
885483    False
885484    False
885485    False
Name: user_id, Length: 885486, dtype: bool
data[data["user_id"] == 55]
user_idrecord_dateconsumption
32886552015/1/14225.0
32887552015/1/24027.0
32888552015/1/33532.0
32889552015/1/44687.0
32890552015/1/54555.0
............
33490552016/8/273664.0
33491552016/8/283169.0
33492552016/8/295182.0
33493552016/8/304720.0
33494552016/8/315116.0

609 rows × 3 columns

data.groupby('user_id')['consumption'].max()
user_id
1        4293.0
2         485.0
3        1044.0
4        7833.0
5        3100.0
         ...   
1450     1714.0
1451     3095.0
1452    11579.0
1453     2348.0
1454     3596.0
Name: consumption, Length: 1454, dtype: float64
data.groupby(["user_id"])["consumption"].min().plot()
<matplotlib.axes._subplots.AxesSubplot at 0x7f86e472ed90>

在这里插入图片描述

data.describe()
user_idconsumption
count885486.0000008.854860e+05
mean727.5000002.619980e+03
std419.7337833.154743e+04
min1.0000001.000000e+00
25%364.0000004.200000e+01
50%727.5000002.610000e+02
75%1091.0000008.250000e+02
max1454.0000001.310016e+06

使用LSTM进行预测

import paddle

# 定义LSTM模型

class LSTM(paddle.nn.Layer):
    def __init__(self, input_size=1, hidden_size=1): # input_size 输入维度,单变量预测的话这个就是1 ;hidden_size 隐藏层的维度(决定了数据输出的维度)
        super().__init__()
        self.rnn = paddle.nn.LSTM(input_size=input_size, hidden_size=hidden_size,  num_layers=2) # num_layers 网络的层数
        self.linear = paddle.nn.Linear(hidden_size, 1) # 表示output维度

    def forward(self, inputs):
        y, (hidden, cell) = self.rnn(inputs)
        output = self.linear(hidden)
        output = output[0]+output[1]
        # output = paddle.squeeze(output)  # 该OP会删除输入Tensor的Shape中尺寸为1的维度。
        return output                  
# 测试LSTM模型
input_size = 1  # 输入维度 单变量预测的话这个就是1 
seq_len = 10  # 用多少个数据预测后一个数据
batch_size = 4  # 一批4个样本长度为10

x = paddle.randn((batch_size, seq_len, input_size))
model = LSTM(input_size)
print('model:',model)
print(x.shape) # (batch_size, seq_len, input_size)
y=model.forward(x)
print(y)
print(y.shape)   
model: LSTM(
  (rnn): LSTM(1, 1, num_layers=2
    (0): RNN(
      (cell): LSTMCell(1, 1)
    )
    (1): RNN(
      (cell): LSTMCell(1, 1)
    )
  )
  (linear): Linear(in_features=1, out_features=1, dtype=float32)
)
[4, 10, 1]
Tensor(shape=[4, 1], dtype=float32, place=CPUPlace, stop_gradient=False,
       [[0.23631290],
        [0.05070865],
        [0.09258709],
        [0.11746131]])
[4, 1]
# 数据集类
class MyDataset(paddle.io.Dataset):
    """
    步骤一:继承paddle.io.Dataset类
    """
    def __init__(self, data, num_features=13, num_labels=1):
        """
        步骤二:实现构造函数,定义数据集大小

        data: numpy.Array 1维数组
        """
        super(MyDataset, self).__init__()
        self.data = data
        self.num_features = num_features
        self.num_labels = num_labels

        x = []    # x为输入
        y = []    # y为标签
        for i in range(0, len(data) - num_features - num_labels + 1):
            x.append(data[i:i+num_features])
            y.append(data[i+num_features:i+num_features+num_labels])
        print('x——1',x)
        print('y--1',y)
        self.x = np.vstack(x).reshape(-1, self.num_features, 1)   # np.vstack:按垂直方向(行顺序)堆叠数组构成一个新的数组
        self.y = np.vstack(y)
        print('x',x)
        print('y',y)
        self.x = np.array(self.x, dtype="float32")
        self.y = np.array(self.y, dtype="float32")

        self.num_samples = len(x)

    def __getitem__(self, index):
        """
        步骤三:实现__getitem__方法,定义指定index时如何获取数据,并返回单条数据(训练数据,对应的标签)
        """
        data = self.x[index]
        label = self.y[index]

        return data, label
    def __len__(self):
        """
        步骤四:实现__len__方法,返回数据集总数目
        """
        return self.num_samples
data_01 = data[data["user_id"] == 2]['consumption'].to_numpy()
data_02=MyDataset(data_01)
y [array([212.]), array([252.]), array([279.]), array([201.]), array([104.]), array([185.]), array([163.]), array([159.]), array([170.]), array([93.]), array([82.]), array([38.]), array([207.]), array([236.]), array([245.]), array([221.]), array([290.]), array([209.]), array([294.]), array([200.]), array([194.]), array([290.]), array([283.]), array([213.]), array([145.]), array([146.]), array([295.]), array([322.]), array([197.]), array([100.]), array([69.]), array([62.]), array([36.]), array([22.]), array([36.]), array([31.]), array([15.]), array([13.]), array([15.]), array([13.]), array([23.]), array([13.]), array([15.]), array([48.]), array([126.]), array([152.]), array([254.]), array([232.]), array([203.]), array([228.]), array([210.]), array([1.]), array([236.]), array([205.]), array([275.]), array([326.]), array([244.]), array([179.]), array([114.]), array([161.]), array([184.]), array([199.]), array([209.]), array([242.]), array([272.]), array([265.]), array([220.]), array([255.]), array([189.]), array([188.]), array([137.]), array([190.]), array([164.]), array([106.]), array([30.]), array([163.]), array([159.]), array([232.]), array([161.]), array([157.]), array([200.]), array([29.]), array([24.]), array([217.]), array([176.]), array([196.]), array([140.]), array([129.]), array([25.]), array([136.]), array([88.]), array([110.]), array([262.]), array([232.]), array([152.]), array([44.]), array([42.]), array([180.]), array([153.]), array([167.]), array([286.]), array([319.]), array([131.]), array([241.]), array([70.]), array([147.]), array([267.]), array([25.]), array([53.]), array([223.]), array([144.]), array([301.]), array([208.]), array([275.]), array([296.]), array([366.]), array([252.]), array([334.]), array([321.]), array([341.]), array([450.]), array([315.]), array([349.]), array([119.]), array([425.]), array([315.]), array([448.]), array([307.]), array([294.]), array([327.]), array([300.]), array([276.]), array([276.]), array([346.]), array([388.]), array([399.]), array([397.]), array([297.]), array([320.]), array([297.]), array([189.]), array([192.]), array([178.]), array([354.]), array([318.]), array([317.]), array([286.]), array([305.]), array([410.]), array([482.]), array([440.]), array([273.]), array([184.]), array([249.]), array([276.]), array([311.]), array([294.]), array([158.]), array([263.]), array([265.]), array([200.]), array([180.]), array([343.]), array([218.]), array([258.]), array([180.]), array([190.]), array([207.]), array([238.]), array([276.]), array([240.]), array([321.]), array([254.]), array([255.]), array([296.]), array([304.]), array([208.]), array([249.]), array([339.]), array([239.]), array([326.]), array([315.]), array([345.]), array([313.]), array([312.]), array([264.]), array([234.]), array([311.]), array([332.]), array([315.]), array([321.]), array([365.]), array([304.]), array([326.]), array([412.]), array([370.]), array([365.]), array([407.]), array([386.]), array([387.]), array([316.]), array([345.]), array([359.]), array([351.]), array([428.]), array([423.]), array([335.]), array([310.]), array([316.]), array([311.]), array([294.]), array([295.]), array([331.]), array([306.]), array([261.]), array([258.]), array([257.]), array([302.]), array([287.]), array([294.]), array([256.]), array([216.]), array([250.]), array([263.]), array([304.]), array([329.]), array([48.]), array([300.]), array([330.]), array([373.]), array([290.]), array([240.]), array([55.]), array([265.]), array([233.]), array([190.]), array([152.]), array([239.]), array([241.]), array([239.]), array([214.]), array([130.]), array([112.]), array([173.]), array([155.]), array([188.]), array([158.]), array([190.]), array([235.]), array([122.]), array([190.]), array([132.]), array([147.]), array([128.]), array([112.]), array([99.]), array([24.]), array([136.]), array([243.]), array([132.]), array([14.]), array([15.]), array([202.]), array([217.]), array([109.]), array([233.]), array([244.]), array([150.]), array([243.]), array([167.]), array([258.]), array([140.]), array([170.]), array([166.]), array([119.]), array([145.]), array([186.]), array([134.]), array([142.]), array([189.]), array([196.]), array([231.]), array([211.]), array([222.]), array([153.]), array([234.]), array([320.]), array([317.]), array([322.]), array([326.]), array([265.]), array([181.]), array([273.]), array([233.]), array([264.]), array([140.]), array([209.]), array([283.]), array([180.]), array([187.]), array([195.]), array([267.]), array([324.]), array([168.]), array([165.]), array([67.]), array([112.]), array([147.]), array([57.]), array([103.]), array([130.]), array([158.]), array([154.]), array([173.]), array([159.]), array([233.]), array([175.]), array([181.]), array([188.]), array([240.]), array([192.]), array([91.]), array([102.]), array([420.]), array([452.]), array([321.]), array([250.]), array([272.]), array([258.]), array([174.]), array([142.]), array([198.]), array([191.]), array([230.]), array([258.]), array([239.]), array([214.]), array([210.]), array([189.]), array([122.]), array([46.]), array([92.]), array([76.]), array([92.]), array([251.]), array([167.]), array([249.]), array([46.]), array([121.]), array([212.]), array([118.]), array([174.]), array([34.]), array([47.]), array([148.]), array([196.]), array([294.]), array([290.]), array([352.]), array([383.]), array([305.]), array([262.]), array([265.]), array([373.]), array([285.]), array([295.]), array([365.]), array([351.]), array([320.]), array([305.]), array([271.]), array([313.]), array([348.]), array([327.]), array([243.]), array([96.]), array([334.]), array([357.]), array([283.]), array([230.]), array([297.]), array([294.]), array([213.]), array([110.]), array([169.]), array([54.]), array([89.]), array([55.]), array([62.]), array([30.]), array([26.]), array([12.]), array([12.]), array([12.]), array([12.]), array([11.]), array([15.]), array([64.]), array([177.]), array([179.]), array([217.]), array([257.]), array([354.]), array([353.]), array([415.]), array([464.]), array([381.]), array([428.]), array([455.]), array([372.]), array([362.]), array([385.]), array([383.]), array([417.]), array([285.]), array([332.]), array([339.]), array([217.]), array([254.]), array([283.]), array([307.]), array([399.]), array([320.]), array([212.]), array([164.]), array([174.]), array([313.]), array([294.]), array([284.]), array([321.]), array([247.]), array([174.]), array([287.]), array([346.]), array([377.]), array([328.]), array([331.]), array([327.]), array([276.]), array([261.]), array([268.]), array([290.]), array([256.]), array([273.]), array([293.]), array([220.]), array([192.]), array([46.]), array([301.]), array([324.]), array([288.]), array([264.]), array([185.]), array([266.]), array([278.]), array([307.]), array([278.]), array([258.]), array([313.]), array([245.]), array([342.]), array([268.]), array([329.]), array([323.]), array([247.]), array([246.]), array([210.]), array([278.]), array([254.]), array([304.]), array([40.]), array([245.]), array([246.]), array([40.]), array([277.]), array([389.]), array([235.]), array([302.]), array([313.]), array([295.]), array([148.]), array([232.]), array([244.]), array([229.]), array([268.]), array([250.]), array([212.]), array([212.]), array([296.]), array([278.]), array([214.]), array([241.]), array([288.]), array([356.]), array([342.]), array([372.]), array([337.]), array([354.]), array([317.]), array([293.]), array([327.]), array([172.]), array([155.]), array([302.]), array([101.]), array([302.]), array([298.]), array([337.]), array([231.]), array([282.]), array([330.]), array([302.]), array([339.]), array([246.]), array([339.]), array([273.]), array([271.]), array([360.]), array([321.]), array([363.]), array([351.]), array([362.]), array([312.]), array([319.]), array([307.]), array([315.]), array([328.]), array([277.]), array([200.]), array([86.]), array([233.]), array([241.]), array([253.]), array([294.]), array([285.]), array([297.]), array([173.]), array([191.]), array([324.]), array([209.]), array([192.]), array([357.]), array([331.]), array([299.]), array([373.]), array([403.]), array([344.]), array([322.]), array([338.]), array([356.]), array([208.]), array([319.]), array([346.]), array([322.]), array([370.]), array([403.]), array([359.]), array([304.]), array([277.]), array([257.]), array([263.]), array([284.]), array([275.]), array([329.]), array([249.]), array([310.]), array([285.]), array([377.]), array([396.]), array([331.]), array([361.]), array([340.]), array([375.]), array([353.]), array([340.]), array([298.]), array([328.]), array([296.]), array([324.]), array([398.]), array([400.]), array([366.]), array([485.]), array([338.]), array([297.]), array([276.]), array([385.]), array([373.]), array([309.]), array([333.]), array([261.]), array([348.]), array([294.]), array([324.]), array([343.]), array([353.])]
# 从文件读取数据集
data = pd.read_csv("consumption.csv")
data.head()
data = data[data["user_id"] == 2]['consumption'].to_numpy()

data_min = data.min()
data_max = data.max()
data = (data - data_min) / (data_max - data_min) # 归一化处理(0,1)

n_train_samples = int(len(data) * 0.85)

train_data = data[:n_train_samples]
test_data = data[n_train_samples:]
plt.plot(range(n_train_samples), train_data)
plt.plot(range(n_train_samples, len(data)), test_data)
[<matplotlib.lines.Line2D at 0x7f86e465e610>]

在这里插入图片描述

# 经过测验,num_features = 12 batch_size = 8 时效果较好
train_dataset = MyDataset(train_data, num_features=12)
test_dataset = MyDataset(test_data, num_features=12)
train_loader = paddle.io.DataLoader(train_dataset, batch_size=8, shuffle=False, drop_last=True)
y [array([0.55785124]), array([0.74173554]), array([0.66115702]), array([0.74793388]), array([0.7231405]), array([0.74586777]), array([0.64256198]), array([0.65702479]), array([0.6322314]), array([0.64876033]), array([0.67561983]), array([0.57024793]), array([0.41115702]), array([0.17561983]), array([0.47933884]), array([0.49586777]), array([0.52066116]), array([0.6053719]), array([0.58677686]), array([0.61157025]), array([0.3553719]), array([0.39256198]), array([0.66735537]), array([0.42975207]), array([0.3946281]), array([0.73553719]), array([0.68181818]), array([0.61570248]), array([0.76859504]), array([0.83057851]), array([0.70867769]), array([0.66322314]), array([0.69628099]), array([0.73347107]), array([0.42768595]), array([0.65702479]), array([0.71280992]), array([0.66322314]), array([0.76239669]), array([0.83057851]), array([0.73966942]), array([0.62603306]), array([0.57024793]), array([0.52892562]), array([0.54132231]), array([0.58471074]), array([0.5661157]), array([0.67768595]), array([0.51239669]), array([0.63842975]), array([0.58677686]), array([0.7768595]), array([0.8161157]), array([0.68181818]), array([0.74380165]), array([0.70041322]), array([0.77272727]), array([0.72727273]), array([0.70041322]), array([0.61363636]), array([0.67561983]), array([0.60950413]), array([0.66735537]), array([0.82024793]), array([0.82438017]), array([0.75413223]), array([1.]), array([0.69628099]), array([0.61157025]), array([0.56818182]), array([0.79338843]), array([0.76859504]), array([0.63636364]), array([0.68595041]), array([0.53719008]), array([0.71694215]), array([0.6053719]), array([0.66735537]), array([0.70661157]), array([0.72727273])]

paddle提供了两种训练方式

  • 一种是基础API常规的训练方式
  • 另一种是用paddle.Model对模型进行封装,通过高层API如Model.fit()、Model.evaluate()、Model.predict()等完成模型的训练与预测
# 训练方式一:
model = LSTM()
loss_fn = paddle.nn.MSELoss(reduction='mean') # reduction(缩小),使用均值的方式
optimizer = paddle.optimizer.Adam(learning_rate=0.01,parameters=model.parameters())

for epoch in range(30):
    for batch, (batch_x, batch_y) in enumerate(train_loader()):
        # print(batch_x.shape)
        print("batch_y的shape",batch_y.shape)
        y_pred = model(batch_x)
        print("y_pred的shape",y_pred.shape)
        
        loss = loss_fn(y_pred, batch_y)
        loss.backward()
        optimizer.step()
        optimizer.clear_grad()
    print("epoch {} loss {:.4f}".format(epoch, float(loss)))

# 使用验证集进行验证
test_x = paddle.to_tensor(test_dataset.x, dtype="float32")
y_pred = model(test_x)
plt.plot(test_dataset.y)
plt.plot(y_pred.numpy())
plt.show()
epoch 29 loss 0.0188

在这里插入图片描述

# 训练方式二:
model = paddle.Model(LSTM())
model.prepare(optimizer=paddle.optimizer.Adam(learning_rate=0.01, parameters=model.parameters()),
              loss=paddle.nn.MSELoss(reduction='mean'))
model.fit(train_dataset,
          epochs=30,
          batch_size=12,
          shuffle=False,  # True 和 False 对结果影响很大
          drop_last=True)
          
# 使用验证集进行验证
y_pred = model.predict(test_dataset)
print('test_dataset.x.shape',test_dataset.x.shape)
print('test_dataset.y.shape',test_dataset.y.shape)
plt.plot(test_dataset.y)
plt.plot(np.concatenate(np.concatenate(y_pred)))
shape)
print('test_dataset.y.shape',test_dataset.y.shape)
plt.plot(test_dataset.y)
plt.plot(np.concatenate(np.concatenate(y_pred)))
plt.show()
The loss value printed in the log is the current step, and the metric is the average value of previous step.
Epoch 1/30
step 10/42 - loss: 0.0985 - 4ms/step
step 20/42 - loss: 0.0290 - 3ms/step
step 30/42 - loss: 0.0988 - 3ms/step
step 40/42 - loss: 0.0429 - 3ms/step
step 42/42 - loss: 0.0506 - 3ms/step
Epoch 2/30
step 10/42 - loss: 0.0566 - 3ms/step
step 20/42 - loss: 0.0285 - 3ms/step
step 30/42 - loss: 0.0720 - 3ms/step
step 40/42 - loss: 0.0431 - 3ms/step
step 42/42 - loss: 0.0422 - 3ms/step
Epoch 3/30
step 10/42 - loss: 0.0508 - 3ms/step
step 20/42 - loss: 0.0278 - 3ms/step
step 30/42 - loss: 0.0629 - 3ms/step
step 40/42 - loss: 0.0443 - 3ms/step
step 42/42 - loss: 0.0366 - 3ms/step
Epoch 4/30
step 10/42 - loss: 0.0464 - 3ms/step
step 20/42 - loss: 0.0279 - 3ms/step
step 30/42 - loss: 0.0529 - 3ms/step
step 40/42 - loss: 0.0470 - 3ms/step
step 42/42 - loss: 0.0296 - 3ms/step
Epoch 5/30
step 10/42 - loss: 0.0418 - 3ms/step
step 20/42 - loss: 0.0293 - 3ms/step
step 30/42 - loss: 0.0424 - 3ms/step
step 40/42 - loss: 0.0515 - 3ms/step
step 42/42 - loss: 0.0217 - 3ms/step
Epoch 6/30
step 10/42 - loss: 0.0383 - 3ms/step
step 20/42 - loss: 0.0321 - 3ms/step
step 30/42 - loss: 0.0360 - 3ms/step
step 40/42 - loss: 0.0561 - 3ms/step
step 42/42 - loss: 0.0159 - 3ms/step
Epoch 7/30
step 10/42 - loss: 0.0371 - 3ms/step
step 20/42 - loss: 0.0348 - 3ms/step
step 30/42 - loss: 0.0350 - 3ms/step
step 40/42 - loss: 0.0588 - 3ms/step
step 42/42 - loss: 0.0139 - 3ms/step
Epoch 8/30
step 10/42 - loss: 0.0374 - 3ms/step
step 20/42 - loss: 0.0362 - 3ms/step
step 30/42 - loss: 0.0354 - 3ms/step
step 40/42 - loss: 0.0599 - 3ms/step
step 42/42 - loss: 0.0136 - 3ms/step
Epoch 9/30
step 10/42 - loss: 0.0378 - 3ms/step
step 20/42 - loss: 0.0366 - 3ms/step
step 30/42 - loss: 0.0355 - 3ms/step
step 40/42 - loss: 0.0602 - 3ms/step
step 42/42 - loss: 0.0136 - 3ms/step
Epoch 10/30
step 10/42 - loss: 0.0380 - 3ms/step
step 20/42 - loss: 0.0367 - 3ms/step
step 30/42 - loss: 0.0354 - 3ms/step
step 40/42 - loss: 0.0603 - 3ms/step
step 42/42 - loss: 0.0136 - 3ms/step
Epoch 11/30
step 10/42 - loss: 0.0381 - 3ms/step
step 20/42 - loss: 0.0366 - 3ms/step
step 30/42 - loss: 0.0354 - 3ms/step
step 40/42 - loss: 0.0603 - 3ms/step
step 42/42 - loss: 0.0136 - 3ms/step
Epoch 12/30
step 10/42 - loss: 0.0382 - 3ms/step
step 20/42 - loss: 0.0366 - 3ms/step
step 30/42 - loss: 0.0353 - 3ms/step
step 40/42 - loss: 0.0603 - 3ms/step
step 42/42 - loss: 0.0136 - 3ms/step
Epoch 13/30
step 10/42 - loss: 0.0383 - 3ms/step
step 20/42 - loss: 0.0366 - 3ms/step
step 30/42 - loss: 0.0353 - 3ms/step
step 40/42 - loss: 0.0603 - 3ms/step
step 42/42 - loss: 0.0136 - 3ms/step
Epoch 14/30
step 10/42 - loss: 0.0384 - 3ms/step
step 20/42 - loss: 0.0366 - 3ms/step
step 30/42 - loss: 0.0352 - 3ms/step
step 40/42 - loss: 0.0603 - 3ms/step
step 42/42 - loss: 0.0136 - 3ms/step
Epoch 15/30
step 10/42 - loss: 0.0385 - 3ms/step
step 20/42 - loss: 0.0365 - 3ms/step
step 30/42 - loss: 0.0352 - 3ms/step
step 40/42 - loss: 0.0603 - 3ms/step
step 42/42 - loss: 0.0136 - 3ms/step
Epoch 16/30
step 10/42 - loss: 0.0386 - 3ms/step
step 20/42 - loss: 0.0365 - 3ms/step
step 30/42 - loss: 0.0352 - 3ms/step
step 40/42 - loss: 0.0603 - 3ms/step
step 42/42 - loss: 0.0137 - 3ms/step
Epoch 17/30
step 10/42 - loss: 0.0386 - 3ms/step
step 20/42 - loss: 0.0364 - 3ms/step
step 30/42 - loss: 0.0351 - 3ms/step
step 40/42 - loss: 0.0602 - 3ms/step
step 42/42 - loss: 0.0137 - 3ms/step
Epoch 18/30
step 10/42 - loss: 0.0387 - 3ms/step
step 20/42 - loss: 0.0364 - 3ms/step
step 30/42 - loss: 0.0351 - 3ms/step
step 40/42 - loss: 0.0602 - 3ms/step
step 42/42 - loss: 0.0137 - 3ms/step
Epoch 19/30
step 10/42 - loss: 0.0387 - 3ms/step
step 20/42 - loss: 0.0363 - 3ms/step
step 30/42 - loss: 0.0351 - 3ms/step
step 40/42 - loss: 0.0602 - 3ms/step
step 42/42 - loss: 0.0137 - 3ms/step
Epoch 20/30
step 10/42 - loss: 0.0388 - 3ms/step
step 20/42 - loss: 0.0362 - 3ms/step
step 30/42 - loss: 0.0350 - 3ms/step
step 40/42 - loss: 0.0601 - 3ms/step
step 42/42 - loss: 0.0137 - 3ms/step
Epoch 21/30
step 10/42 - loss: 0.0388 - 3ms/step
step 20/42 - loss: 0.0362 - 3ms/step
step 30/42 - loss: 0.0350 - 3ms/step
step 40/42 - loss: 0.0601 - 3ms/step
step 42/42 - loss: 0.0138 - 3ms/step
Epoch 22/30
step 10/42 - loss: 0.0389 - 3ms/step
step 20/42 - loss: 0.0361 - 3ms/step
step 30/42 - loss: 0.0350 - 3ms/step
step 40/42 - loss: 0.0600 - 3ms/step
step 42/42 - loss: 0.0138 - 3ms/step
Epoch 23/30
step 10/42 - loss: 0.0389 - 3ms/step
step 20/42 - loss: 0.0360 - 3ms/step
step 30/42 - loss: 0.0350 - 3ms/step
step 40/42 - loss: 0.0599 - 3ms/step
step 42/42 - loss: 0.0138 - 3ms/step
Epoch 24/30
step 10/42 - loss: 0.0389 - 3ms/step
step 20/42 - loss: 0.0360 - 3ms/step
step 30/42 - loss: 0.0349 - 3ms/step
step 40/42 - loss: 0.0599 - 3ms/step
step 42/42 - loss: 0.0138 - 3ms/step
Epoch 25/30
step 10/42 - loss: 0.0390 - 3ms/step
step 20/42 - loss: 0.0359 - 3ms/step
step 30/42 - loss: 0.0349 - 3ms/step
step 40/42 - loss: 0.0598 - 3ms/step
step 42/42 - loss: 0.0138 - 3ms/step
Epoch 26/30
step 10/42 - loss: 0.0390 - 3ms/step
step 20/42 - loss: 0.0358 - 3ms/step
step 30/42 - loss: 0.0349 - 3ms/step
step 40/42 - loss: 0.0597 - 3ms/step
step 42/42 - loss: 0.0139 - 3ms/step
Epoch 27/30
step 10/42 - loss: 0.0390 - 3ms/step
step 20/42 - loss: 0.0357 - 3ms/step
step 30/42 - loss: 0.0349 - 3ms/step
step 40/42 - loss: 0.0597 - 3ms/step
step 42/42 - loss: 0.0139 - 3ms/step
Epoch 28/30
step 10/42 - loss: 0.0391 - 3ms/step
step 20/42 - loss: 0.0357 - 3ms/step
step 30/42 - loss: 0.0349 - 3ms/step
step 40/42 - loss: 0.0596 - 3ms/step
step 42/42 - loss: 0.0139 - 3ms/step
Epoch 29/30
step 10/42 - loss: 0.0391 - 3ms/step
step 20/42 - loss: 0.0356 - 3ms/step
step 30/42 - loss: 0.0349 - 3ms/step
step 40/42 - loss: 0.0595 - 3ms/step
step 42/42 - loss: 0.0139 - 3ms/step
Epoch 30/30
step 10/42 - loss: 0.0391 - 3ms/step
step 20/42 - loss: 0.0355 - 3ms/step
step 30/42 - loss: 0.0348 - 3ms/step
step 40/42 - loss: 0.0595 - 3ms/step
step 42/42 - loss: 0.0140 - 3ms/step
Predict begin...
step 80/80 [==============================] - 2ms/step        
Predict samples: 80
test_dataset.x.shape (80, 12, 1)
test_dataset.y.shape (80, 1)

在这里插入图片描述

  • 11
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

深耕AI

谢谢鼓励~我将继续创作优质博文

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值