Linknet网络结构

Linknet网络结构

LinkNet图像语义分割
像素级的图像语义分割,不仅需要精确,还需要高效(例如:自动驾驶)

具体结构

整体结构

一个输入层+4个编码层+4个解码层+1个输出层

结构

编码层

在这里插入图片描述

解码层

在这里插入图片描述

创建Linknet模型

思路:编写不同的block在最后输出阶段将其链接

1、编写 卷积模块 (卷积 + 激活 + BN)

2、编写 反卷积模块 (反卷积 + 激活 + BN)

3、编码器(4*卷积模块)

4、解码器(卷积模块+反卷积模块+卷积模块)

5、实现整体的网络结构 (卷积模型+反卷积模型+解码器+编码器)

卷积模块

class Convblock (nn.Module):
    def __init__(self, in_channels, out_channels, 
                 k_size=3, 
                 stride=1, 
                 padding=1):
        super(Convblock, self).__init__()
        self.conv_relu = nn.Sequential(
                            nn.Conv2d(in_channels, out_channels, 
                                      kernel_size=k_size,
                                      stride=stride,
                                      padding=padding),
                            nn.BatchNorm2d(out_channels),
                            nn.ReLU(inplace=True)
            )
    def forward(self, x):
        x = self.conv_relu(x)
        return x

反卷积模块

#反卷积模块
class Deconvblock (nn.Module):
    def __init__(self,in_channels,out_channels,
                   k_size = 3,
                   stride = 2,
                   padding = 1,
                   output_padding = 1):
        super(Deconvblock,self).__init__()
        self.deconv = nn.ConvTranspose2d(in_channels,out_channels,kernel_size=k_size,stride=stride,padding=padding,output_padding=output_padding)
        self.bn = nn.BatchNorm2d(out_channels)
    
    def forward(self,x,is_act = True):
        x = self.deconv(x)
        if is_act:
            x = torch.relu(self.bn(x))
        return x
      

编码器

#编码器模块
class Encodeblock (nn.Module):
    def __init__(self,in_channels,out_channels):
        super(Encodeblock,self).__init__()
        self.conv1 = Convblock(in_channels,out_channels,stride=2)
        self.conv2 = Convblock(out_channels,out_channels)  
        self.conv3 = Convblock(out_channels,out_channels)
        self.conv4 = Convblock(out_channels,out_channels)
        self.short_cut = Convblock(in_channels,out_channels,stride=2)
        
    def forward(self,x):
        out1 = self.conv1(x)
        out1 = self.conv1(out1)
        short_cut = self.short_cut(x)
        
        out2 = self.conv3(out1+short_cut)
        out2 = self.conv4(out2)
        return out2+out1
        

解码器

#解码器模块
class Encodeblock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(Encodeblock, self).__init__()
        self.conv1_1 = Convblock(in_channels, out_channels, stride=2)
        self.conv1_2 = Convblock(out_channels, out_channels)
        self.conv2_1 = Convblock(out_channels, out_channels)
        self.conv2_2 = Convblock(out_channels, out_channels)
        self.shortcut = Convblock(in_channels, out_channels, stride=2)

    def forward(self, x):
        out1 = self.conv1_1(x)
        out1 = self.conv2_1(out1)
        residue = self.shortcut(x)
        out2 = self.conv2_1(out1 + residue)
        out2 = self.conv2_2(out2)
        return out2 + out1

最终模型的编写(注意每一步的输出,中间有类Resnet结构)

#模型编写
class Net (nn.Module):
    def __init__(self):
        super(Net,self).__init__()
        self.input_conv = Convblock(3,64,k_size=7,stride=2,padding=3)
        self.input_maxpool = nn.MaxPool2d(kernel_size=(2,2))
        
        self.encode1 = Encodeblock(64,64)
        self.encode2 = Encodeblock(64,128)
        self.encode3 = Encodeblock(128,256)
        self.encode4 = Encodeblock(256,512)
        
        self.decode4 = Decodeblock(512,256)
        self.decode3 = Decodeblock(256,128)
        self.decode2 = Decodeblock(128,64)
        self.decode1 = Decodeblock(64,64)
        
        self.deconv_out1 = Deconvblock(64,32)
        self.conv_out = Convblock(32,32)
        self.deconv_out2 = Deconvblock(32,2,k_size=2,padding=0,output_padding=0)
        
        


        
    def forward(self,x):
        x = self.input_conv(x)
        x = self.input_maxpool(x)
        
        e1 = self.encode1(x)
        e2 = self.encode2(e1)
        e3 = self.encode3(e2)
        e4 = self.encode4(e3)
        
        d4 = self.decode4(e4)
        d3 = self.decode3(d4+e3)
        d2 = self.decode2(d3+e2)
        d1 = self.decode1(d2+e1)
        
        f1 = self.deconv_out1(d1)
        f2 = self.conv_out(f1)
        f3 = self.deconv_out2(f2)
        return f3
        
  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值