【转载】U-Net A PyTorch Implementation in 60 lines of Code

转载并翻译自:https://amaarora.github.io/posts/2020-09-13-unet.html

1. 介绍

今天的博客文章会很简短。今天,我们将学习如何在PyTorch中用60行代码实现U-Net架构。 这篇博客是一步一步解释如何从头开始在PyTorch中实现U-Net。 在这篇博客中,首先我们会了解U-Net架构,特别是每个模块的输入和输出形状。我们将用工厂生产线的类比来简化和易于理解U-Net架构。接下来,我们将把对U-Net架构的理解转化为简洁的PyTorch代码。

2. 理解U-Net中的输入和输出形状

在这里插入图片描述

从图1可以看出,这个架构是“U形”的,因此被称为“U-Net”。完整的架构由两部分组成 - 编码器和解码器。图1左侧的部分(图2中的黄色高亮部分)是编码器,而右侧的部分是解码器(图2中的橙色高亮部分)。 根据论文:网络架构如图1所示。它由一个收缩路径(左侧)和一个扩张路径(右侧)组成。收缩路径遵循卷积网络的典型架构。

编码器类似于任何标准的卷积神经网络(如ResNet),从输入图像中提取有意义的特征图。与卷积神经网络的标准做法一样,编码器在每一步都将通道数加倍,并将空间维度减半。 接下来,解码器实际上是对特征图进行上采样,每一步都将空间维度加倍,并将通道数减半(与编码器相反)。

3. 工厂生产线类比

现在让我们来看一下U-Net,用一个工厂生产线的类比来解释,就像图2中所示。

在这里插入图片描述

我们可以把整个架构看作是一个工厂生产线,黑点代表装配站,路径本身是一个传送带,根据传送带是黄色还是橙色,对传送带上的图像进行不同的操作。 如果是黄色传送带,我们使用最大池化2x2(Max Pooling 2x2)操作对图像进行下采样,将图像的高度和宽度都减半。如果是橙色传送带,我们使用反卷积操作ConvTranspose2d,将图像的高度和宽度加倍,同时将通道数减半。因此,橙色传送带执行的操作与黄色传送带相反。 此外,需要注意的是,在解码器的每个装配站(黑点)上,编码器装配站的输出也会与输入连接起来。 现在让我们开始将这个简单的理解转化为PyTorch代码。

4. 黑点/ 块

所有的装配站 - 在图2中的黑点 - 有两个Conv2D操作,它们之间有ReLU激活。卷积操作的核大小为3,没有填充。因此,输出特征图的高度和宽度与输入特征图不同。 根据论文:收缩路径由两个3x3卷积(无填充卷积)的重复应用组成,每个卷积后面跟着一个修正线性单元(ReLU)和一个2x2的最大池化操作,步长为2,用于下采样。在每个下采样步骤中,我们将特征通道的数量加倍。 让我们用代码来写一个Block

class Block(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.conv1 = nn.Conv2d(in_ch, out_ch, 3)
        self.relu  = nn.ReLU()
        self.conv2 = nn.Conv2d(out_ch, out_ch, 3)
    
    def forward(self, x):
        return self.relu(self.conv2(self.relu(self.conv1(x))))

很简单 - 两个卷积操作,一个将输入通道数从in_ch加倍到out_ch,另一个从out_ch转换为out_ch。两者都是2D卷积,核大小为3,没有填充,如论文中所述,然后是ReLU激活。 我们来确保这个方法有效。

enc_block = Block(1, 64)
x         = torch.randn(1, 1, 572, 572)
enc_block(x).shape

>> torch.Size([1, 64, 568, 568])

看起来不错,输出的尺寸与图1左上角的尺寸相匹配。给定一个形状为1x572x572的输入图像,输出的形状为64x568x568

5. 编码器

现在我们已经在图2中实现了块(Block)或黑点,我们准备实现编码器。编码器是U-Net架构的收缩路径。 到目前为止,我们已经实现了卷积操作,但还没有实现下采样部分。正如论文中所提到的:>每个块后面都跟着一个2x2的最大池化操作,步长为2,用于下采样。 所以这就是我们需要做的,我们需要在两个块操作之间添加最大池化操作(图2中的黄色传送带)。

class Encoder(nn.Module):
    def __init__(self, chs=(3,64,128,256,512,1024)):
        super().__init__()
        self.enc_blocks = nn.ModuleList([Block(chs[i], chs[i+1]) for i in range(len(chs)-1)])
        self.pool       = nn.MaxPool2d(2)
    
    def forward(self, x):
        ftrs = []
        for block in self.enc_blocks:
            x = block(x)
            ftrs.append(x)
            x = self.pool(x)
        return ftrs

我们所做的就是这些。在编码器端,编码器块或self.enc_blocks是一个块操作的列表。接下来,我们对每个块的输出执行MaxPooling操作。由于我们还需要存储块的输出,我们将它们存储在一个名为ftrs的列表中,并返回该列表。 让我们确保这个实现是有效的。

encoder = Encoder()
# input image
x    = torch.randn(1, 3, 572, 572)
ftrs = encoder(x)
for ftr in ftrs: print(ftr.shape)

>> 
torch.Size([1, 64, 568, 568])
torch.Size([1, 128, 280, 280])
torch.Size([1, 256, 136, 136])
torch.Size([1, 512, 64, 64])
torch.Size([1, 1024, 28, 28])

输出的形状与图1中提到的形状完全匹配 - 到目前为止,一切都很好。实现了编码器后,我们现在准备进入解码器的部分。

6. 解码器

解码器是U-Net架构的扩展路径。 根据论文中的描述:> 扩展路径中的每一步都包括对特征图进行上采样,然后进行2x2的卷积(“上卷积”),将特征通道数量减半,然后将其与收缩路径中相应裁剪的特征图进行连接,并进行两次3x3的卷积,每次卷积后跟一个ReLU激活函数。由于每次卷积都会导致边界像素的丢失,所以裁剪是必要的。

请注意,我们已经在Block中实现了两次3x3卷积后跟ReLU激活函数的部分。我们只需要实现解码器中的“上卷积”(图2中的橙色高亮部分)以及与收缩路径中相应裁剪的特征图进行连接(图2中的灰色箭头)。 请注意,在PyTorch中,ConvTranspose2d操作执行“上卷积”。它接受参数,如in_channelsout_channelskernel_sizestride等。

由于解码器中的in_channelsout_channels值取决于执行此操作的位置,因此在实现中,“上卷积”操作也被存储为列表。如论文中所述,stride和kernel size始终为2。

现在,我们只需要执行特征连接操作。让我们看一下解码器的实现,以更清楚地了解所有这些的工作原理-

class Decoder(nn.Module):
    def __init__(self, chs=(1024, 512, 256, 128, 64)):
        super().__init__()
        self.chs         = chs
        self.upconvs    = nn.ModuleList([nn.ConvTranspose2d(chs[i], chs[i+1], 2, 2) for i in range(len(chs)-1)])
        self.dec_blocks = nn.ModuleList([Block(chs[i], chs[i+1]) for i in range(len(chs)-1)]) 
        
    def forward(self, x, encoder_features):
        for i in range(len(self.chs)-1):
            x        = self.upconvs[i](x)
            enc_ftrs = self.crop(encoder_features[i], x)
            x        = torch.cat([x, enc_ftrs], dim=1)
            x        = self.dec_blocks[i](x)
        return x
    
    def crop(self, enc_ftrs, x):
        _, _, H, W = x.shape
        enc_ftrs   = torchvision.transforms.CenterCrop([H, W])(enc_ftrs)
        return enc_ftrs

所以self.dec_blocks是一个Decoder Block的列表,它执行论文中提到的两个conv + ReLU操作。self.upconvs是一个ConvTranspose2d操作的列表,它执行“上卷积”操作。最后,在forward函数中,解码器接受由编码器输出的encoder_features,执行连接操作,然后将结果传递给Block

这就是U-Net解码器内部的全部内容。让我们确保这个实现是有效的:

decoder = Decoder()
x = torch.randn(1, 1024, 28, 28)
decoder(x, ftrs[::-1][1:]).shape

>> (torch.Size([1, 64, 388, 388])

这就是了,最终的特征图的大小是64x388x388,与图1的大小相匹配。到目前为止,我们已经成功实现了编码器和解码器。

你可能会问为什么要使用 ftrs[::-1][1:]? 还记得编码器的输出形状吗?它们是:

torch.Size([1, 64, 568, 568]) #0
torch.Size([1, 128, 280, 280]) #1
torch.Size([1, 256, 136, 136]) #2
torch.Size([1, 512, 64, 64]) #3
torch.Size([1, 1024, 28, 28]) #4

从图1中可以看出,形状为torch.Size([1, 1024, 28, 28]) 的特征图实际上并没有被连接起来,只进行了“上卷积”操作。此外,图1中的第一个解码器块接受来自第三个位置的编码器块的输入。类似地,第二个解码器块接受来自第二个位置的编码器块的输入,依此类推。因此,在将它们传递给解码器之前,编码器特征被反转,而形状为torch.Size([1, 1024, 28, 28]) 的特征图不会被传递。 因此,解码器的输入是 ftrs[::-1][1:]

7. U-Net

太棒了,我们目前已经实现了U-Net架构的编码器和解码器。现在让我们把它们整合起来,完成我们的U-Net实现吧。

class UNet(nn.Module):
    def __init__(self, enc_chs=(3,64,128,256,512,1024), dec_chs=(1024, 512, 256, 128, 64), num_class=1, retain_dim=False, out_sz=(572,572)):
        super().__init__()
        self.encoder     = Encoder(enc_chs)
        self.decoder     = Decoder(dec_chs)
        self.head        = nn.Conv2d(dec_chs[-1], num_class, 1)
        self.retain_dim  = retain_dim

    def forward(self, x):
        enc_ftrs = self.encoder(x)
        out      = self.decoder(enc_ftrs[::-1][0], enc_ftrs[::-1][1:])
        out      = self.head(out)
        if self.retain_dim:
            out = F.interpolate(out, out_sz)
        return out

让我们确保这个实现是有效的:

unet = UNet()
x    = torch.randn(1, 3, 572, 572)
unet(x).shape

>> torch.Size([4, 1, 388, 388])

输出形状与图1相匹配。 如前所述,由于卷积操作是3x3且没有填充,输出特征图的大小与输入特征图的大小不同。同时,如图1所示,最终输出的形状是1x388x388,而输入图像的尺寸是572x572。这在计算PyTorch中的BCELoss时可能会产生问题,因为它期望输入和输出特征图具有相同的形状。 因此,如果我们想保留维度,我在U-Net中添加了F.interpolate操作,使输出大小与输入图像大小相同。 太棒了,我们刚刚成功地在PyTorch中实现了U-Net架构。

将所有内容放在一起,看起来像这样:

class Block(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.conv1 = nn.Conv2d(in_ch, out_ch, 3)
        self.relu  = nn.ReLU()
        self.conv2 = nn.Conv2d(out_ch, out_ch, 3)
    
    def forward(self, x):
        return self.conv2(self.relu(self.conv1(x)))


class Encoder(nn.Module):
    def __init__(self, chs=(3,64,128,256,512,1024)):
        super().__init__()
        self.enc_blocks = nn.ModuleList([Block(chs[i], chs[i+1]) for i in range(len(chs)-1)])
        self.pool       = nn.MaxPool2d(2)
    
    def forward(self, x):
        ftrs = []
        for block in self.enc_blocks:
            x = block(x)
            ftrs.append(x)
            x = self.pool(x)
        return ftrs


class Decoder(nn.Module):
    def __init__(self, chs=(1024, 512, 256, 128, 64)):
        super().__init__()
        self.chs         = chs
        self.upconvs    = nn.ModuleList([nn.ConvTranspose2d(chs[i], chs[i+1], 2, 2) for i in range(len(chs)-1)])
        self.dec_blocks = nn.ModuleList([Block(chs[i], chs[i+1]) for i in range(len(chs)-1)]) 
        
    def forward(self, x, encoder_features):
        for i in range(len(self.chs)-1):
            x        = self.upconvs[i](x)
            enc_ftrs = self.crop(encoder_features[i], x)
            x        = torch.cat([x, enc_ftrs], dim=1)
            x        = self.dec_blocks[i](x)
        return x
    
    def crop(self, enc_ftrs, x):
        _, _, H, W = x.shape
        enc_ftrs   = torchvision.transforms.CenterCrop([H, W])(enc_ftrs)
        return enc_ftrs


class UNet(nn.Module):
    def __init__(self, enc_chs=(3,64,128,256,512,1024), dec_chs=(1024, 512, 256, 128, 64), num_class=1, retain_dim=False, out_sz=(572,572)):
        super().__init__()
        self.encoder     = Encoder(enc_chs)
        self.decoder     = Decoder(dec_chs)
        self.head        = nn.Conv2d(dec_chs[-1], num_class, 1)
        self.retain_dim  = retain_dim

    def forward(self, x):
        enc_ftrs = self.encoder(x)
        out      = self.decoder(enc_ftrs[::-1][0], enc_ftrs[::-1][1:])
        out      = self.head(out)
        if self.retain_dim:
            out = F.interpolate(out, out_sz)
        return out
  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值