pytorch:解决训练数据不能被batchsize整除

在训练seq2seq模型时,如果数据量不能被batch size整除,最后一个批次的数据可能会导致形状不匹配问题。为解决此问题,可以在PyTorch的DataLoader中设置`drop_last=True`,这样可以自动丢弃最后一个不足batch size的数据,确保每个epoch内数据的完整性和批处理的一致性。这是一个简单而有效的方法,尤其适用于大量数据的情况。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

训练seq2seq模型时,训练数据一般都不能刚好和batchsize成整数倍数关系。
那么在每个epoch训练中,最后会剩余一组数据量<batchsize的数据。
此时这些数据可能会不适合编写的网络形状,或者代码中reshape形状的部分,在rnn中还会不匹配隐状态形状。
因为我的训练数据量很大,所以直接把最后一个不足batch的数据组抛弃就好。

train_loader = Data.DataLoader(MyDataSet(train_enc_inputs, train_dec_inputs, train_dec_outputs), BATCH_SIZE, True,drop_last=True)

解决这个问题,在pytorch中使用dataloader加载数据,只需要在参数里加一个

drop_last=True

就可以了,十分方便!

评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值