【网络结构复现】之ACNet

网络出处请看这篇博客:基于注意力机制的深度学习路面裂缝检测

直接上代码了,不会的地方可以留言

AFM.py

from torch import nn

class ConvG1(nn.Module): #ConvG1采用 7×7 和 3×3 步长均为 2 的 2 种卷积核
    def __init__(self):
        super(ConvG1, self).__init__()
        self.conv1 = nn.Conv2d(512,128,7,2,3)
        self.conv2 = nn.Conv2d(512,128,3,2,1)
    def forward(self,x):
        x1 = self.conv1(x)
        x2 = self.conv2(x)
        out = x1 + x2
        return out

class ConvG2(nn.Module): # ConvG2采用 5×5 和 3×3 步长均为 2 的 2 种卷积核
    def __init__(self):
        super(ConvG2, self).__init__()
        self.conv1 = nn.Conv2d(128,128,5,2,2)
        self.conv2 = nn.Conv2d(128,128,3,2,1)
    def forward(self,x):
        x1 = self.conv1(x)
        x2 = self.conv2(x)
        out = x1 + x2
        return out

class AFM(nn.Module):
    def __init__(self,in_channel):
        super(AFM, self).__init__()
        self.gp = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(512, 512, 1, 1),
            nn.Upsample(scale_factor=16, mode='bilinear', align_corners=True)
        )
        self.conv1 = nn.Conv2d(512, 512, 1, 1)
        self.conv3 = nn.Conv2d(128, 512, 3, 1, 1)
        self.conv5 = nn.Conv2d(128, 512, 5, 1, 2)
        self.conv7 = nn.Conv2d(128, 512, 7, 1, 3)
        self.convG1 = ConvG1()
        self.convG2 = ConvG2()
        self.convG3 = nn.Conv2d(128, 128, 3, 2, 1)
        self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)

    def forward(self,x):
        x_up = self.gp(x)
        # print("x_up",x_up.shape)
        x_conv1 = self.conv1(x)
        # print("x_conv1",x_conv1.shape)
        x_convG1 = self.convG1(x)
        # print("x_convG1",x_convG1.shape)
        x_conv7 = self.conv7(x_convG1)
        # print("x_conv7",x_conv7.shape)
        x_convG2 = self.convG2(x_convG1)
        # print("x_convG2",x_convG2.shape)
        x_conv5 = self.conv5(x_convG2)
        # print("x_conv5",x_conv5.shape)
        x_convG3 = self.convG3(x_convG2)
        # print("x_convG3",x_convG3.shape)
        x_conv3 = self.conv3(x_convG3)
        # print("x_conv3",x_conv3.shape)
        x_conv3_up = self.up(x_conv3) # 对 conv3结果上采样到4*4
        # print("conv3上采样后:",x_conv3_up.shape)
        x_conv3_5 = x_conv3_up + x_conv5 # conv3上采样后加conv5的结果
        # print("conv3和conv5相加后:",x_conv3_5.shape)
        x_conv5_up = self.up(x_conv3_5) # 对上一步相加的结果上采样到8*8
        # print("con5上采样后:",x_conv5_up.shape)
        x_conv5_7 = x_conv5_up + x_conv7 # 再加上conv7的结果
        # print("conv5和conv7相加后:",x_conv5_7.shape)
        x_conv7_up = self.up(x_conv5_7) # 对上一步的结果上采样到16*16
        # print("con7上采样后:",x_conv7_up.shape)
        x_conv7_1 = x_conv7_up * x_conv1
        # print("conv7和conv1相乘后:",x_conv7_1.shape)
        out = x_conv7_1 + x_up
        # print("再加上上分支后",out.shape)
        # print(out.shape)
        return out

ADM.py

import torch
from torch import nn

# 通道注意力机制(只保留通道):
# 输入分别经过(全局)最大池化和平均池化 -> b, c, 1, 1
# 再将两次池化的结果进行相同的全连接计算
# 将经过全连接计算的结果相加 取 Sigmoid()激活得到权值
# 再乘以输入
class Channel_attention(nn.Module):
    def __init__(self, channel, ratio=16):
        super(Channel_attention, self).__init__()
        self.max_pool = nn.AdaptiveMaxPool2d(1)
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(channel, channel // ratio),
            nn.ReLU(),
            nn.Linear(channel // ratio, channel),
        )
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        b, c, h, w = x.size()
        max_pool_out = self.max_pool(x).view([b, c])
        avg_pool_out = self.avg_pool(x).view([b, c])

        max_fc_out = self.fc(max_pool_out)
        avg_fc_out = self.fc(avg_pool_out)

        out = max_fc_out + avg_fc_out
        out = self.sigmoid(out).view([b, c, 1, 1])
        return out * x

# 空间注意力机制(只保留空间)
# 在通道这个维度上分别进行最大池化和平均池化 -> b, 1, h, w
# 将两次池化后的结果进行 cat 连接,再卷积(输入为2, 输出为1)
# 再对卷积后的结果的每一个像素进行 Sigmoid() 得到输出
class Spatial_attention(nn.Module):
    def __init__(self, kernel_size):
        super(Spatial_attention, self).__init__()
        self.conv = nn.Conv2d(2, 1, kernel_size, 1, 2, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        b, c, w, h = x.size()
        # torch.max()函数会返回两个tensor,第一个tensor是每行的最大值;第二个tensor是每行最大值的索引
        max_pool_out,_ = torch.max(x, dim=1, keepdim=True) # 对通道做最大池化 ,c的索引为 1
        avg_pool_out = torch.mean(x, dim=1, keepdim=True)
        pool_out = torch.cat([max_pool_out, avg_pool_out], dim=1)
        out = self.conv(pool_out)
        out = self.sigmoid(out)
        return out * x

class ADM(nn.Module):
    def __init__(self, channel,out_channel, ratio=16, kernel_size=5):
        super(ADM, self).__init__()
        self.channel_attention = Channel_attention(channel, ratio)
        self.spatial_attention = Spatial_attention(kernel_size)
        self.conv = nn.Conv2d(channel, out_channel, 1, 1)
        self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)

    def forward(self, x):
        x1 = self.channel_attention(x)
        x2 = self.spatial_attention(x)
        x = x1 + x2
        # 因为注意力机制不会更改图片的尺寸
        x = self.conv(x) # 改通道,为输入的一半
        return self.up(x) # 改尺寸,为输入的二倍

# cbam = ADM(512,256)
# x4 = torch.randn([2, 512, 16, 16])
# x5 = torch.randn([2, 512, 16, 16])
# output = cbam(x4+x5)
# print(output.shape)

model.py

import torch
import torch.nn as nn
import torch.nn.functional as F
from AFM import *
from ADM import *

# 用于ResNet18和34的残差块,用的是2个3x3的卷积
class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, 3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, 3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.shortcut = nn.Sequential()
        # 经过处理后的x要与x的维度相同(尺寸和深度)
        # 如果不相同,需要添加卷积+BN来变换为同一维度
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes, 1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out

class ACnet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=10):
        super(ACnet, self).__init__()
        self.in_planes = 64
        self.pre = nn.Sequential(
            nn.Conv2d(3, 64, 7, stride=2, padding=3, bias=False),# (224+2*p-)/2(向下取整)+1,size减半->112
            nn.BatchNorm2d(64), #112x112x64
            nn.ReLU(inplace=True),
            nn.MaxPool2d(3, 2, 1)
        )
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) #block 在网络调用时选哪个残差快
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) # num_blocks,选每一步用几遍残差快
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.linear = nn.Linear(512*block.expansion, num_classes)
        self.afm = AFM(512)
        self.adm1 = ADM(512,256)
        self.adm2 = ADM(256,128)
        self.adm3 = ADM(128,64)
        self.adm4 = ADM(64,32)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):     #512x512x3
        x = self.pre(x)
        print("经过卷积和池化的输出x:",x.shape)
        x1 = self.layer1(x) # 128*128*64
        print("经过第一个残差块的输出x1:",x1.shape)
        x2 = self.layer2(x1) # 64*64*128
        print("经过第二个残差块的输出x2:",x2.shape)
        x3 = self.layer3(x2) # 32*32*256
        print("经过第三个残差块的输出x3:",x3.shape)
        x4 = self.layer4(x3) # 16*16*512
        print("经过第四个残差块的输出x4:",x4.shape)
        x5 = self.afm(x4)
        print("经过AFM模块的输出x5:",x5.shape)
        x6 = self.adm1(x5+x4)
        print("经过第一个ADM的输出x6:",x6.shape)
        x7 = self.adm2(x3+x6)
        print("经过第二个ADM的输出x7:",x7.shape)
        x8 = self.adm3(x2+x7)
        print("经过第三个ADM的输出x8:",x8.shape)
        x9 = self.adm4(x1+x8)
        print("经过第四个ADM的输出x9:",x9.shape)
        return x1, x2, x3, x4

def ACNet():
    return ACnet(BasicBlock, [3,4,6,3])

def run():
    net = ACNet()
    net(torch.randn(1, 3, 512, 512))

if __name__ == "__main__":
    run()

输出:

经过卷积和池化的输出x: torch.Size([1, 64, 128, 128])
经过第一个残差块的输出x1: torch.Size([1, 64, 128, 128])
经过第二个残差块的输出x2: torch.Size([1, 128, 64, 64])
经过第三个残差块的输出x3: torch.Size([1, 256, 32, 32])
经过第四个残差块的输出x4: torch.Size([1, 512, 16, 16])
经过AFM模块的输出x5: torch.Size([1, 512, 16, 16])
经过第一个ADM的输出x6: torch.Size([1, 256, 32, 32])
经过第二个ADM的输出x7: torch.Size([1, 128, 64, 64])
经过第三个ADM的输出x8: torch.Size([1, 64, 128, 128])
经过第四个ADM的输出x9: torch.Size([1, 32, 256, 256])

  • 0
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值