实验python train.py -opt options/train/train_sr.json
先激活虚拟环境source activate pytorch
tensorboard --logdir tb_logger/ --port 6008
浏览器打开http://172.20.36.203:6008/#scalars
本博文为论文《Multi-scale Location-aware Kernel Representation for Object Detection》和《Higher-order integration of hierarchical convolutional activations for fine-grained visua》的高阶特征的实现笔记。
理论部分可以参考博文《论文阅读笔记之——《Higher-order integration of hierarchical convolutional activations for fine-grained visua》》
两篇论文中提到的kernel representation是通过1*1 Conv来实现的。论文中说通过1*1的Conv来近似keenel representationn,那用如此复杂的数学推导后,不轻易公开自己的code,实际上就是conv层可以解决~~~。通过下面两张图,来探究特征的维度选择
1*1Conv选择channel为8192.但是在《Multi-scale Location-aware Kernel Representation for Object Detection》中貌似用了4906
下面表格是调参过程
从表格中反倒是三阶特征是最好的,上面则是二阶特征是最好的。下面分别调参试试。由于input已经是1024,2048和4096其实相当于input的两倍或者四倍而已。
先给出refine block
#modified octave
# Block for OctConv
####################
class M_OctaveConv(nn.Module):
def __init__(self, in_nc, out_nc, kernel_size, alpha=0.5, stride=1, dilation=1, groups=1, \
bias=True, pad_type='zero', norm_type=None, act_type='prelu', mode='CNA'):
super(M_OctaveConv, self).__init__()
assert mode in ['CNA', 'NAC', 'CNAC'], 'Wong conv mode [{:s}]'.format(mode)
padding = get_valid_padding(kernel_size, dilation) if pad_type == 'zero' else 0
padding1 = get_valid_padding(1, dilation) if pad_type == 'zero' else 0
self.h2g_pool = nn.AvgPool2d(kernel_size=(2, 2), stride=2)
self.upsample = nn.Upsample(scale_factor=2, mode='nearest')
self.stride = stride
##################################################
self.l2l = nn.Conv2d(int(alpha * in_nc), int(alpha * out_nc),
kernel_size, 1, padding, dilation, groups, bias)
self.l2h = nn.Conv2d(int(alpha * in_nc), out_nc - int(alpha * out_nc),
kernel_size, 1, padding, dilation, groups, bias)
self.h2l = nn.Conv2d(in_nc - int(alpha * in_nc), int(alpha * out_nc),
kernel_size, 1, padding, dilation, groups, bias)
self.h2h = nn.Conv2d(in_nc - int(alpha * in_nc), out_nc - int(alpha * out_nc),
kernel_size, 1, padding, dilation, groups, bias)
##################################################################################################
self.hh2h = nn.Conv2d(2*(out_nc - int(alpha * out_nc)), out_nc - int(alpha * out_nc),
1, 1, padding1, dilation, groups, bias)
self.ll2l = nn.Conv2d(2*int(alpha * out_nc), int(alpha * out_nc),
1, 1, padding1, dilation, groups, bias)
#####################################################################################################
self.a = act(act_type) if act_type else None
self.n_h = norm(norm_type, int(out_nc*(1 - alpha))) if norm_type else None
self.n_l = norm(norm_type, int(out_nc*alpha)) if norm_type else None
def forward(self, x):
X_h, X_l = x
X_h1, X_l1 = x
if self.stride ==2:
X_h, X_l = self.h2g_pool(X_h), self.h2g_pool(X_l)
X_h2h = self.h2h(X_h)
X_l2h = self.upsample(self.l2h(X_l))
X_l2l = self.l2l(X_l)
X_h2l = self.h2l(self.h2g_pool(X_h))
#print(X_l2h.shape,"~~~~",X_h2h.shape)
X_h=torch.cat((X_l2h,X_h2h), dim=1)
X_l=torch.cat((X_h2l,X_l2l), dim=1)
X_hh=self.hh2h(X_h)
X_ll=self.ll2l(X_l)
X_h=X_hh+X_h1
X_l=X_ll+X_l1
#X_h = X_l2h + X_h2h
#X_l = X_h2l + X_l2l
if self.n_h and self.n_l:
X_h = self.n_h(X_h)
X_l = self.n_l(X_l)
if self.a:
X_h = self.a(X_h)
X_l = self.a(X_l)
return X_h, X_l
class M_FirstOctaveConv(nn.Module):
def __init__(self, in_nc, out_nc, kernel_size, alpha=0.5, stride=1, dilation=1, groups=1, \
bias=True, pad_type='zero', norm_type=None, act_type='prelu', mode='CNA'):
super(M_FirstOctaveConv, self).__init__()
assert mode in ['CNA', 'NAC', 'CNAC'], 'Wong conv mode [{:s}]'.format(mode)
padding = get_valid_padding(kernel_size, dilation) if pad_type == 'zero' else 0
self.h2g_pool = nn.AvgPool2d(kernel_size=(2, 2), stride=2)
self.stride = stride
self.h2l = nn.Conv2d(in_nc, int(alpha * out_nc),
kernel_size, 1, padding, dilation, groups, bias)
self.h2h = nn.Conv2d(in_nc, out_nc - int(alpha * out_nc),
kernel_size, 1, padding, dilation, groups, bias)
self.a = act(act_type) if act_type else None
self.n_h = norm(norm_type, int(out_nc*(1 - alpha))) if norm_type else None
self.n_l = norm(norm_type, int(out_nc*alpha)) if norm_type else None
def forward(self, x):
if self.stride ==2:
x = self.h2g_pool(x)
X_h = self.h2h(x)
X_l = self.h2l(self.h2g_pool(x))
if self.n_h and self.n_l:
X_h = self.n_h(X_h)
X_l = self.n_l(X_l)
if self.a:
X_h = self.a(X_h)
X_l = self.a(X_l)
return X_h, X_l
class M_LastOctaveConv(nn.Module):
def __init__(self, in_nc, out_nc, kernel_size, alpha=0.5, stride=1, dilation=1, groups=1, \
bias=True, pad_type='zero', norm_type=None, act_type='prelu', mode='CNA'):
super(M_LastOctaveConv, self).__init__()
assert mode in ['CNA', 'NAC', 'CNAC'], 'Wong conv mode [{:s}]'.format(mode)
padding = get_valid_padding(kernel_size, dilation) if pad_type == 'zero' else 0
padding1 = get_valid_padding(1, dilation) if pad_type == 'zero' else 0
self.h2g_pool = nn.AvgPool2d(kernel_size=(2,2), stride=2)
self.upsample = nn.Upsample(scale_factor=2, mode='nearest')
self.stride = stride
self.l2h = nn.Conv2d(int(alpha * in_nc), out_nc,
kernel_size, 1, padding, dilation, groups, bias)
self.h2h = nn.Conv2d(in_nc - int(alpha * in_nc), out_nc,
kernel_size, 1, padding, dilation, groups, bias)
self.hh2h = nn.Conv2d(2*out_nc, out_nc,
1, 1, padding1, dilation, groups, bias)
self.h12h = nn.Conv2d(in_nc - int(alpha * in_nc), out_nc,
1, 1, padding1, dilation, groups, bias)
self.a = act(act_type) if act_type else None
self.n_h = norm(norm_type, out_nc) if norm_type else None
def forward(self, x):
X_h, X_l = x
X_h1, X_l1 = x
if self.stride ==2:
X_h, X_l = self.h2g_pool(X_h), self.h2g_pool(X_l)
X_h2h = self.h2h(X_h)
X_l2h = self.upsample(self.l2h(X_l))
#X_h = X_h2h + X_l2h
X_h=torch.cat((X_l2h,X_h2h), dim=1)
X_h=self.hh2h(X_h)
X_h1=self.h12h(X_h1)
X_h=X_h+X_h1
if self.n_h:
X_h = self.n_h(X_h)
if self.a:
X_h = self.a(X_h)
return X_h
class M_OctaveCascadeBlock(nn.Module):
"""
OctaveCascadeBlock, 3-3 style
"""
def __init__(self, nc, gc, kernel_size=3, alpha=0.75, stride=1, dilation=1, groups=1, \
bias=True, pad_type='zero', norm_type=None, act_type='prelu', mode='CNA', res_scale=1):
super(M_OctaveCascadeBlock, self).__init__()
self.nc = nc
self.ResBlocks = nn.ModuleList([M_OctaveResBlock(gc, gc, gc, kernel_size, alpha, stride, dilation, \
groups, bias, pad_type, norm_type, act_type, mode, res_scale) for _ in range(nc)])
self.CatBlocks = nn.ModuleList([M_OctaveConv((i + 2)*gc, gc, kernel_size=1, alpha=alpha, bias=bias, \
pad_type=pad_type, norm_type=norm_type, act_type=act_type, mode=mode) for i in range(nc)])
def forward(self, x):
pre_fea = x
for i in range(self.nc):
res = self.ResBlocks[i](x)
pre_fea = (torch.cat((pre_fea[0], res[0]), dim=1), \
torch.cat((pre_fea[1], res[1]), dim=1))
x = self.CatBlocks[i](pre_fea)
return x
class M_OctaveResBlock(nn.Module):
'''
ResNet Block, 3-3 style
with extra residual scaling used in EDSR
(Enhanced Deep Residual Networks for Single Image Super-Resolution, CVPRW 17)
'''
def __init__(self, in_nc, mid_nc, out_nc, kernel_size=3, alpha=0.75, stride=1, dilation=1, groups=1, \
bias=True, pad_type='zero', norm_type=None, act_type='prelu', mode='CNA', res_scale=1):
super(M_OctaveResBlock, self).__init__()
conv0 = M_OctaveConv(in_nc, mid_nc, kernel_size, alpha, stride, dilation, groups, bias, pad_type, \
norm_type, act_type, mode)
if mode == 'CNA':
act_type = None
if mode == 'CNAC': # Residual path: |-CNAC-|
act_type = None
norm_type = None
conv1 = M_OctaveConv(mid_nc, out_nc, kernel_size, alpha, stride, dilation, groups, bias, pad_type, \
norm_type, act_type, mode)
self.res = sequential(conv0, conv1)
self.res_scale = res_scale
def forward(self, x):
#if(len(x)>2):
#print(x[0].shape," ",x[1].shape," ",x[2].shape," ",x[3].shape)
#print(len(x))
res = self.res(x)
res = (res[0].mul(self.res_scale), res[1].mul(self.res_scale))
x = (x[0] + res[0], x[1] + res[1])
#print(len(x),"~~~",len(res),"~~~",len(x + res))
#return (x[0] + res[0], x[1]+res[1])
return x
################################################################################################
class M_octave_ResidualDenseBlockTiny_4C(nn.Module):
'''
Residual Dense Block
style: 4 convs
The core module of paper: (Residual Dense Network for Image Super-Resolution, CVPR 18)
'''
def __init__(self, nc, kernel_size=3, gc=16,alpha=0.5, stride=1, bias=True, pad_type='zero', \
norm_type=None, act_type='leakyrelu', mode='CNA'):
super(M_octave_ResidualDenseBlockTiny_4C, self).__init__()
# gc: growth channel, i.e. intermediate channels
self.conv1 =M_OctaveConv(nc, gc, kernel_size, alpha, stride, bias=bias, pad_type=pad_type, \
norm_type=norm_type, act_type=act_type, mode=mode)
# conv_block(nc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, \
# norm_type=norm_type, act_type=act_type, mode=mode)
self.conv2 = M_OctaveConv(nc+gc, gc, kernel_size, alpha, stride, bias=bias, pad_type=pad_type, \
norm_type=norm_type, act_type=act_type, mode=mode)
# conv_block(nc+gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, \
# norm_type=norm_type, act_type=act_type, mode=mode)
self.conv3 = M_OctaveConv(nc+2*gc, gc, kernel_size, alpha, stride, bias=bias, pad_type=pad_type, \
norm_type=norm_type, act_type=act_type, mode=mode)
# conv_block(nc+2*gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, \
# norm_type=norm_type, act_type=act_type, mode=mode)
if mode == 'CNA':
last_act = None
else:
last_act = act_type
self.conv4 = M_OctaveConv(nc+3*gc, nc, kernel_size, alpha, stride, bias=bias, pad_type=pad_type, \
norm_type=norm_type, act_type=act_type, mode=mode)
# conv_block(nc+3*gc, nc, 3, stride, bias=bias, pad_type=pad_type, \
# norm_type=norm_type, act_type=last_act, mode=mode)
def forward(self, x):
x1 = self.conv1(x)
x2 = self.conv2((torch.cat((x[0], x1[0]), dim=1),(torch.cat((x[1], x1[1]), dim=1))))
x3 = self.conv3((torch.cat((x[0], x1[0],x2[0]), dim=1),(torch.cat((x[1], x1[1],x2[1]), dim=1))))
x4 = self.conv4((torch.cat((x[0], x1[0],x2[0],x3[0]), dim=1),(torch.cat((x[1], x1[1],x2[1],x3[1]), dim=1))))
res = (x4[0].mul(0.2), x4[1].mul(0.2))
x = (x[0] + res[0], x[1] + res[1])
#print(len(x),"~~~",len(res),"~~~",len(x + res))
#return (x[0] + res[0], x[1]+res[1])
return x
####################################################################################################################
class M_octave_RRDBTiny(nn.Module):
'''
Residual in Residual Dense Block
(ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks)
'''
def __init__(self, nc, kernel_size=3, gc=16, stride=1, alpha=0.5, bias=True, pad_type='zero', \
norm_type=None, act_type='leakyrelu', mode='CNA'):
super(M_octave_RRDBTiny, self).__init__()
self.RDB1 = M_octave_ResidualDenseBlockTiny_4C(nc=nc, kernel_size=kernel_size,alpha=alpha, gc=gc, stride=stride, bias=bias, pad_type=pad_type, \
norm_type=norm_type, act_type=act_type, mode=mode)
self.RDB2 = M_octave_ResidualDenseBlockTiny_4C(nc=nc, kernel_size=kernel_size,alpha=alpha, gc=gc, stride=stride, bias=bias, pad_type=pad_type, \
norm_type=norm_type, act_type=act_type, mode=mode)
def forward(self, x):
out = self.RDB1(x)
out = self.RDB2(out)
res = (out[0].mul(0.2), out[1].mul(0.2))
x = (x[0] + res[0], x[1] + res[1])
#print(len(x),"~~~",len(res),"~~~",len(x + res))
#return (x[0] + res[0], x[1]+res[1])
return x
在次结构下,重塑为:
设计的结构是基于之前博文介绍的octave SRResNet,由于里面正好具有两个scale,只在最后一层融合的时候,加入高阶特征
##################################################################################
##################################################################################
##################################################################################
#modified octave
# Block for OctConv
####################
class M_OctaveConv(nn.Module):
def __init__(self, in_nc, out_nc, kernel_size, alpha=0.5, stride=1, dilation=1, groups=1, \
bias=True, pad_type='zero', norm_type=None, act_type='prelu', mode='CNA'):
super(M_OctaveConv, self).__init__()
assert mode in ['CNA', 'NAC', 'CNAC'], 'Wong conv mode [{:s}]'.format(mode)
padding = get_valid_padding(kernel_size, dilation) if pad_type == 'zero' else 0
padding1 = get_valid_padding(1, dilation) if pad_type == 'zero' else 0
self.h2g_pool = nn.AvgPool2d(kernel_size=(2, 2), stride=2)
self.upsample = nn.Upsample(scale_factor=2, mode='nearest')
self.stride = stride
##################################################
self.l2l = nn.Conv2d(int(alpha * in_nc), int(alpha * out_nc),
kernel_size, 1, padding, dilation, groups, bias)
self.l2h = nn.Conv2d(int(alpha * in_nc), out_nc - int(alpha * out_nc),
kernel_size, 1, padding, dilation, groups, bias)
self.h2l = nn.Conv2d(in_nc - int(alpha * in_nc), int(alpha * out_nc),
kernel_size, 1, padding, dilation, groups, bias)
self.h2h = nn.Conv2d(in_nc - int(alpha * in_nc), out_nc - int(alpha * out_nc),
kernel_size, 1, padding, dilation, groups, bias)
self.a = act(act_type) if act_type else None
self.n_h = norm(norm_type, int(out_nc*(1 - alpha))) if norm_type else None
self.n_l = norm(norm_type, int(out_nc*alpha)) if norm_type else None
def forward(self, x):
X_h, X_l = x
if self.stride ==2:
X_h, X_l = self.h2g_pool(X_h), self.h2g_pool(X_l)
X_h2h = self.h2h(X_h)
X_l2h = self.upsample(self.l2h(X_l))
X_l2l = self.l2l(X_l)
X_h2l = self.h2l(self.h2g_pool(X_h))
#print(X_l2h.shape,"~~~~",X_h2h.shape)
X_h = X_l2h + X_h2h
X_l = X_h2l + X_l2l
if self.n_h and self.n_l:
X_h = self.n_h(X_h)
X_l = self.n_l(X_l)
if self.a:
X_h = self.a(X_h)
X_l = self.a(X_l)
return X_h, X_l
class M_FirstOctaveConv(nn.Module):
def __init__(self, in_nc, out_nc, kernel_size, alpha=0.5, stride=1, dilation=1, groups=1, \
bias=True, pad_type='zero', norm_type=None, act_type='prelu', mode='CNA'):
super(M_FirstOctaveConv, self).__init__()
assert mode in ['CNA', 'NAC', 'CNAC'], 'Wong conv mode [{:s}]'.format(mode)
padding = get_valid_padding(kernel_size, dilation) if pad_type == 'zero' else 0
self.h2g_pool = nn.AvgPool2d(kernel_size=(2, 2), stride=2)
self.stride = stride
self.h2l = nn.Conv2d(in_nc, int(alpha * out_nc),
kernel_size, 1, padding, dilation, groups, bias)
self.h2h = nn.Conv2d(in_nc, out_nc - int(alpha * out_nc),
kernel_size, 1, padding, dilation, groups, bias)
self.a = act(act_type) if act_type else None
self.n_h = norm(norm_type, int(out_nc*(1 - alpha))) if norm_type else None
self.n_l = norm(norm_type, int(out_nc*alpha)) if norm_type else None
def forward(self, x):
if self.stride ==2:
x = self.h2g_pool(x)
X_h = self.h2h(x)
X_l = self.h2l(self.h2g_pool(x))
if self.n_h and self.n_l:
X_h = self.n_h(X_h)
X_l = self.n_l(X_l)
if self.a:
X_h = self.a(X_h)
X_l = self.a(X_l)
return X_h, X_l
class M_LastOctaveConv(nn.Module):
def __init__(self, in_nc, out_nc, kernel_size, alpha=0.5, stride=1, dilation=1, groups=1, \
bias=True, pad_type='zero', norm_type=None, act_type='prelu', mode='CNA'):
super(M_LastOctaveConv, self).__init__()
assert mode in ['CNA', 'NAC', 'CNAC'], 'Wong conv mode [{:s}]'.format(mode)
padding = get_valid_padding(kernel_size, dilation) if pad_type == 'zero' else 0
padding1 = get_valid_padding(1, dilation) if pad_type == 'zero' else 0
self.h2g_pool = nn.AvgPool2d(kernel_size=(2,2), stride=2)
self.upsample = nn.Upsample(scale_factor=2, mode='nearest')
self.stride = stride
self.l2h = nn.Conv2d(int(alpha * in_nc), out_nc,
kernel_size, 1, padding, dilation, groups, bias)
self.h2h = nn.Conv2d(in_nc - int(alpha * in_nc), out_nc,
kernel_size, 1, padding, dilation, groups, bias)
self.hh2h = nn.Conv2d(2*out_nc, out_nc,
1, 1, padding1, dilation, groups, bias)
self.h12h = nn.Conv2d(in_nc - int(alpha * in_nc), out_nc,
1, 1, padding1, dilation, groups, bias)
##################kernel representation
self.con1=nn.Conv2d(out_nc, 8*out_nc, 1, 1, padding1, dilation, groups, bias)
self.con2=nn.Conv2d(out_nc, 8*out_nc, 1, 1, padding1, dilation, groups, bias)
self.con3=nn.Conv2d(out_nc, 8*out_nc, 1, 1, padding1, dilation, groups, bias)
self.con_out=nn.Conv2d(out_nc+2*8*out_nc, out_nc, 1, 1, padding1, dilation, groups, bias)
self.a = act(act_type) if act_type else None
self.n_h = norm(norm_type, out_nc) if norm_type else None
def forward(self, x):
X_h, X_l = x
if self.stride ==2:
X_h, X_l = self.h2g_pool(X_h), self.h2g_pool(X_l)
X_h2h = self.h2h(X_h)
X_l2h = self.upsample(self.l2h(X_l))
X_h = X_h2h + X_l2h
X_h1=self.con1(X_h)
X_h2=self.con2(X_h)
X_h3=self.con3(X_h)
#x.mul(y)
X_r1=X_h###first order
X_r2=X_h1.mul(X_h2)###second order
X_r3=X_h1.mul(X_h2).mul(X_h3)###three order
##############
X_h=torch.cat((X_r1,X_r2,X_r3), dim=1)
X_h=self.con_out(X_h)
if self.n_h:
X_h = self.n_h(X_h)
if self.a:
X_h = self.a(X_h)
return X_h
class M_OctaveCascadeBlock(nn.Module):
"""
OctaveCascadeBlock, 3-3 style
"""
def __init__(self, nc, gc, kernel_size=3, alpha=0.75, stride=1, dilation=1, groups=1, \
bias=True, pad_type='zero', norm_type=None, act_type='prelu', mode='CNA', res_scale=1):
super(M_OctaveCascadeBlock, self).__init__()
self.nc = nc
self.ResBlocks = nn.ModuleList([M_OctaveResBlock(gc, gc, gc, kernel_size, alpha, stride, dilation, \
groups, bias, pad_type, norm_type, act_type, mode, res_scale) for _ in range(nc)])
self.CatBlocks = nn.ModuleList([M_OctaveConv((i + 2)*gc, gc, kernel_size=1, alpha=alpha, bias=bias, \
pad_type=pad_type, norm_type=norm_type, act_type=act_type, mode=mode) for i in range(nc)])
def forward(self, x):
pre_fea = x
for i in range(self.nc):
res = self.ResBlocks[i](x)
pre_fea = (torch.cat((pre_fea[0], res[0]), dim=1), \
torch.cat((pre_fea[1], res[1]), dim=1))
x = self.CatBlocks[i](pre_fea)
return x
class M_OctaveResBlock(nn.Module):
'''
ResNet Block, 3-3 style
with extra residual scaling used in EDSR
(Enhanced Deep Residual Networks for Single Image Super-Resolution, CVPRW 17)
'''
def __init__(self, in_nc, mid_nc, out_nc, kernel_size=3, alpha=0.75, stride=1, dilation=1, groups=1, \
bias=True, pad_type='zero', norm_type=None, act_type='prelu', mode='CNA', res_scale=1):
super(M_OctaveResBlock, self).__init__()
conv0 = M_OctaveConv(in_nc, mid_nc, kernel_size, alpha, stride, dilation, groups, bias, pad_type, \
norm_type, act_type, mode)
if mode == 'CNA':
act_type = None
if mode == 'CNAC': # Residual path: |-CNAC-|
act_type = None
norm_type = None
conv1 = M_OctaveConv(mid_nc, out_nc, kernel_size, alpha, stride, dilation, groups, bias, pad_type, \
norm_type, act_type, mode)
self.res = sequential(conv0, conv1)
self.res_scale = res_scale
def forward(self, x):
#if(len(x)>2):
#print(x[0].shape," ",x[1].shape," ",x[2].shape," ",x[3].shape)
#print(len(x))
res = self.res(x)
res = (res[0].mul(self.res_scale), res[1].mul(self.res_scale))
x = (x[0] + res[0], x[1] + res[1])
#print(len(x),"~~~",len(res),"~~~",len(x + res))
#return (x[0] + res[0], x[1]+res[1])
return x
################################################################################################
class M_octave_ResidualDenseBlockTiny_4C(nn.Module):
'''
Residual Dense Block
style: 4 convs
The core module of paper: (Residual Dense Network for Image Super-Resolution, CVPR 18)
'''
def __init__(self, nc, kernel_size=3, gc=16,alpha=0.5, stride=1, bias=True, pad_type='zero', \
norm_type=None, act_type='leakyrelu', mode='CNA'):
super(M_octave_ResidualDenseBlockTiny_4C, self).__init__()
# gc: growth channel, i.e. intermediate channels
self.conv1 =M_OctaveConv(nc, gc, kernel_size, alpha, stride, bias=bias, pad_type=pad_type, \
norm_type=norm_type, act_type=act_type, mode=mode)
# conv_block(nc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, \
# norm_type=norm_type, act_type=act_type, mode=mode)
self.conv2 = M_OctaveConv(nc+gc, gc, kernel_size, alpha, stride, bias=bias, pad_type=pad_type, \
norm_type=norm_type, act_type=act_type, mode=mode)
# conv_block(nc+gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, \
# norm_type=norm_type, act_type=act_type, mode=mode)
self.conv3 = M_OctaveConv(nc+2*gc, gc, kernel_size, alpha, stride, bias=bias, pad_type=pad_type, \
norm_type=norm_type, act_type=act_type, mode=mode)
# conv_block(nc+2*gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, \
# norm_type=norm_type, act_type=act_type, mode=mode)
if mode == 'CNA':
last_act = None
else:
last_act = act_type
self.conv4 = M_OctaveConv(nc+3*gc, nc, kernel_size, alpha, stride, bias=bias, pad_type=pad_type, \
norm_type=norm_type, act_type=act_type, mode=mode)
# conv_block(nc+3*gc, nc, 3, stride, bias=bias, pad_type=pad_type, \
# norm_type=norm_type, act_type=last_act, mode=mode)
def forward(self, x):
x1 = self.conv1(x)
x2 = self.conv2((torch.cat((x[0], x1[0]), dim=1),(torch.cat((x[1], x1[1]), dim=1))))
x3 = self.conv3((torch.cat((x[0], x1[0],x2[0]), dim=1),(torch.cat((x[1], x1[1],x2[1]), dim=1))))
x4 = self.conv4((torch.cat((x[0], x1[0],x2[0],x3[0]), dim=1),(torch.cat((x[1], x1[1],x2[1],x3[1]), dim=1))))
res = (x4[0].mul(0.2), x4[1].mul(0.2))
x = (x[0] + res[0], x[1] + res[1])
#print(len(x),"~~~",len(res),"~~~",len(x + res))
#return (x[0] + res[0], x[1]+res[1])
return x
####################################################################################################################
class M_octave_RRDBTiny(nn.Module):
'''
Residual in Residual Dense Block
(ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks)
'''
def __init__(self, nc, kernel_size=3, gc=16, stride=1, alpha=0.5, bias=True, pad_type='zero', \
norm_type=None, act_type='leakyrelu', mode='CNA'):
super(M_octave_RRDBTiny, self).__init__()
self.RDB1 = M_octave_ResidualDenseBlockTiny_4C(nc=nc, kernel_size=kernel_size,alpha=alpha, gc=gc, stride=stride, bias=bias, pad_type=pad_type, \
norm_type=norm_type, act_type=act_type, mode=mode)
self.RDB2 = M_octave_ResidualDenseBlockTiny_4C(nc=nc, kernel_size=kernel_size,alpha=alpha, gc=gc, stride=stride, bias=bias, pad_type=pad_type, \
norm_type=norm_type, act_type=act_type, mode=mode)
def forward(self, x):
out = self.RDB1(x)
out = self.RDB2(out)
res = (out[0].mul(0.2), out[1].mul(0.2))
x = (x[0] + res[0], x[1] + res[1])
#print(len(x),"~~~",len(res),"~~~",len(x + res))
#return (x[0] + res[0], x[1]+res[1])
return x
实验结果: