原来的帧移帧长会将数据变长很多。但实际上,我们可以做一些操作.改写dataset就可以了
class AADDataset(Dataset):
'''
x: Features.
y: Targets, if none, do prediction.
'''
def __init__(self, x, y):
self.y = torch.FloatTensor(y)
self.x = torch.FloatTensor(x)
def __getitem__(self, idx):
dim = list(self.x.size())
pointrange = dim[1]-train_length
tr = int(idx / pointrange)
point = idx % pointrange
return self.x[tr,point:point+train_length,:], self.y[tr,point:point+test_length,:]
def __len__(self):
dim = list(self.x.size())
return dim[0]*(dim[1]-train_length)
只要改写len就可以找到idx的范围,然后再在getitem里面解决idx的下标问题,困扰许久的帧长帧移问题就解决了。
可以看到改写之后和改写之前结果几乎一样