HRNet-3D-Segmentaion 由2D改为3D网络C=64

将HRNet 2D形式改为3D,并适应3D数据格式(h,w,d), 在最后阶段加两个Upsample操作,恢复到输入尺寸

如果需要运行成功需要更改几个设置(torchsummary可能报错)

因为代码使用nn.Sequential()操作,生成了list,不是Tensor,所以会报错

  1. 在Anaconda3/envs/your_enviro/Lib/site-package/torchsummary/torchsummary.py修改几行内容
         #summary[m_key]["input_shape"] = list(input[0].size()) 注释掉的内容,更改为下面内容
         if isinstance(input[0], torch.Tensor):
             summary[m_key]["input_shape"] = list(input[0].size())
         elif isinstance(input[0], list):
  1. 在Anaconda3/envs/你自己的环境/lib/python3.7/site-packages/torch/tensor.py
# return self.numpy()
return self.cpu().detach().numpy()

如下建立一个.py文件可直接运行

该改动为3D数据,格式如 H x W x D, 输入模型数据为 (batch , height, width, depth)

import torch.nn as nn
import torch
from torchsummary import summary
BN_MOMENTUM = 0.1

'''
    [conv -> bn -> relu -> conv -> bn -> Residual -> relu  
'''
class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv3d(inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm3d(planes, momentum=BN_MOMENTUM)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv3d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm3d(planes, momentum=BN_MOMENTUM)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out

'''
    2x[conv -> bn -> relu] ->  Residual -> relu  
'''
class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv3d(inplanes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm3d(planes, momentum=BN_MOMENTUM)
        self.conv2 = nn.Conv3d(planes, planes, kernel_size=3, stride=stride,
                               padding=1, bias=False)
        self.bn2 = nn.BatchNorm3d(planes, momentum=BN_MOMENTUM)
        self.conv3 = nn.Conv3d(planes, planes * self.expansion, kernel_size=1,
                               bias=False)
        self.bn3 = nn.BatchNorm3d(planes * self.expansion,
                                  momentum=BN_MOMENTUM)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out


class StageModule(nn.Module):
    def __init__(self, input_branches, output_branches, c):
        """
        构建对应stage,即用来融合不同尺度的实现
        :param input_branches: 输入的分支数,每个分支对应一种尺度
        :param output_branches: 输出的分支数
        :param c: 输入的第一个分支通道数
        """
        super().__init__()
        self.input_branches = input_branches
        self.output_branches = output_branches

        self.branches = nn.ModuleList() # 存储每一个branch上的block
        for i in range(self.input_branches):  # 每个分支上都先通过不同个BasicBlock
            w = c * (2 ** i)  # 对应第i个分支的通道数,每一层的通道数要翻倍
            branch = nn.Sequential(
                BasicBlock(w, w),
                BasicBlock(w, w),
                BasicBlock(w, w),
                BasicBlock(w, w)
            )
            self.branches.append(branch)  # 每一个分支上的Block已构建好

        self.fuse_layers = nn.ModuleList()  # 用于融合每个分支上的输出
        for i in range(self.output_branches):
            self.fuse_layers.append(nn.ModuleList())
            for j in range(self.input_branches):
                if i == j:
                    # 当输入、输出为同一个分支时不做任何处理
                    self.fuse_layers[-1].append(nn.Identity())
                elif i < j:
                    # 当输入分支j大于输出分支i时(即输入分支下采样率大于输出分支下采样率),
                    # 此时需要对输入分支j进行通道调整以及上采样,方便后续相加
                    self.fuse_layers[-1].append(
                        nn.Sequential(
                            nn.Conv3d(c * (2 ** j), c * (2 ** i), kernel_size=1, stride=1, bias=False),
                            nn.BatchNorm3d(c * (2 ** i), momentum=BN_MOMENTUM),
                            nn.Upsample(scale_factor=2.0 ** (j - i), mode='trilinear')
                        )
                    )
                else:  # i > j
                    # 当输入分支j小于输出分支i时(即输入分支下采样率小于输出分支下采样率),
                    # 此时需要对输入分支j进行通道调整以及下采样,方便后续相加
                    # 注意,这里每次下采样2x都是通过一个3x3卷积层实现的,4x就是两个,8x就是三个,总共i-j个
                    ops = []
                    # 前i-j-1个卷积层不用变通道,只进行下采样
                    for k in range(i - j - 1):
                        ops.append(
                            nn.Sequential(
                                nn.Conv3d(c * (2 ** j), c * (2 ** j), kernel_size=3, stride=2, padding=1, bias=False),
                                nn.BatchNorm3d(c * (2 ** j), momentum=BN_MOMENTUM),
                                nn.ReLU(inplace=True)
                            )
                        )
                    # 最后一个卷积层不仅要调整通道,还要进行下采样
                    ops.append(
                        nn.Sequential(
                            nn.Conv3d(c * (2 ** j), c * (2 ** i), kernel_size=3, stride=2, padding=1, bias=False),
                            nn.BatchNorm3d(c * (2 ** i), momentum=BN_MOMENTUM)
                        )
                    )
                    self.fuse_layers[-1].append(nn.Sequential(*ops))

        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        # 每个分支通过对应的block
        x = [branch(xi) for branch, xi in zip(self.branches, x)]

        # 接着融合不同尺寸信息
        x_fused = []
        for i in range(len(self.fuse_layers)):
            x_fused.append(
                self.relu(
                    sum([self.fuse_layers[i][j](x[j]) for j in range(len(self.branches))]) # 第j个输出分支对 前面不同分支的输出进行处理,包括不处理(Indenty) 上采样x2 、 x4 ,相加
                )
            )

        return x_fused


class HighResolutionNet(nn.Module):
    def __init__(self, base_channel: int = 32, output_channels: int = 6):
        super().__init__()
        '''
        Stem层, 初始图像带步长卷积下采样了两次,变成1/4尺寸的特征图和c=64)  
        然后进入Layer1. input: 1/4的尺寸+base channel * 4的通道。 只调整channel数
                       有两个分支,分支1再变为base channel /2 , 分支2变为 1/2尺寸+  base channel
        '''
        self.conv1 = nn.Conv3d(1, 64, kernel_size=3, stride=2, padding=1, bias=False)
        self.bn1 = nn.BatchNorm3d(64, momentum=BN_MOMENTUM)
        self.conv2 = nn.Conv3d(64, 64, kernel_size=3, stride=2, padding=1, bias=False)
        self.bn2 = nn.BatchNorm3d(64, momentum=BN_MOMENTUM)
        self.relu = nn.ReLU(inplace=True)

        # Stage1
        downsample = nn.Sequential(
            nn.Conv3d(64, 256, kernel_size=1, stride=1, bias=False),
            nn.BatchNorm3d(256, momentum=BN_MOMENTUM)
        )
        '''
        Layer1 在不同的stage 一直卷
        '''
        self.layer1 = nn.Sequential(
            Bottleneck(64, 64, downsample=downsample), #ResNet bottleneck 操作,输入为c,输出为4c
            Bottleneck(256, 64),
            Bottleneck(256, 64),
            Bottleneck(256, 64)
        )

        self.transition1 = nn.ModuleList([  # 两个分支,1/4尺寸和1/8尺寸+
            nn.Sequential(
                nn.Conv3d(256, base_channel, kernel_size=3, stride=1, padding=1, bias=False),
                nn.BatchNorm3d(base_channel, momentum=BN_MOMENTUM),
                nn.ReLU(inplace=True)
            ),
            nn.Sequential(
                nn.Sequential(  # 这里又使用一次Sequential是为了适配原项目中提供的权重
                    nn.Conv3d(256, base_channel * 2, kernel_size=3, stride=2, padding=1, bias=False),
                    nn.BatchNorm3d(base_channel * 2, momentum=BN_MOMENTUM),
                    nn.ReLU(inplace=True)
                )
            )
        ])

        # Stage2
        self.stage2 = nn.Sequential(
            StageModule(input_branches=2, output_branches=2, c=base_channel)
        )

        # transition2 ,先对Stage2输出的两个Block不做处理,下采样第二个Block
        self.transition2 = nn.ModuleList([
            nn.Identity(),  # None,  - Used in place of "None" because it is callable
            nn.Identity(),  # None,  - Used in place of "None" because it is callable
            nn.Sequential(
                nn.Sequential(
                    nn.Conv3d(base_channel * 2, base_channel * 4, kernel_size=3, stride=2, padding=1, bias=False),
                    nn.BatchNorm3d(base_channel * 4, momentum=BN_MOMENTUM),
                    nn.ReLU(inplace=True)
                )
            )
        ])

        # Stage3
        self.stage3 = nn.Sequential(
            StageModule(input_branches=3, output_branches=3, c=base_channel),
            StageModule(input_branches=3, output_branches=3, c=base_channel),
            StageModule(input_branches=3, output_branches=3, c=base_channel),
            StageModule(input_branches=3, output_branches=3, c=base_channel)
        )

        # transition3
        self.transition3 = nn.ModuleList([
            nn.Identity(),  # None,  - Used in place of "None" because it is callable
            nn.Identity(),  # None,  - Used in place of "None" because it is callable
            nn.Identity(),  # None,  - Used in place of "None" because it is callable
            nn.Sequential(
                nn.Sequential(
                    nn.Conv3d(base_channel * 4, base_channel * 8, kernel_size=3, stride=2, padding=1, bias=False),
                    nn.BatchNorm3d(base_channel * 8, momentum=BN_MOMENTUM),
                    nn.ReLU(inplace=True)
                )
            )
        ])

        # Stage4
        # 注意,最后一个StageModule只输出分辨率最高的特征层
        self.stage4 = nn.Sequential(
            StageModule(input_branches=4, output_branches=4, c=base_channel),
            StageModule(input_branches=4, output_branches=4, c=base_channel),
            StageModule(input_branches=4, output_branches=1, c=base_channel)
        )

        self.up = nn.Upsample(scale_factor=2, mode='trilinear', align_corners=True)
        # Final layer
        self.final_layer = nn.Conv3d(base_channel, output_channels, kernel_size=1, stride=1)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu(x)  # stem

        x = self.layer1(x)
        x = [trans(x) for trans in self.transition1]  # x变成了一个列表。每个Stage有好几个输出

        x = self.stage2(x)  # 把前一层的x输入传入
        x = [
            self.transition2[0](x[0]),
            self.transition2[1](x[1]),
            self.transition2[2](x[-1])
        ]  # New branch derives from the "upper" branch only

        x = self.stage3(x)
        x = [
            self.transition3[0](x[0]),
            self.transition3[1](x[1]),
            self.transition3[2](x[2]),
            self.transition3[3](x[-1]),
        ]  # New branch derives from the "upper" branch only

        x = self.stage4(x) # 4个输入分支,1个输出分支

        x = self.final_layer(x[0])
        x = self.up(x)
        x = self.up(x)
        print('x shape', x.shape)

        return x

if __name__ == '__main__':
    torch.cuda.set_device(0)
    network = HighResolutionNet()
    net = network.cuda().eval()

    summary(net,(1,96,96,96))

参考文章:
https://blog.csdn.net/qq_37541097/article/details/124346626
论文名称: Deep High-Resolution Representation Learning for Human Pose Estimation
论文下载地址:https://arxiv.org/abs/1902.09212
官方源码地址:https://github.com/leoxiaobin/deep-high-resolution-net.pytorch

在bilibili上的视频讲解:https://www.bilibili.com/video/BV1bB4y1y7qP

  • 0
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值