训练seq2seq模型时,训练数据一般都不能刚好和batchsize成整数倍数关系。
那么在每个epoch训练中,最后会剩余一组数据量<batchsize的数据。
此时这些数据可能会不适合编写的网络形状,或者代码中reshape形状的部分,在rnn中还会不匹配隐状态形状。
因为我的训练数据量很大,所以直接把最后一个不足batch的数据组抛弃就好。
train_loader = Data.DataLoader(MyDataSet(train_enc_inputs, train_dec_inputs, train_dec_outputs), BATCH_SIZE, True,drop_last=True)
解决这个问题,在pytorch中使用dataloader加载数据,只需要在参数里加一个
drop_last=True
就可以了,十分方便!