ResNet34+Unet(可以直接用)

import torch
from torch import nn
import torch.nn.functional as F


# 因为ResNet34包含重复的单元,故用ResidualBlock类来简化代码
class ResidualBlock(nn.Module):
    def __init__(self, inchannel, outchannel, stride, shortcut=None):
        super(ResidualBlock, self).__init__()
        self.basic = nn.Sequential(
            nn.Conv2d(inchannel, outchannel, 3, stride, 1,
                      bias=False),  # 要采样的话在这里改变stride
            nn.BatchNorm2d(outchannel),  # 批处理正则化
            nn.ReLU(inplace=True),  # 激活
            nn.Conv2d(outchannel, outchannel, 3, 1, 1,
                      bias=False),  # 采样之后注意保持feature map的大小不变
            nn.BatchNorm2d(outchannel),
        )
        self.shortcut = shortcut

    def forward(self, x):
        out = self.basic(x)
        residual = x if self.shortcut is None else self.shortcut(x)  # 计算残差
        out += residual
        return nn.ReLU(inplace=True)(out)  # 注意激活


class Conv2dReLU(nn.Module):
    def __init__(
            self,
            in_channels,
            out_channels,
            kernel_size,
            padding=0,
            stride=1,
    ):
        super(Conv2dReLU, self).__init__()
        self.conv = nn.Conv2d(
            in_channels,
            out_channels,
            kernel_size,
            stride=stride,
            padding=padding,
            bias=False
        )
        self.bn = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.relu(self.bn(self.conv(x)))
        return x


class DecoderBlock(nn.Module):
    def __init__(
            self,
            in_channels,
            skip_channels,
            out_channels,
    ):
        super().__init__()
        self.conv1 = Conv2dReLU(
            in_channels + skip_channels,
            out_channels,
            kernel_size=3,
            padding=1,
        )
        self.conv2 = Conv2dReLU(
            out_channels,
            out_channels,
            kernel_size=3,
            padding=1,
        )

    def forward(self, x, skip=None):
        x = F.interpolate(x, scale_factor=2, mode="nearest")
        if skip is not None:
            x = torch.cat([x, skip], dim=1)

        x = self.conv1(x)
        x = self.conv2(x)

        return x


class SegmentationHead(nn.Sequential):
    def __init__(self,
                 in_channels=16,
                 out_channels=1,
                 kernel_size=3,
                 upsampling=1):
        conv2d = nn.Conv2d(in_channels,
                           out_channels,
                           kernel_size=kernel_size,
                           padding=kernel_size // 2)
        upsampling = nn.UpsamplingBilinear2d(
            scale_factor=upsampling) if upsampling > 1 else nn.Identity()
        super().__init__(conv2d, upsampling)


# ResNet类
class Resnet34(nn.Module):
    def __init__(self, inchannels):
        super(Resnet34, self).__init__()
        self.pre = nn.Sequential(
            nn.Conv2d(inchannels, 64, 7, 2, 3, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(3, 2, 1),
        )  # 开始的部分
        self.body = self.makelayers([3, 4, 6, 3])  # 具有重复模块的部分
        in_channels = [512, 256, 128, 128, 32]
        skip_channels = [256, 128, 64, 0, 0]
        out_channels = [256, 128, 64, 32, 16]
        blocks = [
            DecoderBlock(in_ch, skip_ch,
                         out_ch) for in_ch, skip_ch, out_ch in zip(
                             in_channels, skip_channels, out_channels)
        ]
        self.blocks = nn.ModuleList(blocks)
        self.seg = SegmentationHead()

    def makelayers(self, blocklist):  # 注意传入列表而不是解列表
        self.layers = []
        for index, blocknum in enumerate(blocklist):
            if index != 0:
                shortcut = nn.Sequential(
                    nn.Conv2d(64 * 2**(index - 1),
                              64 * 2**index,
                              1,
                              2,
                              bias=False),
                    nn.BatchNorm2d(64 * 2**index))  # 使得输入输出通道数调整为一致
                self.layers.append(
                    ResidualBlock(64 * 2**(index - 1), 64 * 2**index, 2,
                                  shortcut))  # 每次变化通道数时进行下采样
            for i in range(0 if index == 0 else 1, blocknum):
                self.layers.append(
                    ResidualBlock(64 * 2**index, 64 * 2**index, 1))
        return nn.Sequential(*self.layers)

    def forward(self, x):
        self.features = []
        # 下采样
        # x = self.pre(x)
        for i, l in enumerate(self.pre):
            x = l(x)
            if i == 2:
                self.features.append(x)

        print("y=", len(self.features))
        for i, l in enumerate(self.body):
            if i == 3 or i == 7 or i == 13:
                self.features.append(x)
            x = l(x)
        skips = self.features[::-1]

        # skips = self.features[1:]

        for i, decoder_block in enumerate(self.blocks):
            skip = skips[i] if i < len(skips) else None
            x = decoder_block(x, skip)

        x = self.seg(x)
        return x



四次Skipconnect分别在:Maxpool前;另外三次在通道数变化前。
上采样combine时采用的是插值(nn.functionnal.interpolate)。

这里是一个简单的ResNet-34与U-Net结合的代码示例,用于图像分割任务: ``` python import torch import torch.nn as nn import torch.nn.functional as F class ResNetBlock(nn.Module): def __init__(self, in_channels, out_channels): super(ResNetBlock, self).__init__() self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False) self.bn1 = nn.BatchNorm2d(out_channels) self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False) self.bn2 = nn.BatchNorm2d(out_channels) self.relu = nn.ReLU(inplace=True) def forward(self, x): identity = x out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) out += identity out = self.relu(out) return out class ResNetEncoder(nn.Module): def __init__(self, in_channels, out_channels): super(ResNetEncoder, self).__init__() self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=7, stride=2, padding=3, bias=False) self.bn1 = nn.BatchNorm2d(out_channels) self.relu = nn.ReLU(inplace=True) self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) self.layer1 = nn.Sequential( ResNetBlock(out_channels, out_channels), ResNetBlock(out_channels, out_channels), ResNetBlock(out_channels, out_channels), ) self.layer2 = nn.Sequential( ResNetBlock(out_channels, out_channels*2), ResNetBlock(out_channels*2, out_channels*2), ResNetBlock(out_channels*2, out_channels*2), ) self.layer3 = nn.Sequential( ResNetBlock(out_channels*2, out_channels*4), ResNetBlock(out_channels*4, out_channels*4), ResNetBlock(out_channels*4, out_channels*4), ResNetBlock(out_channels*4, out_channels*4), ResNetBlock(out_channels*4, out_channels*4), ResNetBlock(out_channels*4, out_channels*4), ) self.layer4 = nn.Sequential( ResNetBlock(out_channels*4, out_channels*8), ResNetBlock(out_channels*8, out_channels*8), ResNetBlock(out_channels*8, out_channels*8), ) def forward(self, x): x = self.conv1(x) x = self.bn1(x) x = self.relu(x) x = self.maxpool(x) x1 = self.layer1(x) x2 = self.layer2(x1) x3 = self.layer3(x2) x4 = self.layer4(x3) return x1, x2, x3, x4 class ResNetDecoder(nn.Module): def __init__(self, in_channels, out_channels): super(ResNetDecoder, self).__init__() self.layer1 = nn.Sequential( ResNetBlock(in_channels*8 + out_channels*8, out_channels*4), ResNetBlock(out_channels*4, out_channels*4), ResNetBlock(out_channels*4, out_channels*4), ) self.layer2 = nn.Sequential( ResNetBlock(in_channels*4 + out_channels*4, out_channels*2), ResNetBlock(out_channels*2, out_channels*2), ResNetBlock(out_channels*2, out_channels*2), ) self.layer3 = nn.Sequential( ResNetBlock(in_channels*2 + out_channels*2, out_channels), ResNetBlock(out_channels, out_channels), ResNetBlock(out_channels, out_channels), ) self.layer4 = nn.Sequential( ResNetBlock(in_channels + out_channels, out_channels), ResNetBlock(out_channels, out_channels), ResNetBlock(out_channels, out_channels), ) self.conv = nn.Conv2d(out_channels, 1, kernel_size=1) def forward(self, x1, x2, x3, x4): x = F.interpolate(x4, scale_factor=2) x = torch.cat([x, x3], dim=1) x = self.layer1(x) x = F.interpolate(x, scale_factor=2) x = torch.cat([x, x2], dim=1) x = self.layer2(x) x = F.interpolate(x, scale_factor=2) x = torch.cat([x, x1], dim=1) x = self.layer3(x) x = F.interpolate(x, scale_factor=2) x = self.layer4(x) x = self.conv(x) return x class ResUNet(nn.Module): def __init__(self, in_channels, out_channels): super(ResUNet, self).__init__() self.encoder = ResNetEncoder(in_channels, out_channels) self.decoder = ResNetDecoder(in_channels, out_channels) def forward(self, x): x1, x2, x3, x4 = self.encoder(x) out = self.decoder(x1, x2, x3, x4) return out ``` 在这个示例中,我们使用了ResNet-34结构作为编码器,并将其与U-Net结构的解码器相结合。该模型接受大小为(in_channels, H, W)的图像作为输入,并输出大小为(1, H, W)的二进制掩模。
评论 7
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值