Unet分割模型——pytorch代码

import torch
import torch.nn as nn


class pub(nn.Module):
    def __init__(self, in_channel, out_channel, batch_norm=True, keep_size=False):
        super(pub, self).__init__()
        pad = 1 if keep_size else 0
        Layer = [
                 nn.Conv2d(in_channel, out_channel, 3, padding=pad),
                 nn.ReLU(True),
                 nn.Conv2d(out_channel, out_channel, 3, padding=pad),
                 nn.ReLU(True)
                ]
        if batch_norm:
            Layer.insert(1, nn.BatchNorm2d(out_channel))
            Layer.insert(len(Layer) - 1, nn.BatchNorm2d(out_channel))
        self.pub_con = nn.Sequential(*Layer)

    def forward(self, x):
        return self.pub_con(x)


class unet_down(nn.Module):

    def __init__(self, in_channel, out_channel, batch_norm=True, keep_size=False):
        super(unet_down, self).__init__()
        self.pub = pub(in_channel, out_channel, batch_norm, keep_size)
        self.pool = nn.MaxPool2d(2, 2)

    def forward(self, x):
        x = self.pool(x)
        x = self.pub(x)
        return x


class unet_up(nn.Module):

    def __init__(self, in_channel, out_channel, batch_norm=True, upsample=True, keep_size=False):
        super(unet_up, self).__init__()
        layers = []
        if upsample:
            layers += [nn.Conv2d(out_channel*2, out_channel, 1)]
            layers += [nn.Upsample(scale_factor=2, mode='nearest')]
        else:
            layers += [nn.ConvTranspose2d(out_channel*2, out_channel, 2, stride=2)]
        self.upsample = nn.Sequential(*layers)
        self.pub = pub(in_channel, out_channel, batch_norm, keep_size)
        self.orignal_size = keep_size

    def forward(self, x1, x2):
        x2 = self.upsample(x2)
        c = (x1.size(2) - x2.size(2)) // 2
        x1 = x1[:, :, c:-c, c:-c]
        x = torch.cat((x1, x2), 1)
        x = self.pub(x)
        return x


class Unet(nn.Module):
    def __init__(self, channels, class_nums, layers=5, upsample=True, batch_norm=True, keep_size=False):
        super(Unet, self).__init__()
        self.layers = layers
        down = []
        up = []
        down.append(pub(channels, 64, batch_norm, keep_size))
        for layer in range(layers-1):
            down.append(unet_down(64*(2**layer), 128*(2**layer), batch_norm, keep_size))
            up.append(unet_up(128*(2**(3-layer)), 64*(2**(3-layer)), upsample, batch_norm, keep_size))
        up.append(nn.Conv2d(64, class_nums, 1))
        self.down = nn.ModuleList(down)
        self.up = nn.ModuleList(up)
        self._initialize_weights()
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        down = []
        for i in range(self.layers):
            x = self.down[i](x)
            down.append(x)
        x = down[self.layers-1]
        for j in range(self.layers-1):
            x = self.up[j](down[self.layers-j-2], x)
        x = self.up[4](x)
        return self.sigmoid(x)

    def _initialize_weights(self):
            for m in self.modules():
                if isinstance(m, nn.Conv2d):
                    nn.init.kaiming_uniform(m.weight.data)
                    if m.bias is not None:
                        m.bias.data.zero_()
                elif isinstance(m, nn.BatchNorm2d):
                    m.weight.data.fill_(1)
                    m.bias.data.zero_()
以下是 Unet 模型PyTorch 实现代码: ```python import torch import torch.nn as nn class DoubleConv(nn.Module): def __init__(self, in_channels, out_channels): super().__init__() self.conv = nn.Sequential( nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True), nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True) ) def forward(self, x): return self.conv(x) class UNet(nn.Module): def __init__(self, in_channels=3, out_channels=1, features=[64, 128, 256, 512]): super().__init__() self.ups = nn.ModuleList() self.downs = nn.ModuleList() self.pool = nn.MaxPool2d(kernel_size=2, stride=2) # Down part of UNet for feature in features: self.downs.append(DoubleConv(in_channels, feature)) in_channels = feature # Up part of UNet for feature in reversed(features): self.ups.append(nn.ConvTranspose2d(feature*2, feature, kernel_size=2, stride=2)) self.ups.append(DoubleConv(feature*2, feature)) self.bottleneck = DoubleConv(features[-1], features[-1]*2) self.final_conv = nn.Conv2d(features[], out_channels, kernel_size=1) def forward(self, x): skip_connections = [] for down in self.downs: x = down(x) skip_connections.append(x) x = self.pool(x) x = self.bottleneck(x) skip_connections = skip_connections[::-1] for idx in range(, len(self.ups), 2): x = self.ups[idx](x) skip_connection = skip_connections[idx//2] if x.shape != skip_connection.shape: x = nn.functional.interpolate(x, size=skip_connection.shape[2:], mode='bilinear', align_corners=True) concat_skip = torch.cat((skip_connection, x), dim=1) x = self.ups[idx+1](concat_skip) return self.final_conv(x) ``` 希望对你有所帮助!
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

super_lsl

谢谢你的欣赏

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值