def window_partition(x, window_size):
B, D, H, W, C = x.shape
x = x.view(B, D // window_size[0], window_size[0], H // window_size[1], window_size[1], W // window_size[2], window_size[2], C)
windows = x.permute(0, 1, 3, 5, 2, 4, 6, 7).contiguous().view(-1, reduce(mul, window_size), C)
return windows
将特征图B, D, H, W, C=(1,35,35,35,1),按照window_size(7,7,7)划分窗口,
得到的窗口个数是:
(D/window_size)*(H/window_size)*(W/window_size)
最后返回值是:
(窗口个数(125),窗口大小(7*7*7),1)