nn.PReLU() if isRelu else nn.Sequential(),
)
class VarGBlock_S1(nn.Module):
def init(self, in_plances,kernel_size, stride=1, S=8):
super(VarGBlock_S1, self).init()
plances = 2 * in_plances
self.varGConv1 = VarGConv(in_plances, plances, kernel_size, stride, S)
self.varGPointConv1 = VarGPointConv(plances, in_plances, stride, S, isRelu=True)
self.varGConv2 = VarGConv(in_plances, plances, kernel_size, stride, S)
self.varGPointConv2 = VarGPointConv(plances, in_plances, stride, S, isRelu=False)
self.se = SqueezeAndExcite(in_plances,in_plances)
self.prelu = nn.PReLU()
def forward(self, x):
out = x
x = self.varGPointConv1(self.varGConv1(x))
x = self.varGPointConv2(self.varGConv2(x))
x = self.se(x)
out += x
return self.prelu(out)
class VarGBlock_S2(nn.Module):
def init(self, in_plances,kernel_size, stride=2, S=8):
super(VarGBlock_S2, self).init()
plances = 2 * in_plances
self.varGConvBlock_branch1 = nn.Sequential(
VarGConv(in_plances, plances, kernel_size, stride, S),
VarGPointConv(plances,