之前的博文《论文阅读笔记之——《Multi-level Wavelet-CNN for Image Restoration》及基于pytorch的复现》曾经研究过WMCNN。本博文就是采用DWT变换代替octave中的pooling
代码
类似于博文《实验笔记之——octave conv (without pooling)》对octave layer的结构进行改进如下:
pytorch中实现离散小波变换
https://github.com/fbcotter/pytorch_wavelets
git clone https://github.com/fbcotter/pytorch_wavelets
cd pytorch_wavelets
pip install .
pip install -r tests/requirements.txt
测试
改修代码如下:
##################################################################################
##################################################################################
##################################################################################
#DWT octave
# Block for OctConv
####################
class DWT_OctaveConv(nn.Module):
def __init__(self, in_nc, out_nc, kernel_size, alpha=0.5, stride=1, dilation=1, groups=1, \
bias=True, pad_type='zero', norm_type=None, act_type='prelu', mode='CNA'):
super(DWT_OctaveConv, self).__init__()
assert mode in ['CNA', 'NAC', 'CNAC'], 'Wong conv mode [{:s}]'.format(mode)
padding = get_valid_padding(kernel_size, dilation) if pad_type == 'zero' else 0
self.xfm = DTCWTForward(J=3, biort='near_sym_b', qshift='qshift_b').cuda()
self.ifm = DTCWTInverse( biort='near_sym_b', qshift='qshift_b').cuda()
self.stride = stride
self.l2l = nn.Conv2d(int(alpha * in_nc), int(alpha * out_nc),
kernel_size, 1, padding, dilation, groups, bias)
self.l2h = nn.Conv2d(int(alpha * in_nc), out_nc - int(alpha * out_nc),
kernel_size, 1, padding, dilation, groups, bias)
self.h2l = nn.Conv2d(in_nc - int(alpha * in_nc), int(alpha * out_nc),
kernel_size, 1, padding, dilation, groups, bias)
self.h2h = nn.Conv2d(in_nc - int(alpha * in_nc), out_nc - int(alpha * out_nc),
kernel_size, 1, padding, dilation, groups, bias)
self.a = act(act_type) if act_type else None
self.n_h = norm(norm_type, int(out_nc*(1 - alpha))) if norm_type else None
self.n_l = norm(norm_type, int(out_nc*alpha)) if norm_type else None
def forward(self, x):
X_h, X_l = x
#if self.stride ==2:
#X_h, X_l = self.h2g_pool(X_h), self.h2g_pool(X_l)
X_l,X_ll=self.xfm(X_l)
X_h2h = self.h2h(X_h)
#X_l2h = self.upsample(self.l2h(X_l))
#X_l2h = self.l2h(X_l)
#=self.l2h(X_l)
# print(X_l.shape,"~~~~",X_ll[0].shape,X_ll[1].shape,X_ll[2].shape)
# exit()
X_l2ha = self.ifm((X_l,X_ll))
X_l2h = self.l2h(X_l2ha)
# print(X_l.shape,"~~~~",X_ll[0].shape)
# exit()
# X_l2h = self.ifm((X_l,X_ll))
X_l2l = self.l2l(X_l2ha)
X_h2l = self.h2l(X_h)
X_h2l,X_h2ll=self.xfm(X_h2l)
X_l2l,X_lla=self.xfm(X_l2l)
#print(X_lla[0].shape,"~~~~",X_h2ll[0].shape)
X_h = X_l2h + X_h2h
X_l = X_h2l + X_l2l
# print(X_lla[0].shape,"~~~~",X_h2ll[0].shape)
#exit()
X_ll[0]=X_lla[0]+X_h2ll[0]
X_ll[1]=X_lla[1]+X_h2ll[1]
X_ll[2]=X_lla[2]+X_h2ll[2]
X_l=self.ifm((X_l,X_ll))
if self.n_h and self.n_l:
X_h = self.n_h(X_h)
X_l = self.n_l(X_l)
if self.a:
X_h = self.a(X_h)
X_l = self.a(X_l)
return X_h, X_l
class DWT_FirstOctaveConv(nn.Module):
def __init__(self, in_nc, out_nc, kernel_size, alpha=0.5, stride=1, dilation=1, groups=1, \
bias=True, pad_type='zero', norm_type=None, act_type='prelu', mode='CNA'):
super(DWT_FirstOctaveConv, self).__init__()
assert mode in ['CNA', 'NAC', 'CNAC'], 'Wong conv mode [{:s}]'.format(mode)
padding = get_valid_padding(kernel_size, dilation) if pad_type == 'zero' else 0
#self.h2g_pool = nn.AvgPool2d(kernel_size=(2, 2), stride=2)
#self.h2g_pool2 = nn.AvgPool2d(kernel_size=(2, 2), stride=2)
#self.xfm = DTCWTForward(J=3, biort='near_sym_b', qshift='qshift_b').cuda()
#self.ifm = DTCWTInverse(biort='near_sym_b', qshift='qshift_b').cuda()
self.stride = stride
###low frequency
self.h2l = nn.Conv2d(in_nc, int(alpha * out_nc),
kernel_size, 1, padding, dilation, groups, bias)
###high frequency
self.h2h = nn.Conv2d(in_nc, out_nc - int(alpha * out_nc),
kernel_size, 1, padding, dilation, groups, bias)
self.a = act(act_type) if act_type else None
self.n_h = norm(norm_type, int(out_nc*(1 - alpha))) if norm_type else None
self.n_l = norm(norm_type, int(out_nc*alpha)) if norm_type else None
def forward(self, x):
#if self.stride ==2:
#x = self.h2g_pool(x)
X_h = self.h2h(x)
X_l = self.h2l(x)
if self.n_h and self.n_l:##batch norm
X_h = self.n_h(X_h)
X_l = self.n_l(X_l)
if self.a:#Activation layer
X_h = self.a(X_h)
X_l = self.a(X_l)
return X_h, X_l
class DWT_LastOctaveConv(nn.Module):
def __init__(self, in_nc, out_nc, kernel_size, alpha=0.5, stride=1, dilation=1, groups=1, \
bias=True, pad_type='zero', norm_type=None, act_type='prelu', mode='CNA'):
super(DWT_LastOctaveConv, self).__init__()
assert mode in ['CNA', 'NAC', 'CNAC'], 'Wong conv mode [{:s}]'.format(mode)
padding = get_valid_padding(kernel_size, dilation) if pad_type == 'zero' else 0
#self.upsample = nn.Upsample(scale_factor=2, mode='nearest')
#self.upsample = nn.Upsample(scale_factor=4, mode='nearest')##double pool
#self.ifm = DTCWTInverse(biort='near_sym_b', qshift='qshift_b').cuda()
self.stride = stride
self.l2h = nn.Conv2d(int(alpha * in_nc), out_nc,
kernel_size, 1, padding, dilation, groups, bias)
self.h2h = nn.Conv2d(in_nc - int(alpha * in_nc), out_nc,
kernel_size, 1, padding, dilation, groups, bias)
self.a = act(act_type) if act_type else None
self.n_h = norm(norm_type, out_nc) if norm_type else None
def forward(self, x):
X_h, X_l = x
#if self.stride ==2:
#X_h, X_l = self.h2g_pool(X_h), self.h2g_pool(X_l)
X_h2h = self.h2h(X_h)
X_l2h=self.l2h(X_l)
#X_l2h = self.l2h(X_l)
X_h = X_h2h + X_l2h
if self.n_h:
X_h = self.n_h(X_h)
if self.a:
X_h = self.a(X_h)
return X_h
class DWT_octave_ResidualDenseBlockTiny_4C(nn.Module):
'''
Residual Dense Block
style: 4 convs
The core module of paper: (Residual Dense Network for Image Super-Resolution, CVPR 18)
'''
def __init__(self, nc, kernel_size=3, gc=16,alpha=0.5, stride=1, bias=True, pad_type='zero', \
norm_type=None, act_type='leakyrelu', mode='CNA'):
super(DWT_octave_ResidualDenseBlockTiny_4C, self).__init__()
# gc: growth channel, i.e. intermediate channels
self.conv1 =DWT_OctaveConv(nc, gc, kernel_size, alpha, stride, bias=bias, pad_type=pad_type, \
norm_type=norm_type, act_type=act_type, mode=mode)
# conv_block(nc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, \
# norm_type=norm_type, act_type=act_type, mode=mode)
self.conv2 = DWT_OctaveConv(nc+gc, gc, kernel_size, alpha, stride, bias=bias, pad_type=pad_type, \
norm_type=norm_type, act_type=act_type, mode=mode)
# conv_block(nc+gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, \
# norm_type=norm_type, act_type=act_type, mode=mode)
self.conv3 = DWT_OctaveConv(nc+2*gc, gc, kernel_size, alpha, stride, bias=bias, pad_type=pad_type, \
norm_type=norm_type, act_type=act_type, mode=mode)
# conv_block(nc+2*gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, \
# norm_type=norm_type, act_type=act_type, mode=mode)
if mode == 'CNA':
last_act = None
else:
last_act = act_type
self.conv4 = DWT_OctaveConv(nc+3*gc, nc, kernel_size, alpha, stride, bias=bias, pad_type=pad_type, \
norm_type=norm_type, act_type=act_type, mode=mode)
# conv_block(nc+3*gc, nc, 3, stride, bias=bias, pad_type=pad_type, \
# norm_type=norm_type, act_type=last_act, mode=mode)
def forward(self, x):
x1 = self.conv1(x)
x2 = self.conv2((torch.cat((x[0], x1[0]), dim=1),(torch.cat((x[1], x1[1]), dim=1))))
x3 = self.conv3((torch.cat((x[0], x1[0],x2[0]), dim=1),(torch.cat((x[1], x1[1],x2[1]), dim=1))))
x4 = self.conv4((torch.cat((x[0], x1[0],x2[0],x3[0]), dim=1),(torch.cat((x[1], x1[1],x2[1],x3[1]), dim=1))))
res = (x4[0].mul(0.2), x4[1].mul(0.2))
x = (x[0] + res[0], x[1] + res[1])
#print(len(x),"~~~",len(res),"~~~",len(x + res))
#return (x[0] + res[0], x[1]+res[1])
return x
class DWT_octave_RRDBTiny(nn.Module):
'''
Residual in Residual Dense Block
(ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks)
'''
def __init__(self, nc, kernel_size=3, gc=16, stride=1, alpha=0.5, bias=True, pad_type='zero', \
norm_type=None, act_type='leakyrelu', mode='CNA'):
super(DWT_octave_RRDBTiny, self).__init__()
self.RDB1 = DWT_octave_ResidualDenseBlockTiny_4C(nc=nc, kernel_size=kernel_size,alpha=alpha, gc=gc, stride=stride, bias=bias, pad_type=pad_type, \
norm_type=norm_type, act_type=act_type, mode=mode)
self.RDB2 = DWT_octave_ResidualDenseBlockTiny_4C(nc=nc, kernel_size=kernel_size,alpha=alpha, gc=gc, stride=stride, bias=bias, pad_type=pad_type, \
norm_type=norm_type, act_type=act_type, mode=mode)
def forward(self, x):
out = self.RDB1(x)
out = self.RDB2(out)
res = (out[0].mul(0.2), out[1].mul(0.2))
x = (x[0] + res[0], x[1] + res[1])
#print(len(x),"~~~",len(res),"~~~",len(x + res))
#return (x[0] + res[0], x[1]+res[1])
return x
##################this is ESRGAN based on DWT_octave
class DWT_Octave_RRDBNet(nn.Module):
def __init__(self, in_nc, out_nc, nf, nb, gc=32,alpha=0.125, upscale=4, norm_type=None, \
act_type='leakyrelu', mode='CNA', upsample_mode='upconv'):
super(DWT_Octave_RRDBNet, self).__init__()
n_upscale = int(math.log(upscale, 2))
if upscale == 3:
n_upscale = 1
fea_conv1 = B.conv_block(in_nc, nf, kernel_size=3, norm_type=None, act_type=None)
fea_conv = B.DWT_FirstOctaveConv(nf, nf, kernel_size=3,alpha=alpha, stride=1, dilation=1, groups=1, \
bias=True, pad_type='zero', norm_type=None, act_type='prelu', mode='CNA')
rb_blocks = [B.DWT_octave_RRDBTiny(nf, kernel_size=3, gc=32,alpha=alpha,stride=1, bias=True, pad_type='zero', \
norm_type=norm_type, act_type=act_type, mode='CNA') for _ in range(nb)]
LR_conv = B.DWT_LastOctaveConv(nf, nf, kernel_size=3, alpha=alpha, stride=1, dilation=1, groups=1, \
bias=True, pad_type='zero', norm_type=None, act_type='prelu', mode='CNA')
if upsample_mode == 'upconv':
upsample_block = B.upconv_blcok
elif upsample_mode == 'pixelshuffle':
upsample_block = B.pixelshuffle_block
else:
raise NotImplementedError('upsample mode [{:s}] is not found'.format(upsample_mode))
if upscale == 3:
upsampler = upsample_block(nf, nf, 3, act_type=act_type)
else:
upsampler = [upsample_block(nf, nf, act_type=act_type) for _ in range(n_upscale)]
HR_conv0 = B.conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=act_type)
HR_conv1 = B.conv_block(nf, out_nc, kernel_size=3, norm_type=None, act_type=None)
self.model = B.sequential(fea_conv1,B.ShortcutBlock(B.sequential(fea_conv,*rb_blocks, LR_conv)),\
*upsampler, HR_conv0, HR_conv1)
def forward(self, x):
x = self.model(x)
return x
##############################################################################################
实验
改进
通过测试有新发现
https://blog.csdn.net/qq_40587575/article/details/83154042
从上面测试可以看出,只需要J=1
改进代码如下:
##################################################################################
##################################################################################
##################################################################################
#DWT octave
# Block for OctConv
####################
class DWT_OctaveConv(nn.Module):
def __init__(self, in_nc, out_nc, kernel_size, alpha=0.5, stride=1, dilation=1, groups=1, \
bias=True, pad_type='zero', norm_type=None, act_type='prelu', mode='CNA'):
super(DWT_OctaveConv, self).__init__()
assert mode in ['CNA', 'NAC', 'CNAC'], 'Wong conv mode [{:s}]'.format(mode)
padding = get_valid_padding(kernel_size, dilation) if pad_type == 'zero' else 0
#self.xfm = DWTForward(J=1, wave='db3', mode='zero')
#self.ifm = DWTInverse(wave='db3', mode='zero')
self.stride = stride
# self.l2l = nn.Conv2d(int(alpha * in_nc), int(alpha * out_nc),
# kernel_size, 1, padding, dilation, groups, bias)
# self.l2h = nn.Conv2d(int(alpha * in_nc), out_nc - int(alpha * out_nc),
# kernel_size, 1, padding, dilation, groups, bias)
# self.h2l = nn.Conv2d(in_nc - int(alpha * in_nc), int(alpha * out_nc),
# kernel_size, 1, padding, dilation, groups, bias)
# self.h2h = nn.Conv2d(in_nc - int(alpha * in_nc), out_nc - int(alpha * out_nc),
# kernel_size, 1, padding, dilation, groups, bias)
self.l2l = nn.Conv2d(in_nc, out_nc,
kernel_size, 1, padding, dilation, groups, bias)
self.l2h = nn.Conv2d(in_nc, out_nc,
kernel_size, 1, padding, dilation, groups, bias)
self.h2l = nn.Conv2d(in_nc, out_nc,
kernel_size, 1, padding, dilation, groups, bias)
self.h2h = nn.Conv2d(in_nc, out_nc,
kernel_size, 1, padding, dilation, groups, bias)
self.a = act(act_type) if act_type else None
self.n_h = norm(norm_type, int(out_nc*(1 - alpha))) if norm_type else None
self.n_l = norm(norm_type, int(out_nc*alpha)) if norm_type else None
def forward(self, x):
X_ll,X_lh,X_hl,X_hh = x
# print(X_ll.shape,'~~~',X_lh.shape,'~~~',X_lh.shape,'~~~',X_lh.shape)
# exit()
# A,B=self.xfm(x)
# X_ll=A
# X_lh=B[0][:,:,0]
# X_hl=B[0][:,:,1]
# X_hh=B[0][:,:,2]
#if self.stride ==2:
#X_h, X_l = self.h2g_pool(X_h), self.h2g_pool(X_l)
X_hh2h = self.h2h(X_hh)
#X_l2h = self.upsample(self.l2h(X_l))
X_lh2h = self.l2h(X_lh)
#X_l2h = self.upsample(self.l2h(X_l))
X_ll2l = self.l2l(X_ll)
#X_h2l = self.h2l(self.h2g_pool(X_h))
X_hl2l = self.h2l(X_hl)
#X_h2l = self.h2l(self.h2g_pool2(self.h2g_pool(X_h)))
#print(X_l2h.shape,"~~~~",X_h2h.shape)
X_hh=X_hh2h
X_lh=X_lh2h
X_ll=X_ll2l
X_hl=X_hl2l
if self.n_h and self.n_l:
X_hh= self.n_h(X_hh)
X_hh=self.n_h(X_hh)
X_lh=self.n_h(X_lh)
X_ll=self.n_h(X_ll)
X_hl=self.n_h(X_hl)
if self.a:
X_hh = self.a(X_hh)
X_hh=self.a(X_hh)
X_lh=self.a(X_lh)
X_ll=self.a(X_ll)
X_hl=self.a(X_hl)
# A=X_ll
# B[0][:,:,0]=X_lh
# B[0][:,:,1]=X_hl
# B[0][:,:,2]=X_hh
# x=ifm((A,B))
return X_ll,X_lh,X_hl,X_hh
class DWT_FirstOctaveConv(nn.Module):
def __init__(self, in_nc, out_nc, kernel_size, alpha=0.5, stride=1, dilation=1, groups=1, \
bias=True, pad_type='zero', norm_type=None, act_type='prelu', mode='CNA'):
super(DWT_FirstOctaveConv, self).__init__()
assert mode in ['CNA', 'NAC', 'CNAC'], 'Wong conv mode [{:s}]'.format(mode)
padding = get_valid_padding(kernel_size, dilation) if pad_type == 'zero' else 0
self.xfm = DWTForward(J=1, wave='db3', mode='zero').cuda()
#self.h2g_pool = nn.AvgPool2d(kernel_size=(2, 2), stride=2)
#self.h2g_pool2 = nn.AvgPool2d(kernel_size=(2, 2), stride=2)
#self.xfm = DTCWTForward(J=3, biort='near_sym_b', qshift='qshift_b').cuda()
#self.ifm = DTCWTInverse(biort='near_sym_b', qshift='qshift_b').cuda()
self.stride = stride
###low frequency
self.h2l = nn.Conv2d(in_nc, int(alpha * out_nc),
kernel_size, 1, padding, dilation, groups, bias)
###high frequency
self.h2h = nn.Conv2d(in_nc, out_nc - int(alpha * out_nc),
kernel_size, 1, padding, dilation, groups, bias)
self.a = act(act_type) if act_type else None
self.n_h = norm(norm_type, int(out_nc*(1 - alpha))) if norm_type else None
self.n_l = norm(norm_type, int(out_nc*alpha)) if norm_type else None
def forward(self, x):
#if self.stride ==2:
#x = self.h2g_pool(x)
#X_h = self.h2h(x)
#X_l = self.h2l(x)
A,B=self.xfm(x)
X_ll=A
X_lh=B[0][:,:,0]
X_hl=B[0][:,:,1]
X_hh=B[0][:,:,2]
# if self.n_h and self.n_l:##batch norm
# X_h = self.n_h(X_h)
# X_l = self.n_l(X_l)
# if self.a:#Activation layer
# X_h = self.a(X_h)
# X_l = self.a(X_l)
return X_ll,X_lh,X_hl,X_hh
class DWT_LastOctaveConv(nn.Module):
def __init__(self, in_nc, out_nc, kernel_size, alpha=0.5, stride=1, dilation=1, groups=1, \
bias=True, pad_type='zero', norm_type=None, act_type='prelu', mode='CNA'):
super(DWT_LastOctaveConv, self).__init__()
assert mode in ['CNA', 'NAC', 'CNAC'], 'Wong conv mode [{:s}]'.format(mode)
padding = get_valid_padding(kernel_size, dilation) if pad_type == 'zero' else 0
#self.upsample = nn.Upsample(scale_factor=2, mode='nearest')
#self.upsample = nn.Upsample(scale_factor=4, mode='nearest')##double pool
#self.ifm = DTCWTInverse(biort='near_sym_b', qshift='qshift_b').cuda()
self.ifm = DWTInverse(wave='db3', mode='zero').cuda()
self.stride = stride
self.l2h = nn.Conv2d(int(alpha * in_nc), out_nc,
kernel_size, 1, padding, dilation, groups, bias)
self.h2h = nn.Conv2d(in_nc - int(alpha * in_nc), out_nc,
kernel_size, 1, padding, dilation, groups, bias)
self.a = act(act_type) if act_type else None
self.n_h = norm(norm_type, out_nc) if norm_type else None
def forward(self, x):
X_ll,X_lh,X_hl,X_hh = x
#if self.stride ==2:
#X_h, X_l = self.h2g_pool(X_h), self.h2g_pool(X_l)
A=X_ll
c= A.shape
C = torch.randn(c[0], c[1], 3, c[-2], c[-1])
#X_lh_ = torch.unsqueeze(X_lh, 2)
#X_hl_ = torch.unsqueeze(X_hl, 2)
#X_hh_ = torch.unsqueeze(X_hh, 2)
#C=torch.cat((X_lh_,X_hl_,X_hh_), dim=2)
C[:,:,0]=X_lh
C[:,:,1]=X_hl
C[:,:,2]=X_hh
#C=C.cpu()
C_ = [C.cuda()]
#A=A.cpu()
X_h=self.ifm((A,C_))
# print(X_h.shape)
# exit()
# X_h2h = self.h2h(X_h)
# X_l2h=self.l2h(X_l)
# #X_l2h = self.l2h(X_l)
# X_h = X_h2h + X_l2h
# if self.n_h:
# X_h = self.n_h(X_h)
# if self.a:
# X_h = self.a(X_h)
return X_h
class DWT_octave_ResidualDenseBlockTiny_4C(nn.Module):
'''
Residual Dense Block
style: 4 convs
The core module of paper: (Residual Dense Network for Image Super-Resolution, CVPR 18)
'''
def __init__(self, nc, kernel_size=3, gc=16,alpha=0.5, stride=1, bias=True, pad_type='zero', \
norm_type=None, act_type='leakyrelu', mode='CNA'):
super(DWT_octave_ResidualDenseBlockTiny_4C, self).__init__()
# gc: growth channel, i.e. intermediate channels
self.conv1 =DWT_OctaveConv(nc, gc, kernel_size, alpha, stride, bias=bias, pad_type=pad_type, \
norm_type=norm_type, act_type=act_type, mode=mode)
# conv_block(nc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, \
# norm_type=norm_type, act_type=act_type, mode=mode)
self.conv2 = DWT_OctaveConv(nc+gc, gc, kernel_size, alpha, stride, bias=bias, pad_type=pad_type, \
norm_type=norm_type, act_type=act_type, mode=mode)
# conv_block(nc+gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, \
# norm_type=norm_type, act_type=act_type, mode=mode)
self.conv3 = DWT_OctaveConv(nc+2*gc, gc, kernel_size, alpha, stride, bias=bias, pad_type=pad_type, \
norm_type=norm_type, act_type=act_type, mode=mode)
# conv_block(nc+2*gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, \
# norm_type=norm_type, act_type=act_type, mode=mode)
if mode == 'CNA':
last_act = None
else:
last_act = act_type
self.conv4 = DWT_OctaveConv(nc+3*gc, nc, kernel_size, alpha, stride, bias=bias, pad_type=pad_type, \
norm_type=norm_type, act_type=act_type, mode=mode)
# conv_block(nc+3*gc, nc, 3, stride, bias=bias, pad_type=pad_type, \
# norm_type=norm_type, act_type=last_act, mode=mode)
def forward(self, x):
# print(x[0].shape,'~~~',x[1].shape,'~~~',x[2].shape,'~~~',x[3].shape)
# exit()
x1 = self.conv1(x)
x2 = self.conv2((torch.cat((x[0], x1[0]), dim=1),torch.cat((x[1], x1[1]), dim=1),torch.cat((x[2], x1[2]), dim=1),torch.cat((x[3], x1[3]), dim=1)))
x3 = self.conv3((torch.cat((x[0], x1[0],x2[0]), dim=1),torch.cat((x[1], x1[1],x2[1]), dim=1),torch.cat((x[2], x1[2],x2[2]), dim=1),torch.cat((x[3], x1[3],x2[3]), dim=1)))
x4 = self.conv4((torch.cat((x[0], x1[0],x2[0],x3[0]), dim=1),torch.cat((x[1], x1[1],x2[1],x3[1]), dim=1),torch.cat((x[2], x1[2],x2[2],x3[2]), dim=1),torch.cat((x[3], x1[3],x2[3],x3[3]), dim=1)))
res = (x4[0].mul(0.2), x4[1].mul(0.2),x4[2].mul(0.2),x4[3].mul(0.2))
x = (x[0] + res[0], x[1] + res[1], x[2] + res[2], x[3] + res[3])
#print(len(x),"~~~",len(res),"~~~",len(x + res))
#return (x[0] + res[0], x[1]+res[1])
return x
class DWT_octave_RRDBTiny(nn.Module):
'''
Residual in Residual Dense Block
(ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks)
'''
def __init__(self, nc, kernel_size=3, gc=16, stride=1, alpha=0.5, bias=True, pad_type='zero', \
norm_type=None, act_type='leakyrelu', mode='CNA'):
super(DWT_octave_RRDBTiny, self).__init__()
self.RDB1 = DWT_octave_ResidualDenseBlockTiny_4C(nc=nc, kernel_size=kernel_size,alpha=alpha, gc=gc, stride=stride, bias=bias, pad_type=pad_type, \
norm_type=norm_type, act_type=act_type, mode=mode)
self.RDB2 = DWT_octave_ResidualDenseBlockTiny_4C(nc=nc, kernel_size=kernel_size,alpha=alpha, gc=gc, stride=stride, bias=bias, pad_type=pad_type, \
norm_type=norm_type, act_type=act_type, mode=mode)
def forward(self, x):
out = self.RDB1(x)
out = self.RDB2(out)
res = (out[0].mul(0.2), out[1].mul(0.2),out[2].mul(0.2),out[3].mul(0.2))
x = (x[0] + res[0], x[1] + res[1], x[2] + res[2], x[3] + res[3])
#print(len(x),"~~~",len(res),"~~~",len(x + res))
#return (x[0] + res[0], x[1]+res[1])
return x
实验结果:
改进2
##################################################################################
##################################################################################
##################################################################################
#DWT octave
# Block for OctConv
####################
class DWT_OctaveConv(nn.Module):
def __init__(self, in_nc, out_nc, kernel_size, alpha=0.5, stride=1, dilation=1, groups=1, \
bias=True, pad_type='zero', norm_type=None, act_type='prelu', mode='CNA'):
super(DWT_OctaveConv, self).__init__()
assert mode in ['CNA', 'NAC', 'CNAC'], 'Wong conv mode [{:s}]'.format(mode)
padding = get_valid_padding(kernel_size, dilation) if pad_type == 'zero' else 0
self.xfm = DWTForward(J=1, wave='db3', mode='zero').cuda()
self.ifm = DWTInverse(wave='db3', mode='zero').cuda()
self.stride = stride
self.l2l = nn.Conv2d(int(alpha * in_nc), int(alpha * out_nc),
kernel_size, 1, padding, dilation, groups, bias)
self.l2h = nn.Conv2d(int(alpha * in_nc), out_nc - int(alpha * out_nc),
kernel_size, 1, padding, dilation, groups, bias)
self.h2l = nn.Conv2d(in_nc - int(alpha * in_nc), int(alpha * out_nc),
kernel_size, 1, padding, dilation, groups, bias)
self.h2h = nn.Conv2d(in_nc - int(alpha * in_nc), out_nc - int(alpha * out_nc),
kernel_size, 1, padding, dilation, groups, bias)
self.a = act(act_type) if act_type else None
self.n_h = norm(norm_type, int(out_nc*(1 - alpha))) if norm_type else None
self.n_l = norm(norm_type, int(out_nc*alpha)) if norm_type else None
def forward(self, x):
X_h, X_l = x
# print(X_ll.shape,'~~~',X_lh.shape,'~~~',X_lh.shape,'~~~',X_lh.shape)
# exit()
# A,B=self.xfm(x)
# X_ll=A
# X_lh=B[0][:,:,0]
# X_hl=B[0][:,:,1]
# X_hh=B[0][:,:,2]
#if self.stride ==2:
#X_h, X_l = self.h2g_pool(X_h), self.h2g_pool(X_l)
X_h2h = self.h2h(X_h)
#X_l2h = self.upsample(self.l2h(X_l))
X_l2h = self.l2h(X_l)
#X_l2h = self.upsample(self.l2h(X_l))
#DWT for X_h
A,B=self.xfm(X_h)
X_hll=A
X_hlh=B[0][:,:,0]
X_hhl=B[0][:,:,1]
X_hhh=B[0][:,:,2]
#transfer
X_hll2l=self.h2l(X_hll)
X_hlh2l=self.h2l(X_hlh)
X_hhl2l=self.h2l(X_hhl)
X_hhh2l=self.h2l(X_hhh)
#DWT for X_l
C,D=self.xfm(X_l)
X_lll=C
X_llh=D[0][:,:,0]
X_lhl=D[0][:,:,1]
X_lhh=D[0][:,:,2]
#transfer
X_lll2l=self.l2l(X_lll)
X_llh2l=self.l2l(X_llh)
X_lhl2l=self.l2l(X_lhl)
X_lhh2l=self.l2l(X_lhh)
#X_ll2l = self.l2l(X_ll)
#X_h2l = self.h2l(self.h2g_pool(X_h))
#X_hl2l = self.h2l(X_hl)
#X_h2l = self.h2l(self.h2g_pool2(self.h2g_pool(X_h)))
#print(X_l2h.shape,"~~~~",X_h2h.shape)
X_h=X_h2h+X_l2h
E=X_lll2l+X_hll2l
f= E.shape
F = torch.randn(f[0], f[1], 3, f[-2], f[-1])
F[:,:,0]=X_llh2l+X_hlh2l
F[:,:,1]=X_lhl2l+X_hhl2l
F[:,:,2]=X_lhh2l+X_hhh2l
F_ = [F.cuda()]
X_l=self.ifm((E,F_))
if self.n_h and self.n_l:
X_h = self.n_h(X_h)
X_l = self.n_l(X_l)
if self.a:
X_h = self.a(X_h)
X_l = self.a(X_l)
return X_h, X_l
class DWT_FirstOctaveConv(nn.Module):
def __init__(self, in_nc, out_nc, kernel_size, alpha=0.5, stride=1, dilation=1, groups=1, \
bias=True, pad_type='zero', norm_type=None, act_type='prelu', mode='CNA'):
super(DWT_FirstOctaveConv, self).__init__()
assert mode in ['CNA', 'NAC', 'CNAC'], 'Wong conv mode [{:s}]'.format(mode)
padding = get_valid_padding(kernel_size, dilation) if pad_type == 'zero' else 0
self.xfm = DWTForward(J=1, wave='db3', mode='zero').cuda()
#self.h2g_pool = nn.AvgPool2d(kernel_size=(2, 2), stride=2)
#self.h2g_pool2 = nn.AvgPool2d(kernel_size=(2, 2), stride=2)
#self.xfm = DTCWTForward(J=3, biort='near_sym_b', qshift='qshift_b').cuda()
#self.ifm = DTCWTInverse(biort='near_sym_b', qshift='qshift_b').cuda()
self.stride = stride
###low frequency
self.h2l = nn.Conv2d(in_nc, int(alpha * out_nc),
kernel_size, 1, padding, dilation, groups, bias)
###high frequency
self.h2h = nn.Conv2d(in_nc, out_nc - int(alpha * out_nc),
kernel_size, 1, padding, dilation, groups, bias)
self.a = act(act_type) if act_type else None
self.n_h = norm(norm_type, int(out_nc*(1 - alpha))) if norm_type else None
self.n_l = norm(norm_type, int(out_nc*alpha)) if norm_type else None
def forward(self, x):
#if self.stride ==2:
#x = self.h2g_pool(x)
X_h = self.h2h(x)
X_l = self.h2l(x)
if self.n_h and self.n_l:##batch norm
X_h = self.n_h(X_h)
X_l = self.n_l(X_l)
if self.a:#Activation layer
X_h = self.a(X_h)
X_l = self.a(X_l)
return X_h,X_l
class DWT_LastOctaveConv(nn.Module):
def __init__(self, in_nc, out_nc, kernel_size, alpha=0.5, stride=1, dilation=1, groups=1, \
bias=True, pad_type='zero', norm_type=None, act_type='prelu', mode='CNA'):
super(DWT_LastOctaveConv, self).__init__()
assert mode in ['CNA', 'NAC', 'CNAC'], 'Wong conv mode [{:s}]'.format(mode)
padding = get_valid_padding(kernel_size, dilation) if pad_type == 'zero' else 0
#self.upsample = nn.Upsample(scale_factor=2, mode='nearest')
#self.upsample = nn.Upsample(scale_factor=4, mode='nearest')##double pool
#self.ifm = DTCWTInverse(biort='near_sym_b', qshift='qshift_b').cuda()
#self.ifm = DWTInverse(wave='db3', mode='zero').cuda()
self.stride = stride
self.l2h = nn.Conv2d(int(alpha * in_nc), out_nc,
kernel_size, 1, padding, dilation, groups, bias)
self.h2h = nn.Conv2d(in_nc - int(alpha * in_nc), out_nc,
kernel_size, 1, padding, dilation, groups, bias)
self.a = act(act_type) if act_type else None
self.n_h = norm(norm_type, out_nc) if norm_type else None
def forward(self, x):
X_h,X_l = x
#if self.stride ==2:
#X_h, X_l = self.h2g_pool(X_h), self.h2g_pool(X_l)
X_h2h = self.h2h(X_h)
X_l2h=self.l2h(X_l)
#X_l2h = self.l2h(X_l)
X_h = X_h2h + X_l2h
if self.n_h:
X_h = self.n_h(X_h)
if self.a:
X_h = self.a(X_h)
return X_h
class DWT_octave_ResidualDenseBlockTiny_4C(nn.Module):
'''
Residual Dense Block
style: 4 convs
The core module of paper: (Residual Dense Network for Image Super-Resolution, CVPR 18)
'''
def __init__(self, nc, kernel_size=3, gc=16,alpha=0.5, stride=1, bias=True, pad_type='zero', \
norm_type=None, act_type='leakyrelu', mode='CNA'):
super(DWT_octave_ResidualDenseBlockTiny_4C, self).__init__()
# gc: growth channel, i.e. intermediate channels
self.conv1 =DWT_OctaveConv(nc, gc, kernel_size, alpha, stride, bias=bias, pad_type=pad_type, \
norm_type=norm_type, act_type=act_type, mode=mode)
# conv_block(nc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, \
# norm_type=norm_type, act_type=act_type, mode=mode)
self.conv2 = DWT_OctaveConv(nc+gc, gc, kernel_size, alpha, stride, bias=bias, pad_type=pad_type, \
norm_type=norm_type, act_type=act_type, mode=mode)
# conv_block(nc+gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, \
# norm_type=norm_type, act_type=act_type, mode=mode)
self.conv3 = DWT_OctaveConv(nc+2*gc, gc, kernel_size, alpha, stride, bias=bias, pad_type=pad_type, \
norm_type=norm_type, act_type=act_type, mode=mode)
# conv_block(nc+2*gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, \
# norm_type=norm_type, act_type=act_type, mode=mode)
if mode == 'CNA':
last_act = None
else:
last_act = act_type
self.conv4 = DWT_OctaveConv(nc+3*gc, nc, kernel_size, alpha, stride, bias=bias, pad_type=pad_type, \
norm_type=norm_type, act_type=act_type, mode=mode)
# conv_block(nc+3*gc, nc, 3, stride, bias=bias, pad_type=pad_type, \
# norm_type=norm_type, act_type=last_act, mode=mode)
def forward(self, x):
# print(x[0].shape,'~~~',x[1].shape,'~~~',x[2].shape,'~~~',x[3].shape)
# exit()
x1 = self.conv1(x)
x2 = self.conv2((torch.cat((x[0], x1[0]), dim=1),torch.cat((x[1], x1[1]), dim=1)))
x3 = self.conv3((torch.cat((x[0], x1[0],x2[0]), dim=1),torch.cat((x[1], x1[1],x2[1]), dim=1)))
x4 = self.conv4((torch.cat((x[0], x1[0],x2[0],x3[0]), dim=1),torch.cat((x[1], x1[1],x2[1],x3[1]), dim=1)))
res = (x4[0].mul(0.2), x4[1].mul(0.2))
x = (x[0] + res[0], x[1] + res[1])
#print(len(x),"~~~",len(res),"~~~",len(x + res))
#return (x[0] + res[0], x[1]+res[1])
return x
class DWT_octave_RRDBTiny(nn.Module):
'''
Residual in Residual Dense Block
(ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks)
'''
def __init__(self, nc, kernel_size=3, gc=16, stride=1, alpha=0.5, bias=True, pad_type='zero', \
norm_type=None, act_type='leakyrelu', mode='CNA'):
super(DWT_octave_RRDBTiny, self).__init__()
self.RDB1 = DWT_octave_ResidualDenseBlockTiny_4C(nc=nc, kernel_size=kernel_size,alpha=alpha, gc=gc, stride=stride, bias=bias, pad_type=pad_type, \
norm_type=norm_type, act_type=act_type, mode=mode)
self.RDB2 = DWT_octave_ResidualDenseBlockTiny_4C(nc=nc, kernel_size=kernel_size,alpha=alpha, gc=gc, stride=stride, bias=bias, pad_type=pad_type, \
norm_type=norm_type, act_type=act_type, mode=mode)
def forward(self, x):
out = self.RDB1(x)
out = self.RDB2(out)
res = (out[0].mul(0.2), out[1].mul(0.2))
x = (x[0] + res[0], x[1] + res[1])
#print(len(x),"~~~",len(res),"~~~",len(x + res))
#return (x[0] + res[0], x[1]+res[1])
return x
结果
改进3
##################################################################################
##################################################################################
##################################################################################
#DWT octave
# Block for OctConv
####################
class DWT_OctaveConv(nn.Module):
def __init__(self, in_nc, out_nc, kernel_size, alpha=0.5, stride=1, dilation=1, groups=1, \
bias=True, pad_type='zero', norm_type=None, act_type='prelu', mode='CNA'):
super(DWT_OctaveConv, self).__init__()
assert mode in ['CNA', 'NAC', 'CNAC'], 'Wong conv mode [{:s}]'.format(mode)
padding = get_valid_padding(kernel_size, dilation) if pad_type == 'zero' else 0
self.xfm = DWTForward(J=1, wave='db3', mode='zero').cuda()
self.ifm = DWTInverse(wave='db3', mode='zero').cuda()
self.stride = stride
self.l2l = nn.Conv2d(int(alpha * in_nc), int(alpha * out_nc),
kernel_size, 1, padding, dilation, groups, bias)
self.l2h = nn.Conv2d(int(alpha * in_nc), out_nc - int(alpha * out_nc),
kernel_size, 1, padding, dilation, groups, bias)
self.h2l = nn.Conv2d(in_nc - int(alpha * in_nc), int(alpha * out_nc),
kernel_size, 1, padding, dilation, groups, bias)
self.h2h = nn.Conv2d(in_nc - int(alpha * in_nc), out_nc - int(alpha * out_nc),
kernel_size, 1, padding, dilation, groups, bias)
self.a = act(act_type) if act_type else None
self.n_h = norm(norm_type, int(out_nc*(1 - alpha))) if norm_type else None
self.n_l = norm(norm_type, int(out_nc*alpha)) if norm_type else None
def forward(self, x):
X_h,X_l_ll,X_l_lh,X_l_hl,X_l_hh = x
#for X_h to h
X_h2h = self.h2h(X_h)
#X_l2h = self.upsample(self.l2h(X_l))
#get X_l
H=X_l_ll
j= H.shape
J = torch.randn(j[0], j[1], 3, j[-2], j[-1])
J[:,:,0]=X_l_lh
J[:,:,1]=X_l_hl
J[:,:,2]=X_l_hh
J_ = [J.cuda()]
X_l=self.ifm((H,J_))
#X_l to h
X_l2h = self.l2h(X_l)
#DWT for X_h
A,B=self.xfm(X_h)
X_hll=A
X_hlh=B[0][:,:,0]
X_hhl=B[0][:,:,1]
X_hhh=B[0][:,:,2]
#transfer
X_hll2l=self.h2l(X_hll)
X_hlh2l=self.h2l(X_hlh)
X_hhl2l=self.h2l(X_hhl)
X_hhh2l=self.h2l(X_hhh)
#for X_l series (X_l_ll,X_l_lh,X_l_hl,X_l_hh)
#transfer
X_lll2l=self.l2l(X_l_ll)
X_llh2l=self.l2l(X_l_lh)
X_lhl2l=self.l2l(X_l_hl)
X_lhh2l=self.l2l(X_l_hh)
#for X_h
X_h=X_h2h+X_l2h
#for X_l series (X_l_ll,X_l_lh,X_l_hl,X_l_hh)
X_l_ll=X_lll2l+X_hll2l
X_l_lh=X_llh2l+X_hlh2l
X_l_hl=X_lhl2l+X_hhl2l
X_l_hh=X_lhh2l+X_hhh2l
if self.n_h and self.n_l:
X_h = self.n_h(X_h)
#X_l = self.n_l(X_l)
X_l_ll = self.n_l(X_l_ll)
X_l_lh = self.n_l(X_l_lh)
X_l_hl = self.n_l(X_l_hl)
X_l_hh = self.n_l(X_l_hh)
if self.a:
X_h = self.a(X_h)
#X_l = self.a(X_l)
X_l_ll = self.a(X_l_ll)
X_l_lh = self.a(X_l_lh)
X_l_hl = self.a(X_l_hl)
X_l_hh = self.a(X_l_hh)
return X_h,X_l_ll,X_l_lh,X_l_hl,X_l_hh
class DWT_FirstOctaveConv(nn.Module):
def __init__(self, in_nc, out_nc, kernel_size, alpha=0.5, stride=1, dilation=1, groups=1, \
bias=True, pad_type='zero', norm_type=None, act_type='prelu', mode='CNA'):
super(DWT_FirstOctaveConv, self).__init__()
assert mode in ['CNA', 'NAC', 'CNAC'], 'Wong conv mode [{:s}]'.format(mode)
padding = get_valid_padding(kernel_size, dilation) if pad_type == 'zero' else 0
self.xfm = DWTForward(J=1, wave='db3', mode='zero').cuda()
#self.h2g_pool = nn.AvgPool2d(kernel_size=(2, 2), stride=2)
#self.h2g_pool2 = nn.AvgPool2d(kernel_size=(2, 2), stride=2)
#self.xfm = DTCWTForward(J=3, biort='near_sym_b', qshift='qshift_b').cuda()
#self.ifm = DTCWTInverse(biort='near_sym_b', qshift='qshift_b').cuda()
self.stride = stride
###low frequency
self.h2l = nn.Conv2d(in_nc, int(alpha * out_nc),
kernel_size, 1, padding, dilation, groups, bias)
###high frequency
self.h2h = nn.Conv2d(in_nc, out_nc - int(alpha * out_nc),
kernel_size, 1, padding, dilation, groups, bias)
self.a = act(act_type) if act_type else None
self.n_h = norm(norm_type, int(out_nc*(1 - alpha))) if norm_type else None
self.n_l = norm(norm_type, int(out_nc*alpha)) if norm_type else None
def forward(self, x):
#if self.stride ==2:
#x = self.h2g_pool(x)
X_h = self.h2h(x)
X_l = self.h2l(x)
A,B=self.xfm(X_l)
X_l_ll=A
X_l_lh=B[0][:,:,0]
X_l_hl=B[0][:,:,1]
X_l_hh=B[0][:,:,2]
if self.n_h and self.n_l:##batch norm
X_h = self.n_h(X_h)
#X_l = self.n_l(X_l)
X_l_ll = self.n_l(X_l_ll)
X_l_lh = self.n_l(X_l_lh)
X_l_hl = self.n_l(X_l_hl)
X_l_hh = self.n_l(X_l_hh)
if self.a:#Activation layer
X_h = self.a(X_h)
#X_l = self.a(X_l)
X_l_ll = self.a(X_l_ll)
X_l_lh = self.a(X_l_lh)
X_l_hl = self.a(X_l_hl)
X_l_hh = self.a(X_l_hh)
return X_h,X_l_ll,X_l_lh,X_l_hl,X_l_hh
class DWT_LastOctaveConv(nn.Module):
def __init__(self, in_nc, out_nc, kernel_size, alpha=0.5, stride=1, dilation=1, groups=1, \
bias=True, pad_type='zero', norm_type=None, act_type='prelu', mode='CNA'):
super(DWT_LastOctaveConv, self).__init__()
assert mode in ['CNA', 'NAC', 'CNAC'], 'Wong conv mode [{:s}]'.format(mode)
padding = get_valid_padding(kernel_size, dilation) if pad_type == 'zero' else 0
#self.upsample = nn.Upsample(scale_factor=2, mode='nearest')
#self.upsample = nn.Upsample(scale_factor=4, mode='nearest')##double pool
self.ifm = DWTInverse(wave='db3', mode='zero').cuda()
self.stride = stride
self.l2h = nn.Conv2d(int(alpha * in_nc), out_nc,
kernel_size, 1, padding, dilation, groups, bias)
self.h2h = nn.Conv2d(in_nc - int(alpha * in_nc), out_nc,
kernel_size, 1, padding, dilation, groups, bias)
self.a = act(act_type) if act_type else None
self.n_h = norm(norm_type, out_nc) if norm_type else None
def forward(self, x):
X_h,X_l_ll,X_l_lh,X_l_hl,X_l_hh = x
#if self.stride ==2:
#X_h, X_l = self.h2g_pool(X_h), self.h2g_pool(X_l)
A=X_l_ll
c= A.shape
C = torch.randn(c[0], c[1], 3, c[-2], c[-1])
#X_lh_ = torch.unsqueeze(X_lh, 2)
#X_hl_ = torch.unsqueeze(X_hl, 2)
#X_hh_ = torch.unsqueeze(X_hh, 2)
#C=torch.cat((X_lh_,X_hl_,X_hh_), dim=2)
C[:,:,0]=X_l_lh
C[:,:,1]=X_l_hl
C[:,:,2]=X_l_hh
#C=C.cpu()
C_ = [C.cuda()]
#A=A.cpu()
X_l=self.ifm((A,C_))
X_h2h = self.h2h(X_h)
X_l2h=self.l2h(X_l)
#X_l2h = self.l2h(X_l)
X_h = X_h2h + X_l2h
if self.n_h:
X_h = self.n_h(X_h)
if self.a:
X_h = self.a(X_h)
return X_h
class DWT_octave_ResidualDenseBlockTiny_4C(nn.Module):
'''
Residual Dense Block
style: 4 convs
The core module of paper: (Residual Dense Network for Image Super-Resolution, CVPR 18)
'''
def __init__(self, nc, kernel_size=3, gc=16,alpha=0.5, stride=1, bias=True, pad_type='zero', \
norm_type=None, act_type='leakyrelu', mode='CNA'):
super(DWT_octave_ResidualDenseBlockTiny_4C, self).__init__()
# gc: growth channel, i.e. intermediate channels
self.conv1 =DWT_OctaveConv(nc, gc, kernel_size, alpha, stride, bias=bias, pad_type=pad_type, \
norm_type=norm_type, act_type=act_type, mode=mode)
# conv_block(nc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, \
# norm_type=norm_type, act_type=act_type, mode=mode)
self.conv2 = DWT_OctaveConv(nc+gc, gc, kernel_size, alpha, stride, bias=bias, pad_type=pad_type, \
norm_type=norm_type, act_type=act_type, mode=mode)
# conv_block(nc+gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, \
# norm_type=norm_type, act_type=act_type, mode=mode)
self.conv3 = DWT_OctaveConv(nc+2*gc, gc, kernel_size, alpha, stride, bias=bias, pad_type=pad_type, \
norm_type=norm_type, act_type=act_type, mode=mode)
# conv_block(nc+2*gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, \
# norm_type=norm_type, act_type=act_type, mode=mode)
if mode == 'CNA':
last_act = None
else:
last_act = act_type
self.conv4 = DWT_OctaveConv(nc+3*gc, nc, kernel_size, alpha, stride, bias=bias, pad_type=pad_type, \
norm_type=norm_type, act_type=act_type, mode=mode)
# conv_block(nc+3*gc, nc, 3, stride, bias=bias, pad_type=pad_type, \
# norm_type=norm_type, act_type=last_act, mode=mode)
def forward(self, x):
# print(x[0].shape,'~~~',x[1].shape,'~~~',x[2].shape,'~~~',x[3].shape)
# exit()
x1 = self.conv1(x)
x2 = self.conv2((torch.cat((x[0], x1[0]), dim=1),torch.cat((x[1], x1[1]), dim=1),torch.cat((x[2], x1[2]), dim=1),torch.cat((x[3], x1[3]), dim=1),torch.cat((x[4], x1[4]), dim=1)))
x3 = self.conv3((torch.cat((x[0], x1[0],x2[0]), dim=1),torch.cat((x[1], x1[1],x2[1]), dim=1),torch.cat((x[2], x1[2],x2[2]), dim=1),torch.cat((x[3], x1[3],x2[3]), dim=1),torch.cat((x[4], x1[4],x2[4]), dim=1)))
x4 = self.conv4((torch.cat((x[0], x1[0],x2[0],x3[0]), dim=1),torch.cat((x[1], x1[1],x2[1],x3[1]), dim=1),torch.cat((x[2], x1[2],x2[2],x3[2]), dim=1),torch.cat((x[3], x1[3],x2[3],x3[3]), dim=1),torch.cat((x[4], x1[4],x2[4],x3[4]), dim=1)))
res = (x4[0].mul(0.2), x4[1].mul(0.2),x4[2].mul(0.2),x4[3].mul(0.2),x4[4].mul(0.2))
x = (x[0] + res[0], x[1] + res[1], x[2] + res[2], x[3] + res[3], x[4] + res[4])
#print(len(x),"~~~",len(res),"~~~",len(x + res))
#return (x[0] + res[0], x[1]+res[1])
return x
class DWT_octave_RRDBTiny(nn.Module):
'''
Residual in Residual Dense Block
(ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks)
'''
def __init__(self, nc, kernel_size=3, gc=16, stride=1, alpha=0.5, bias=True, pad_type='zero', \
norm_type=None, act_type='leakyrelu', mode='CNA'):
super(DWT_octave_RRDBTiny, self).__init__()
self.RDB1 = DWT_octave_ResidualDenseBlockTiny_4C(nc=nc, kernel_size=kernel_size,alpha=alpha, gc=gc, stride=stride, bias=bias, pad_type=pad_type, \
norm_type=norm_type, act_type=act_type, mode=mode)
self.RDB2 = DWT_octave_ResidualDenseBlockTiny_4C(nc=nc, kernel_size=kernel_size,alpha=alpha, gc=gc, stride=stride, bias=bias, pad_type=pad_type, \
norm_type=norm_type, act_type=act_type, mode=mode)
def forward(self, x):
out = self.RDB1(x)
out = self.RDB2(out)
res = (out[0].mul(0.2), out[1].mul(0.2),out[2].mul(0.2),out[3].mul(0.2),out[4].mul(0.2))
x = (x[0] + res[0], x[1] + res[1], x[2] + res[2], x[3] + res[3],x[4] + res[4])
#print(len(x),"~~~",len(res),"~~~",len(x + res))
#return (x[0] + res[0], x[1]+res[1])
return x
结果