主要是将Identity如何等效为3乘3的DW卷积
Identity–>1乘1conv(对应通道卷积核值为1)—>3乘3conv(1乘1卷积核padding为3乘3)
class RepCPE(nn.Module):
"""Implementation of conditional positional encoding.
For more details refer to paper:
`Conditional Positional Encodings for Vision Transformers <https://arxiv.org/pdf/2102.10882.pdf>`_
In our implementation, we can reparameterize this module to eliminate a skip connection.
"""
def __init__(
self,
in_channels: int,
embed_dim: int = 768,
spatial_shape: Union[int, Tuple[int, int]] = (7, 7),
inference_mode=False,
) -> None:
"""Build reparameterizable conditional positional encoding
Args:
in_channels: Number of input channels.
embed_dim: Number of embedding dimensions. Default: 768
spatial_shape: Spatial shape of kernel for positional encoding. Default: (7, 7)
inference_mode: Flag to instantiate block in inference mode. Default: ``False``
"""
super(RepCPE, self).__init__()
if isinstance(spatial_shape, int):
spatial_shape = tuple([spatial_shape] * 2)
assert isinstance(spatial_shape, Tuple), (
f'"spatial_shape" must by a sequence or int, '
f"get {type(spatial_shape)} instead."
)
assert len(spatial_shape) == 2, (
f'Length of "spatial_shape" should be 2, '
f"got {len(spatial_shape)} instead."
)
self.spatial_shape = spatial_shape
self.embed_dim = embed_dim
self.in_channels = in_channels
self.groups = embed_dim
if inference_mode:
self.reparam_conv = nn.Conv2d(
in_channels=self.in_channels,
out_channels=self.embed_dim,
kernel_size=self.spatial_shape,
stride=1,
padding=int(self.spatial_shape[0] // 2),
groups=self.embed_dim,
bias=True,
)
else:
self.pe = nn.Conv2d(
in_channels,
embed_dim,
spatial_shape,
1,
int(spatial_shape[0] // 2),
bias=True,
groups=embed_dim,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
if hasattr(self, "reparam_conv"):
x = self.reparam_conv(x)
return x
else:
x = self.pe(x) + x
return x
def reparameterize(self) -> None:
# Build equivalent Id tensor
input_dim = self.in_channels // self.groups # DWconv
kernel_value = torch.zeros(
(
self.in_channels,
input_dim,
self.spatial_shape[0],
self.spatial_shape[1],
),
dtype=self.pe.weight.dtype,
device=self.pe.weight.device,
)
for i in range(self.in_channels):
kernel_value[
i,
i % input_dim, # 对应通道
self.spatial_shape[0] // 2, # 卷积核中心值为1
self.spatial_shape[1] // 2,
] = 1
id_tensor = kernel_value
# Reparameterize Id tensor and conv
w_final = id_tensor + self.pe.weight
b_final = self.pe.bias
# Introduce reparam conv
self.reparam_conv = nn.Conv2d(
in_channels=self.in_channels,
out_channels=self.embed_dim,
kernel_size=self.spatial_shape,
stride=1,
padding=int(self.spatial_shape[0] // 2),
groups=self.embed_dim,
bias=True,
)
self.reparam_conv.weight.data = w_final
self.reparam_conv.bias.data = b_final
for para in self.parameters():
para.detach_()
self.__delattr__("pe")