背景
最近做学习训练时,需要加入时间维度。
将原本是[B,C,H,W]格式的输入变成[B,T,C,H,W]的输入,方便使用Timedistribute封装层或者ConvLSTM。
方法
其实并不复杂,利用滑动窗口来对原本batch size进行分割,直接看代码吧
class dimsexpand(nn.Module):
def __init__(self,slide_num=1,timesteps=3):
super(dimsexpand,self).__init__()
self.slide_num = slide_num
self.timesteps = timesteps
def forward(self,x):
window_num = int((x.shape[0] - self.timesteps)/self.slide_num)
x = list(x[i * self.slide_num:i * self.slide_num+self.timesteps] for i in range(window_num))
x = torch.stack(x,0)
return x
处理后数据:
[B1,C,H,W] → [B2,T,C,H,W]
其中:
B1 = B2 + T
做一个简单的测试
if __name__ == '__main__':
net = dimsexpand()
input = torch.rand((5, 3, 512, 512))
output = net(input)
print('result:',output.size())
输出:result: torch.Size([2, 3, 3, 512, 512])
做个记录,以免遗忘
以上