1. torch.chunk
torch.chunk(input, chunks, dim=0) → List of Tensors
将input tensor划分成特定的块数,每个块都是input tensor的一个视图,最后一个块可能会小一点,因为不能被dim整除。
- input 输入tensor
- chunks 返回多少个块
- dim 沿着哪个维度进行切分
>>> import torch
>>> a = torch.zeros([6, 8 , 2, 2])
>>> x,y = a.chunk(chunks=2,dim=2)
>>> x.shape
torch.Size([6, 8, 1, 2])
>>> y.shape
torch.Size([6, 8, 1, 2])
>>>
2. nn.GroupNorm
BN是Batch维度上的归一化,GN就是Group维度上的归一化。
y = x − E [ x ] Var [ x ]