循环卷积模块
class Recurrent_block(nn.Module):
"""
Recurrent Block for R2Unet_CNN
*********************************************循环卷积模块*********************************************
"""
def __init__(self, out_ch, t=2):
super(Recurrent_block, self).__init__()
self.t = t
self.out_ch = out_ch
self.conv = nn.Sequential(
nn.Conv2d(out_ch, out_ch, kernel_size=3, stride=1, padding=1, bias=True),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True)
)
def forward(self, x):
for i in range(self.t):
if i == 0:
x = self.conv(x)
out = self.conv(x + x)
return out
残差模块
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, inplanes, planes