NON-LOCAL
之前介绍kernel prediction net的时候,会为每一个像素生成一个 filter来处理图像, 这样每个像素都有自定义的filter, 但是这个filter只是处理邻域像素,对于距离更远的区域没办法融合进来处理, 除非 filter很大,比如作用于整张图像。
non-local概念在图像降噪领域有比较广泛的应用,传统算法有。深度学习模型也可以借鉴类似的概念,引入和设计相关non-local模块。
non-local和self attention有一定的相关性,归根到底都是 解决不同区域的相关性以及如何建立联系。
1.全连接层
图像尺寸 h,w,c 。展开为1d
(h x w x c) matmul (h x w x c, h x w x c) 得到结果维度:(h x w x c)。 结果中的每一个数值利用了所有像素的信息。
(h x w x c, h x w x c) 是weight size, h,w 图像宽度和高度一般不小, 参数量很大,计算量也很大。
2.Non-local Neural Network | CVPR2018
https://juejin.cn/post/6914526262992044046 介绍的很好。
- 一个 h,w,c 分别经过1x1的卷积 为得到2个featire map: theta(h,w,c1) 和phi(h,w,c1) , 得到一个特征g(hw,c2)
- theta(hw,c1) matmul phi(c1,hw) 得到相似度结果 然后softmax 得到p: hw x hw
- p(hw,hw)与g(hw, c2) matmul 得到 h,w,c2 再经过 1x1卷积 恢复到尺寸 h,w,c
该计算量和参数量与上面全连接层比较可以看出小很多。但是比一般卷积计算量大很多。
一般卷积比如kernel size=3,输入通道为1, 输出通道数目为c
则计算量为 (hw,9) matmul (9,c) 得到 (hw, c),比较小。
把1,1卷积变成3,3,等价于block match,可以看 学到的3,3filter是不是类似于boxfilter
更多详细解释:https://cloud.tencent.com/developer/article/1582047
3. PANet Pyramid Attention Network for Image Restoration
https://blog.csdn.net/weixin_42096202/article/details/106240801
特点:
- 是像素块的匹配
- 利用了多层金字塔
有个问题, 利用了hw 个 3,3 filter, 计算量有点大吧。 假如图像h,w=400,400, 则一共 160000个 3,3,c_in filter. 输出通道是16万。。。。巨大计算量
log:
index scale , wishape raw_w i shape: 0 torch.Size([1, 2500, 4, 3, 3]) torch.Size([1, 2500, 8, 3, 3])
xi, wi_normed shape: torch.Size([1, 4, 50, 50]) torch.Size([2500, 4, 3, 3])
yi , raw_wi shape: torch.Size([1, 2500, 50, 50]) torch.Size([2500, 8, 3, 3])
可以改进的,可以减少计算量,而且特点1和2都保留。 首先进行boxfilter depthwise, 然后利用2的方法
https://github.com/SHI-Labs/Pyramid-Attention-Networks/tree/master
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from torchvision import utils as vutils
import common
from utils.tools import extract_image_patches,\
reduce_mean, reduce_sum, same_padding
class PyramidAttention(nn.Module):
def __init__(self, level=5, res_scale=1, channel=64, reduction=2, ksize=3, stride=1, softmax_scale=10, average=True, conv=common.default_conv):
super(PyramidAttention, self).__init__()
self.ksize = ksize
self.stride = stride
self.res_scale = res_scale
self.softmax_scale = softmax_scale
self.scale = [1-i/10 for i in range(level)]
self.average = average
escape_NaN = torch.FloatTensor([1e-4])
self.register_buffer('escape_NaN', escape_NaN)
self.conv_match_L_base = common.BasicBlock(conv,channel,channel//reduction, 1, bn=False, act=nn.PReLU())
self.conv_match = common.BasicBlock(conv,channel, channel//reduction, 1, bn=False, act=nn.PReLU())
self.conv_assembly = common.BasicBlock(conv,channel, channel,1,bn=False, act=nn.PReLU())
def forward(self, input):
res = input
#theta
match_base = self.conv_match_L_base(input)
shape_base = list(res.size())
input_groups = torch.split(match_base,1,dim=0)
# patch size for matching
kernel = self.ksize
# raw_w is for reconstruction
raw_w = []
# w is for matching
w = []
#build feature pyramid
for i in range(len(self.scale)):
ref = input
if self.scale[i]!=1:
ref = F.interpolate(input, scale_factor=self.scale[i], mode='bicubic')
#feature transformation function f
base = self.conv_assembly(ref)
shape_input = base.shape
#sampling
raw_w_i = extract_image_patches(base, ksizes=[kernel, kernel],
strides=[self.stride,self.stride],
rates=[1, 1],
padding='same') # [N, C*k*k, L]
raw_w_i = raw_w_i.view(shape_input[0], shape_input[1], kernel, kernel, -1)
raw_w_i = raw_w_i.permute(0, 4, 1, 2, 3) # raw_shape: [N, L, C, k, k]
raw_w_i_groups = torch.split(raw_w_i, 1, dim=0)
raw_w.append(raw_w_i_groups)
#feature transformation function g
ref_i = self.conv_match(ref)
shape_ref = ref_i.shape
#sampling
w_i = extract_image_patches(ref_i, ksizes=[self.ksize, self.ksize],
strides=[self.stride, self.stride],
rates=[1, 1],
padding='same')
w_i = w_i.view(shape_ref[0], shape_ref[1], self.ksize, self.ksize, -1)
w_i = w_i.permute(0, 4, 1, 2, 3) # w shape: [N, L, C, k, k]
w_i_groups = torch.split(w_i, 1, dim=0)
w.append(w_i_groups)
print(' index scale , wishape raw_w i shape:', i, w_i.shape, raw_w_i.shape)
y = []
for idx, xi in enumerate(input_groups):
#group in a filter
wi = torch.cat([w[i][idx][0] for i in range(len(self.scale))],dim=0) # [L, C, k, k]
#normalize
max_wi = torch.max(torch.sqrt(reduce_sum(torch.pow(wi, 2),
axis=[1, 2, 3],
keepdim=True)),
self.escape_NaN)
wi_normed = wi/ max_wi
print("xi, wi_normed shape: ", xi.shape, wi_normed.shape)
#matching
xi = same_padding(xi, [self.ksize, self.ksize], [1, 1], [1, 1]) # xi: 1*c*H*W
yi = F.conv2d(xi, wi_normed, stride=1) # [1, L, H, W] L = shape_ref[2]*shape_ref[3]
yi = yi.view(1,wi.shape[0], shape_base[2], shape_base[3]) # (B=1, C=32*32, H=32, W=32)
# softmax matching score
yi = F.softmax(yi*self.softmax_scale, dim=1)
if self.average == False:
yi = (yi == yi.max(dim=1,keepdim=True)[0]).float()
# deconv for patch pasting
raw_wi = torch.cat([raw_w[i][idx][0] for i in range(len(self.scale))],dim=0)
print("yi , raw_wi shape: ",yi.shape, raw_wi.shape)
yi = F.conv_transpose2d(yi, raw_wi, stride=self.stride,padding=1)/4.
y.append(yi)
y = torch.cat(y, dim=0)+res*self.res_scale # back to the mini-batch
return y
##################下面是另一个文件
import torch
import common
import attention
import torch.nn as nn
from measure_time import measure_inference_speed
def make_model(args, parent=False):
return PANET(args)
class PANET(nn.Module):
def __init__(self, args, conv=common.default_conv):
super(PANET, self).__init__()
n_resblocks = args.n_resblocks
n_feats = args.n_feats
kernel_size = 3
scale = 1 #args.scale[0]
rgb_mean = (0.4488, 0.4371, 0.4040)
rgb_std = (1.0, 1.0, 1.0)
#self.sub_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std)
msa = attention.PyramidAttention()
# define head module
m_head = [conv(args.n_colors, n_feats, kernel_size)]
# define body module
m_body = [
common.ResBlock(
conv, n_feats, kernel_size, nn.PReLU(), res_scale=args.res_scale
) for _ in range(n_resblocks//2)
]
m_body.append(msa)
for i in range(n_resblocks//2):
m_body.append(common.ResBlock(conv,n_feats,kernel_size,nn.PReLU(),res_scale=args.res_scale))
m_body.append(conv(n_feats, n_feats, kernel_size))
# define tail module
#m_tail = [
# common.Upsampler(conv, scale, n_feats, act=False),
# conv(n_feats, args.n_colors, kernel_size)
#]
m_tail = [
conv(n_feats, args.n_colors, kernel_size)
]
#self.add_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std, 1)
self.head = nn.Sequential(*m_head)
self.body = nn.Sequential(*m_body)
self.tail = nn.Sequential(*m_tail)
def forward(self, x):
#x = self.sub_mean(x)
x = self.head(x)
res = self.body(x)
res += x
x = self.tail(res)
#x = self.add_mean(x)
return x
def load_state_dict(self, state_dict, strict=True):
own_state = self.state_dict()
for name, param in state_dict.items():
if name in own_state:
if isinstance(param, nn.Parameter):
param = param.data
try:
own_state[name].copy_(param)
except Exception:
if name.find('tail') == -1:
raise RuntimeError('While copying the parameter named {}, '
'whose dimensions in the model are {} and '
'whose dimensions in the checkpoint are {}.'
.format(name, own_state[name].size(), param.size()))
elif strict:
if name.find('tail') == -1:
raise KeyError('unexpected key "{}" in state_dict'
.format(name))
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description='g')
# Model specifications
parser.add_argument('--model', default='PANET',
help='model name')
parser.add_argument('--act', type=str, default='relu',
help='activation function')
parser.add_argument('--pre_train', type=str, default='.',
help='pre-trained model directory')
parser.add_argument('--extend', type=str, default='.',
help='pre-trained model directory')
parser.add_argument('--n_resblocks', type=int, default=16,
help='number of residual blocks')
parser.add_argument('--n_feats', type=int, default=64,
help='number of feature maps')
parser.add_argument('--res_scale', type=float, default=1,
help='residual scaling')
parser.add_argument('--shift_mean', default=True,
help='subtract pixel mean from the input')
parser.add_argument('--dilation', action='store_true',
help='use dilated convolution')
parser.add_argument('--precision', type=str, default='single',
choices=('single', 'half'),
help='FP precision for test (single | half)')
args = parser.parse_args()
args.n_colors = 3
from ptflops import get_model_complexity_info
from torchinfo import summary
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
height = 48
net = PANET(args).to(device)
data = [torch.rand(1, 3, height, height).to(device)]
fps = measure_inference_speed(net, data)
out = net(*data)
print(out.shape)
#summary(net, input_size=(1, 3, height, height), col_names=["kernel_size", "output_size", "num_params", "mult_adds"])
macs, params = get_model_complexity_info(net, (3, height, height), verbose=True, print_per_layer_stat=True)
print(macs, params, out.shape, 1000 / fps)