class DeformConv(nn.Module):
def __init__(self, chi, cho):
super(DeformConv, self).__init__()
self.actf = nn.Sequential(
nn.BatchNorm2d(cho, momentum=BN_MOMENTUM),
nn.ReLU(inplace=True)
)
self.conv = DCN(chi, cho, kernel_size=(3, 3), stride=1, padding=1, dilation=1, deformable_groups=1)
def forward(self, x):
x = self.conv(x)
x = self.actf(x)
return x
class DeformConv_no_relu(nn.Module):
def __init__(self, chi, cho):
super(DeformConv_no_relu, self).__init__()
self.actf = nn.Sequential(
nn.BatchNorm2d(cho, momentum=BN_MOMENTUM)
)
self.conv = DCN(chi, cho, kernel_size=(3, 3), stride=1, padding=1, dilation=1, deformable_groups=1)
def forward(self, x):
x = self.conv(x)
x = self.actf(x)
return x
DeformConv
最新推荐文章于 2023-11-17 12:15:43 发布