深度学习non-local:Non-local Neural Network,PANet Pyramid Attention Network for Image Restoration

NON-LOCAL

之前介绍kernel prediction net的时候,会为每一个像素生成一个 filter来处理图像, 这样每个像素都有自定义的filter, 但是这个filter只是处理邻域像素,对于距离更远的区域没办法融合进来处理, 除非 filter很大,比如作用于整张图像。

non-local概念在图像降噪领域有比较广泛的应用,传统算法有。深度学习模型也可以借鉴类似的概念,引入和设计相关non-local模块。

non-local和self attention有一定的相关性,归根到底都是 解决不同区域的相关性以及如何建立联系。

1.全连接层

图像尺寸 h,w,c 。展开为1d
(h x w x c) matmul (h x w x c, h x w x c) 得到结果维度:(h x w x c)。 结果中的每一个数值利用了所有像素的信息。

(h x w x c, h x w x c) 是weight size, h,w 图像宽度和高度一般不小, 参数量很大,计算量也很大。

2.Non-local Neural Network | CVPR2018

https://juejin.cn/post/6914526262992044046 介绍的很好。

  1. 一个 h,w,c 分别经过1x1的卷积 为得到2个featire map: theta(h,w,c1) 和phi(h,w,c1) , 得到一个特征g(hw,c2)
  2. theta(hw,c1) matmul phi(c1,hw) 得到相似度结果 然后softmax 得到p: hw x hw
  3. p(hw,hw)与g(hw, c2) matmul 得到 h,w,c2 再经过 1x1卷积 恢复到尺寸 h,w,c

该计算量和参数量与上面全连接层比较可以看出小很多。但是比一般卷积计算量大很多。

一般卷积比如kernel size=3,输入通道为1, 输出通道数目为c
则计算量为 (hw,9) matmul (9,c) 得到 (hw, c),比较小。
在这里插入图片描述

把1,1卷积变成3,3,等价于block match,可以看 学到的3,3filter是不是类似于boxfilter
更多详细解释:https://cloud.tencent.com/developer/article/1582047

3. PANet Pyramid Attention Network for Image Restoration

https://blog.csdn.net/weixin_42096202/article/details/106240801
特点:

  1. 是像素块的匹配
  2. 利用了多层金字塔

有个问题, 利用了hw 个 3,3 filter, 计算量有点大吧。 假如图像h,w=400,400, 则一共 160000个 3,3,c_in filter. 输出通道是16万。。。。巨大计算量
log:
index scale , wishape raw_w i shape: 0 torch.Size([1, 2500, 4, 3, 3]) torch.Size([1, 2500, 8, 3, 3])
xi, wi_normed shape: torch.Size([1, 4, 50, 50]) torch.Size([2500, 4, 3, 3])
yi , raw_wi shape: torch.Size([1, 2500, 50, 50]) torch.Size([2500, 8, 3, 3])
可以改进的,可以减少计算量,而且特点1和2都保留。 首先进行boxfilter depthwise, 然后利用2的方法

https://github.com/SHI-Labs/Pyramid-Attention-Networks/tree/master

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from torchvision import utils as vutils
import common
from utils.tools import extract_image_patches,\
    reduce_mean, reduce_sum, same_padding

class PyramidAttention(nn.Module):
    def __init__(self, level=5, res_scale=1, channel=64, reduction=2, ksize=3, stride=1, softmax_scale=10, average=True, conv=common.default_conv):
        super(PyramidAttention, self).__init__()
        self.ksize = ksize
        self.stride = stride
        self.res_scale = res_scale
        self.softmax_scale = softmax_scale
        self.scale = [1-i/10 for i in range(level)]
        self.average = average
        escape_NaN = torch.FloatTensor([1e-4])
        self.register_buffer('escape_NaN', escape_NaN)
        self.conv_match_L_base = common.BasicBlock(conv,channel,channel//reduction, 1, bn=False, act=nn.PReLU())
        self.conv_match = common.BasicBlock(conv,channel, channel//reduction, 1, bn=False, act=nn.PReLU())
        self.conv_assembly = common.BasicBlock(conv,channel, channel,1,bn=False, act=nn.PReLU())

    def forward(self, input):
        res = input
        #theta
        match_base = self.conv_match_L_base(input)
        shape_base = list(res.size())
        input_groups = torch.split(match_base,1,dim=0)
        # patch size for matching 
        kernel = self.ksize
        # raw_w is for reconstruction
        raw_w = []
        # w is for matching
        w = []
        #build feature pyramid
        for i in range(len(self.scale)):    
            ref = input
            if self.scale[i]!=1:
                ref  = F.interpolate(input, scale_factor=self.scale[i], mode='bicubic')
            #feature transformation function f
            base = self.conv_assembly(ref)
            shape_input = base.shape
            #sampling
            raw_w_i = extract_image_patches(base, ksizes=[kernel, kernel],
                                      strides=[self.stride,self.stride],
                                      rates=[1, 1],
                                      padding='same') # [N, C*k*k, L]
            raw_w_i = raw_w_i.view(shape_input[0], shape_input[1], kernel, kernel, -1)
            raw_w_i = raw_w_i.permute(0, 4, 1, 2, 3)    # raw_shape: [N, L, C, k, k]
            raw_w_i_groups = torch.split(raw_w_i, 1, dim=0)
            raw_w.append(raw_w_i_groups)

            #feature transformation function g
            ref_i = self.conv_match(ref)
            shape_ref = ref_i.shape
            #sampling
            w_i = extract_image_patches(ref_i, ksizes=[self.ksize, self.ksize],
                                  strides=[self.stride, self.stride],
                                  rates=[1, 1],
                                  padding='same')
            w_i = w_i.view(shape_ref[0], shape_ref[1], self.ksize, self.ksize, -1)
            w_i = w_i.permute(0, 4, 1, 2, 3)    # w shape: [N, L, C, k, k]
            w_i_groups = torch.split(w_i, 1, dim=0)
            w.append(w_i_groups)
            print(' index scale , wishape raw_w i shape:', i, w_i.shape, raw_w_i.shape)
        y = []
        for idx, xi in enumerate(input_groups):
            #group in a filter
            wi = torch.cat([w[i][idx][0] for i in range(len(self.scale))],dim=0)  # [L, C, k, k]
            #normalize
            max_wi = torch.max(torch.sqrt(reduce_sum(torch.pow(wi, 2),
                                                     axis=[1, 2, 3],
                                                     keepdim=True)),
                               self.escape_NaN)
            wi_normed = wi/ max_wi
            print("xi, wi_normed shape: ", xi.shape, wi_normed.shape)
            #matching
            xi = same_padding(xi, [self.ksize, self.ksize], [1, 1], [1, 1])  # xi: 1*c*H*W
            yi = F.conv2d(xi, wi_normed, stride=1)   # [1, L, H, W] L = shape_ref[2]*shape_ref[3]
            yi = yi.view(1,wi.shape[0], shape_base[2], shape_base[3])  # (B=1, C=32*32, H=32, W=32)
            # softmax matching score
            yi = F.softmax(yi*self.softmax_scale, dim=1)
            
            if self.average == False:
                yi = (yi == yi.max(dim=1,keepdim=True)[0]).float()
            
            # deconv for patch pasting
            raw_wi = torch.cat([raw_w[i][idx][0] for i in range(len(self.scale))],dim=0)
            print("yi , raw_wi shape: ",yi.shape, raw_wi.shape)
            yi = F.conv_transpose2d(yi, raw_wi, stride=self.stride,padding=1)/4.
            y.append(yi)
      
        y = torch.cat(y, dim=0)+res*self.res_scale  # back to the mini-batch
        return y
##################下面是另一个文件
import torch

import common
import attention
import torch.nn as nn

from measure_time import measure_inference_speed


def make_model(args, parent=False):
    return PANET(args)

class PANET(nn.Module):
    def __init__(self, args, conv=common.default_conv):
        super(PANET, self).__init__()

        n_resblocks = args.n_resblocks
        n_feats = args.n_feats
        kernel_size = 3 
        scale = 1 #args.scale[0]

        rgb_mean = (0.4488, 0.4371, 0.4040)
        rgb_std = (1.0, 1.0, 1.0)
        #self.sub_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std)
        msa = attention.PyramidAttention()
        # define head module
        m_head = [conv(args.n_colors, n_feats, kernel_size)]

        # define body module
        m_body = [
            common.ResBlock(
                conv, n_feats, kernel_size, nn.PReLU(), res_scale=args.res_scale
            ) for _ in range(n_resblocks//2)
        ]
        m_body.append(msa)
        for i in range(n_resblocks//2):
            m_body.append(common.ResBlock(conv,n_feats,kernel_size,nn.PReLU(),res_scale=args.res_scale))
      
        m_body.append(conv(n_feats, n_feats, kernel_size))

        # define tail module
        #m_tail = [
        #    common.Upsampler(conv, scale, n_feats, act=False),
        #    conv(n_feats, args.n_colors, kernel_size)
        #]
        m_tail = [
            conv(n_feats, args.n_colors, kernel_size)
        ]

        #self.add_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std, 1)

        self.head = nn.Sequential(*m_head)
        self.body = nn.Sequential(*m_body)
        self.tail = nn.Sequential(*m_tail)

    def forward(self, x):
        #x = self.sub_mean(x)
        x = self.head(x)
        
        res = self.body(x)
        
        res += x

        x = self.tail(res)
        #x = self.add_mean(x)

        return x 

    def load_state_dict(self, state_dict, strict=True):
        own_state = self.state_dict()
        for name, param in state_dict.items():
            if name in own_state:
                if isinstance(param, nn.Parameter):
                    param = param.data
                try:
                    own_state[name].copy_(param)
                except Exception:
                    if name.find('tail') == -1:
                        raise RuntimeError('While copying the parameter named {}, '
                                           'whose dimensions in the model are {} and '
                                           'whose dimensions in the checkpoint are {}.'
                                           .format(name, own_state[name].size(), param.size()))
            elif strict:
                if name.find('tail') == -1:
                    raise KeyError('unexpected key "{}" in state_dict'
                                   .format(name))

if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser(description='g')
    # Model specifications
    parser.add_argument('--model', default='PANET',
                        help='model name')

    parser.add_argument('--act', type=str, default='relu',
                        help='activation function')
    parser.add_argument('--pre_train', type=str, default='.',
                        help='pre-trained model directory')
    parser.add_argument('--extend', type=str, default='.',
                        help='pre-trained model directory')
    parser.add_argument('--n_resblocks', type=int, default=16,
                        help='number of residual blocks')
    parser.add_argument('--n_feats', type=int, default=64,
                        help='number of feature maps')
    parser.add_argument('--res_scale', type=float, default=1,
                        help='residual scaling')
    parser.add_argument('--shift_mean', default=True,
                        help='subtract pixel mean from the input')
    parser.add_argument('--dilation', action='store_true',
                        help='use dilated convolution')
    parser.add_argument('--precision', type=str, default='single',
                        choices=('single', 'half'),
                        help='FP precision for test (single | half)')
    args = parser.parse_args()
    args.n_colors = 3


    from ptflops import get_model_complexity_info
    from torchinfo import summary
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    height = 48

    net = PANET(args).to(device)

    data = [torch.rand(1, 3, height, height).to(device)]
    fps = measure_inference_speed(net, data)
    out = net(*data)
    print(out.shape)
    #summary(net, input_size=(1, 3, height, height), col_names=["kernel_size", "output_size", "num_params", "mult_adds"])

    macs, params = get_model_complexity_info(net, (3, height, height), verbose=True, print_per_layer_stat=True)
    print(macs, params, out.shape, 1000 / fps)
  • 7
    点赞
  • 16
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值