深度残差收缩网络的完整PyTorch代码

深度残差收缩网络的完整PyTorch代码

1、基础理论

深度残差收缩网络是建立在三个部分的基础之上的,包括残差网络、注意力机制和软阈值化。
在这里插入图片描述
其功能特色包括:

1)由于软阈值化是信号降噪算法的常用步骤,所以深度残差收缩网络比较适合强噪、高冗余数据。同时,软阈值化的梯度要么为0,要么为1,这与ReLU激活函数是相似/一致的。

在这里插入图片描述

2)由于软阈值化的阈值是通过类似于SENet的注意力机制自适应地进行设置的,深度残差收缩网络能够根据每个样本的情况,为每个样本单独地设置阈值,因此适用于每个样本内噪声含量不同的情况。

3)当数据噪声很弱、没有噪声时,深度残差收缩网络可能也是适用的。其前提是阈值可以被训练成非常接近于0的值,从而软阈值化就相当于不存在了。

4)值得注意的是,软阈值函数的阈值不能太大,否则会导致所有的输出都是0。所以深度残差收缩网络的注意力模块是经过专门设计的,与一般的SENet是存在明显区别的。

该方法的文献来源:

M. Zhao, S. Zhong, X. Fu, B. Tang, M. Pecht, Deep residual shrinkage networks for fault diagnosis, IEEE Transactions on Industrial Informatics, vol. 16, no. 7, pp. 4681-4690, 2020. (https://ieeexplore.ieee.org/document/8850096/

2、PyTorch代码

本文的PyTorch代码是在这份代码(https://github.com/weiaicunzai/pytorch-cifar100)的基础上修改得到的,所以要下载这份代码到本地。主要是修改了models/resnet.py(https://github.com/weiaicunzai/pytorch-cifar100/blob/master/models/resnet.py)和utils.py(https://github.com/weiaicunzai/pytorch-cifar100/blob/master/utils.py)的代码。

另一方面,残差收缩网络的核心代码,则是来源于知乎上最前线创作的一篇文章《用于故障诊断的残差收缩网络》(https://zhuanlan.zhihu.com/p/337346575)。

具体地,将resnet.py文件的名称,改为了rsnet.py,意思是residual shrinkage network。修改后的rsnet.py代码如下:

import torch
import torch.nn as nn

class BasicBlock(nn.Module):

    expansion = 1
    
    def __init__(self, in_channels, out_channels, stride=1):
        super().__init__()
        self.shrinkage = Shrinkage(out_channels, gap_size=(1, 1))
        #residual function
        self.residual_function = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels * BasicBlock.expansion, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels * BasicBlock.expansion),
            self.shrinkage
        )
        #shortcut
        self.shortcut = nn.Sequential()

        #the shortcut output dimension is not the same with residual function
        #use 1*1 convolution to match the dimension
        if stride != 1 or in_channels != BasicBlock.expansion * out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels * BasicBlock.expansion, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels * BasicBlock.expansion)
            )

    def forward(self, x):
        return nn.ReLU(inplace=True)(self.residual_function(x) + self.shortcut(x))


class Shrinkage(nn.Module):
    def __init__(self,  channel, gap_size):
        super(Shrinkage, self).__init__()
        self.gap = nn.AdaptiveAvgPool2d(gap_size)
        self.fc = nn.Sequential(
            nn.Linear(channel, channel),
            nn.BatchNorm1d(channel),
            nn.ReLU(inplace=True),
            nn.Linear(channel, channel),
            nn.Sigmoid(),
        )

    def forward(self, x):
        x_raw = x
        x = torch.abs(x)
        x_abs = x
        x = self.gap(x)
        x = torch.flatten(x, 1)
        # average = torch.mean(x, dim=1, keepdim=True)
        average = x
        x = self.fc(x)
        x = torch.mul(average, x)
        x = x.unsqueeze(2).unsqueeze(2)
        # soft thresholding
        sub = x_abs - x
        zeros = sub - sub
        n_sub = torch.max(sub, zeros)
        x = torch.mul(torch.sign(x_raw), n_sub)
        return x

class RSNet(nn.Module):

    def __init__(self, block, num_block, num_classes=100):
        super().__init__()

        self.in_channels = 64

        self.conv1 = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True))
        #we use a different inputsize than the original paper
        #so conv2_x's stride is 1
        self.conv2_x = self._make_layer(block, 64, num_block[0], 1)
        self.conv3_x = self._make_layer(block, 128, num_block[1], 2)
        self.conv4_x = self._make_layer(block, 256, num_block[2], 2)
        self.conv5_x = self._make_layer(block, 512, num_block[3], 2)
        self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512 * block.expansion, num_classes)

    def _make_layer(self, block, out_channels, num_blocks, stride):
        """make rsnet layers(by layer i didnt mean this 'layer' was the
        same as a neuron netowork layer, ex. conv layer), one layer may
        contain more than one residual shrinkage block

        Args:
            block: block type, basic block or bottle neck block
            out_channels: output depth channel number of this layer
            num_blocks: how many blocks per layer
            stride: the stride of the first block of this layer

        Return:
            return a rsnet layer
        """

        # we have num_block blocks per layer, the first block
        # could be 1 or 2, other blocks would always be 1
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_channels, out_channels, stride))
            self.in_channels = out_channels * block.expansion

        return nn.Sequential(*layers)

    def forward(self, x):
        output = self.conv1(x)
        output = self.conv2_x(output)
        output = self.conv3_x(output)
        output = self.conv4_x(output)
        output = self.conv5_x(output)
        output = self.avg_pool(output)
        output = output.view(output.size(0), -1)
        output = self.fc(output)

        return output

def rsnet18():
    """ return a RSNet 18 object
    """
    return RSNet(BasicBlock, [2, 2, 2, 2])

def rsnet34():
    """ return a RSNet 34 object
    """
    return RSNet(BasicBlock, [3, 4, 6, 3])

然后,将utils.py文件中的第62-64行:

    elif args.net == 'resnet18':
        from models.resnet import resnet18
        net = resnet18()

修改为:

    elif args.net == 'rsnet18':
        from models.rsnet import rsnet18
        net = rsnet18()

然后在运行窗口输入:

python train.py -net rsnet18 -gpu

就可以运行程序了。
在这里插入图片描述

3、其他代码

论文原作者在GitHub上提供了TFLearn和Keras代码,见链接:https://github.com/zhao62/Deep-Residual-Shrinkage-Networks

也有网友编写了TensorFlow 2.0的代码:
https://blog.csdn.net/qq_36758914/article/details/109452735

  • 25
    点赞
  • 224
    收藏
    觉得还不错? 一键收藏
  • 37
    评论
以下是深度残差收缩网络PyTorch 代码: ```python import torch import torch.nn as nn class ResidualShrinkageBlock(nn.Module): def __init__(self, in_channels, out_channels, stride=1, padding=1, dilation=1, reduction_ratio=16, last=False): super(ResidualShrinkageBlock, self).__init__() self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=padding, dilation=dilation, bias=False) self.bn1 = nn.BatchNorm2d(out_channels) self.relu = nn.ReLU(inplace=True) self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=padding, dilation=dilation, bias=False) self.bn2 = nn.BatchNorm2d(out_channels) self.in_channels = in_channels self.out_channels = out_channels self.stride = stride self.padding = padding self.dilation = dilation self.last = last self.avg_pool = nn.AdaptiveAvgPool2d(1) self.fc1 = nn.Linear(out_channels, out_channels // reduction_ratio) self.fc2 = nn.Linear(out_channels // reduction_ratio, out_channels) self.sigmoid = nn.Sigmoid() def forward(self, x): identity = x out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) if self.last: out = out else: out = self.avg_pool(out) out = out.view(out.size(0), -1) out = self.fc1(out) out = self.relu(out) out = self.fc2(out) out = self.sigmoid(out) out = out.view(out.size(0), out.size(1), 1, 1) out = out * identity out += identity out = self.relu(out) return out class ResidualShrinkageNet(nn.Module): def __init__(self, num_classes=1000): super(ResidualShrinkageNet, self).__init__() self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) self.bn1 = nn.BatchNorm2d(64) self.relu = nn.ReLU(inplace=True) self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) self.layer1 = self._make_layer(64, 256, 3) self.layer2 = self._make_layer(256, 512, 4, stride=2) self.layer3 = self._make_layer(512, 1024, 6, stride=2) self.layer4 = self._make_layer(1024, 2048, 3, stride=2) self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) self.fc = nn.Linear(2048, num_classes) def _make_layer(self, in_channels, out_channels, blocks, stride=1): layers = [] layers.append(ResidualShrinkageBlock(in_channels, out_channels, stride=stride, last=True)) for i in range(1, blocks): layers.append(ResidualShrinkageBlock(out_channels, out_channels)) return nn.Sequential(*layers) def forward(self, x): x = self.conv1(x) x = self.bn1(x) x = self.relu(x) x = self.maxpool(x) x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) x = self.layer4(x) x = self.avgpool(x) x = x.view(x.size(0), -1) x = self.fc(x) return x ``` 这是一个四层残差结构的 Residual Shrinkage Net,其中每个残差块有一个特殊的缩减层,用于减少冗余特征。可以根据需要调整层数和通道数。
评论 37
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值