训练seq2seq模型时,训练数据一般都不能刚好和batchsize成整数倍数关系。
那么在每个epoch训练中,最后会剩余一组数据量<batchsize的数据。
此时这些数据可能会不适合编写的网络形状,或者代码中reshape形状的部分,在rnn中还会不匹配隐状态形状。
因为我的训练数据量很大,所以直接把最后一个不足batch的数据组抛弃就好。
Data = Dataset_Pred
timeenc = 0 if args.embed!='timeF' else 1
flag = 'pred'; shuffle_flag = False; drop_last = False; batch_size = 1
freq = args.detail_freq
data_set = Data(
root_path=args.root_path,
data_path=args.data_path,
flag=flag,
size=[args.seq_len, args.label_len, args.pred_len],
features=args.features,
target=args.target,
timeenc=timeenc,
freq=freq
)
data_loader = DataLoader(
data_set,
batch_size=batch_size,
shuffle=shuffle_flag,
num_workers=args.num_workers,
drop_last=drop_last)
解决这个问题,在pytorch中使用dataloader加载数据,只需要在参数里加一个
drop_last=True
就可以了,十分方便!
————————————————
版权声明:本文为CSDN博主「y hat」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。
原文链接:https://blog.csdn.net/zhangqiqiyihao/article/details/118088321