学习 irevnet,发现这个网络的结构真的是巧妙。加点料估计能节省一大堆显存
irevnet 里的 PSI 层,原来是 pixelshuffle 的反向操作。看了半天。。。
看到里面有 pixelshuffle 和 pixelshuffle 逆向的代码实现,就想着来优化一下,里面的for操作真的看得心痒。
优化后代码,没想到比官方实现还快一点点。神奇!
github仓库:https://github.com/One-sixth/pixelshuffle_invert_pytorch
不想去仓库的话就直接看下面把,又短又简洁
def pixelshuffle(x: torch.Tensor, factor_hw: Tuple[int, int]):
pH = factor_hw[0]
pW = factor_hw[1]
y = x
B, iC, iH, iW = y.shape
oC, oH, oW = iC//(pH*pW), iH*pH, iW*pW
y = y.reshape(B, oC, pH, pW, iH, iW)
y = y.permute(0, 1, 4, 2, 5, 3) # B, oC, iH, pH, iW, pW
y = y.reshape(B, oC, oH, oW)
return y
def pixelshuffle_invert(x: torch.Tensor, factor_hw: Tuple[int, int]):
pH = factor_hw[0]
pW = factor_hw[1]
y = x
B, iC, iH, iW = y.shape
oC, oH, oW = iC*(pH*pW), iH//pH, iW//pW
y = y.reshape(B, iC, oH, pH, oW, pW)
y = y.permute(0, 1, 3, 5, 2, 4) # B, iC, pH, pW, oH, oW
y = y.reshape(B, oC, oH, oW)
return y