残差模块
常见的残差模块如下图左边,是由两个卷积层组成 。右边是瓶颈残差模块。这里的1*1的卷积是用来进行升维和降维的。
class ResBlock(nn.Module):
def __init__(self, in_channel=256, out_channel=256, kernel_size=3, stride=1, padding=1,
bias=True):
super(ResBlock, self).__init__()
layers = [nn.Conv2d(in_channel, out_channel, kernel_size=kernel_size,