LSTM中的batch_size到底是什么

 

 真正的LSTM输入数据并不是按照原始数据这样的顺序输入网络的,

(数据太多,没展示完。。) 而是根据time_step=n将数据重构成下列形式进行输入的。这个案例中n=30,可以看到: 第一行数据就是按照t排列的前30条数据,y也就是预测值是第31条数据。依次类推……

"""此为generate_data_by_n_days构建好的数据格式
           c0       c1       c2       c3  ...      c27      c28      c29        y
0     4144.68  4124.73  4126.94  4109.36  ...  3946.43  3945.20  3952.20  3972.53
1     4124.73  4126.94  4109.36  4047.56  ...  3945.20  3952.20  3972.53  3970.11
2     4126.94  4109.36  4047.56  4018.61  ...  3952.20  3972.53  3970.11  3998.10
3     4109.36  4047.56  4018.61  4014.57  ...  3972.53  3970.11  3998.10  3979.66
4     4047.56  4018.61  4014.57  4007.14  ...  3970.11  3998.10  3979.66  3941.48
...       ...      ...      ...      ...  ...      ...      ...      ...      ...
2011  2631.05  2624.32  2618.25  2705.75  ...  2546.03  2534.16  2489.03  2520.76
2012  2624.32  2618.25  2705.75  2681.33  ...  2534.16  2489.03  2520.76  2514.65
2013  2618.25  2705.75  2681.33  2666.43  ...  2489.03  2520.76  2514.65  2486.24
2014  2705.75  2681.33  2666.43  2664.41  ...  2520.76  2514.65  2486.24  2481.66
2015  2681.33  2666.43  2664.41  2645.95  ...  2514.65  2486.24  2481.66  2472.84
"""

那么batch_size=60是什么呢,我们看看构建的模型LSTM网络一个batch中的输入数据格式吧:

我们查看的数据是从DataLoader中提取的X和Y,其中X是构建好的LSTM网络的输入数据,控制台终的输出:

tensor([[4144.6797, 4124.7300, 4126.9399,  ..., 3946.4299, 3945.1997,
         3952.2000],
        [4124.7300, 4126.9399, 4109.3599,  ..., 3945.1997, 3952.2000,
         3972.5298],
        [4126.9399, 4109.3599, 4047.5598,  ..., 3952.2000, 3972.5298,
         3970.1099],
        ...,
        [3716.0698, 3690.6399, 3758.7798,  ..., 4184.4399, 4148.5298,
         4085.1699],
        [3690.6399, 3758.7798, 3736.2500,  ..., 4148.5298, 4085.1699,
         4076.3899],
        [3758.7798, 3736.2500, 3732.6499,  ..., 4085.1699, 4076.3899,
         4077.4500]], device='cuda:0')
torch.Size([60, 30])
tensor([3972.5298, 3970.1099, 3998.0999, 3979.6599, 3941.4800, 3937.6899,
        3921.7000, 3880.7397, 3859.6799, 3842.8699, 3827.2100, 3857.4700,
        3839.3799, 3838.2300, 3853.7998, 3863.4500, 3807.3599, 3829.7998,
        3825.7598, 3769.8799, 3824.7397, 3801.7197, 3793.0000, 3803.6299,
        3738.8298, 3734.5298, 3731.7798, 3716.0698, 3690.6399, 3758.7798,
        3736.2500, 3732.6499, 3627.7598, 3585.7998, 3663.9500, 3647.9897,
        3775.8499, 3786.8899, 3899.8599, 3937.1799, 3987.4299, 4091.2700,
        4092.9998, 4063.0798, 4181.3101, 4215.8501, 4115.2598, 4147.6499,
        4092.5398, 4031.1499, 4110.5298, 4132.7798, 4126.7100, 4149.0098,
        4184.4399, 4148.5298, 4085.1699, 4076.3899, 4077.4500, 4001.5598],
       device='cuda:0')
torch.Size([60])


第一个tensor是X第二个tensor是Y

从X可以看到batchsize就是有多少个sequence(重构后的数据有2015个sequence),一个sequence有time_step=30条数据

Y的数量也是batch_size=60个。

  • 23
    点赞
  • 77
    收藏
    觉得还不错? 一键收藏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值