python train.py -opt options/train/train_sr.json
python3 test.py -opt options/test/test_sr.json
source activate pytorch
tensorboard --logdir tb_logger/ --port 6008
http://172.20.36.203:6008/#scalars
##############################################################################################
#octave_srresnet
class highorder_SRResNet(nn.Module):
def __init__(self, in_nc, out_nc, nf, nb, upscale=4, norm_type='batch', act_type='relu', \
mode='NAC', res_scale=1, upsample_mode='upconv'):
super(highorder_SRResNet, self).__init__()
n_upscale = int(math.log(upscale, 2))
if upscale == 3:
n_upscale = 1
alpha=0.5
#self.fea_conv = B.conv_block(in_nc, nf, kernel_size=3, norm_type=None, act_type='relu', mode='CNA')
self.fea_conv = B.FirstOctaveConv(in_nc, nf, kernel_size=3, alpha=alpha,norm_type=None, act_type='relu', mode='CNA')
# # # ##base line
# fea_conv1 = B.conv_block(nf, nf, kernel_size=3, norm_type=None, act_type='relu', mode='CNA')
# resnet_blocks = [B.ResNetBlock(nf, nf, nf, norm_type=norm_type, act_type=act_type, mode=mode, res_scale=res_scale) for _ in range(nb)]
# LR_conv = B.conv_block(nf, nf, kernel_size=3, norm_type=None, act_type='relu', mode='CNA')
# # # # # # # # #octave layer(Layer_HRLRadd)
# #fea_conv1 = B.FirstOctaveConv(nf, nf, kernel_size=3, alpha=alpha, stride=1, dilation=1, groups=1, bias=True, pad_type='zero', norm_type=None, act_type='relu', mode='CNA')
fea_conv1 = B.OctaveConv(nf, nf, kernel_size=3, alpha=alpha, stride=1, dilation=1, groups=1, bias=True, pad_type='zero', norm_type=None, act_type='relu', mode='CNA')
# # # # # #fea_conv1 = B.conv_block(nf, nf, kernel_size=3, norm_type=None, act_type='relu', mode='CNA')
#resnet_blocks = [B.OctaveResBlock(nf, nf, nf, kernel_size=3,alpha=alpha,norm_type=norm_type, act_type=act_type, mode=mode, res_scale=res_scale) for _ in range(nb)]
# # # # # # # # # # # # ####LRaddHR(keepLR)
#resnet_blocks = [B.LRaddHR_keepLR_OctaveResBlock(nf, nf, nf, kernel_size=3,alpha=alpha,norm_type=norm_type, act_type=act_type, mode=mode, res_scale=res_scale) for _ in range(nb)]
# # # # # # # # # #without connection
#resnet_blocks = [B.withoutconnection_OctaveResBlock(nf, nf, nf, kernel_size=3,alpha=alpha,norm_type=norm_type, act_type=act_type, mode=mode, res_scale=res_scale) for _ in range(nb)]
# # # # ############HR add to LR
resnet_blocks = [B.HRaddLR_keepHR_OctaveResBlock(nf, nf, nf, kernel_size=3,alpha=alpha,norm_type=norm_type, act_type=act_type, mode=mode, res_scale=res_scale) for _ in range(nb)]
# #LR_conv = B.LastOctaveConv(nf, nf, kernel_size=3, alpha=alpha, stride=1, dilation=1, groups=1, bias=True, pad_type='zero', norm_type=None, act_type='relu', mode='CNA')
LR_conv = B.OctaveConv(nf, nf, kernel_size=3, alpha=alpha, stride=1, dilation=1, groups=1, bias=True, pad_type='zero', norm_type=None, act_type='relu', mode='CNA')
# # # # # #LR_conv = B.conv_block(nf, nf, kernel_size=3, norm_type=None, act_type='relu', mode='CNA')
# two branch(Network_HRLRadd)
# self.fea_conv1 = B.FirstOctaveConv(nf, nf, kernel_size=3, alpha=alpha, stride=1, dilation=1, groups=1, bias=True, pad_type='zero', norm_type=None, act_type='relu', mode='CNA')
# #resnet_blocks = [B.ResNetBlock(nf, nf, nf, norm_type=norm_type, act_type=act_type, mode=mode, res_scale=res_scale) for _ in range(nb)]
# resnet_blocks_LR1 = [B.ResNetBlock(int(alpha*nf), int(alpha*nf), int(alpha*nf), norm_type=norm_type, act_type=act_type, mode=mode, res_scale=res_scale) for _ in range(nb)]
# resnet_blocks_HR1 = [B.ResNetBlock(nf-int(alpha*nf), nf-int(alpha*nf), nf-int(alpha*nf), norm_type=norm_type, act_type=act_type, mode=mode, res_scale=res_scale) for _ in range(nb)]
# self.LR_conv = B.LastOctaveConv(nf, nf, kernel_size=3, alpha=alpha, stride=1, dilation=1, groups=1, bias=True, pad_type='zero', norm_type=None, act_type='relu', mode='CNA')
# ########################################################
# self.fea_conv1 = B.OctaveConv(nf, nf, kernel_size=3, alpha=alpha, stride=1, dilation=1, groups=1, bias=True, pad_type='zero', norm_type=None, act_type='relu', mode='CNA')
# resnet_blocks_LR1 = [B.ResNetBlock(int(alpha*nf), int(alpha*nf), int(alpha*nf), norm_type=norm_type, act_type=act_type, mode=mode, res_scale=res_scale) for _ in range(nb)]
# resnet_blocks_HR1 = [B.ResNetBlock(nf-int(alpha*nf), nf-int(alpha*nf), nf-int(alpha*nf), norm_type=norm_type, act_type=act_type, mode=mode, res_scale=res_scale) for _ in range(nb)]
# self.LR_conv = B.OctaveConv(nf, nf, kernel_size=3, alpha=alpha, stride=1, dilation=1, groups=1, bias=True, pad_type='zero', norm_type=None, act_type='relu', mode='CNA')
# ########################################################
# # # # # ##block
# fea_conv1 = B.FirstOctaveConv(nf, nf, kernel_size=3, alpha=alpha, stride=1, dilation=1, groups=1, bias=True, pad_type='zero', norm_type=None, act_type='relu', mode='CNA')
######################################only output of the RD block
#resnet_blocks = [B.out_HRLRadd_ResNetBlock(nf, nf, nf, alpha=alpha,norm_type=norm_type, act_type=act_type, mode=mode, res_scale=res_scale) for _ in range(nb)]
# # # #resnet_blocks = [B.out_HRLRattention_ResNetBlock(nf, nf, nf, alpha=alpha,norm_type=norm_type, act_type=act_type, mode=mode, res_scale=res_scale) for _ in range(nb)]
# # # # # # ########################only output of the RD block (keepLR)
#resnet_blocks = [B.out_LRaddHR_keepLR_ResNetBlock(nf, nf, nf, alpha=alpha,norm_type=norm_type, act_type=act_type, mode=mode, res_scale=res_scale) for _ in range(nb)]
#######################################only output of the RD block (keepLR)
#resnet_blocks = [B.out_HRaddLR_keepHR_ResNetBlock(nf, nf, nf, alpha=alpha,norm_type=norm_type, act_type=act_type, mode=mode, res_scale=res_scale) for _ in range(nb)]
# # # resnet_blocks = [B.out_LRattentionHR_keepLR_ResNetBlock(nf, nf, nf, alpha=alpha,norm_type=norm_type, act_type=act_type, mode=mode, res_scale=res_scale) for _ in range(nb)]
# LR_conv = B.LastOctaveConv(nf, nf, kernel_size=3, alpha=alpha, stride=1, dilation=1, groups=1, bias=True, pad_type='zero', norm_type=None, act_type='relu', mode='CNA')
# # # ############ LR acts on HR
# # # fea_conv1 = B.conv_block(nf, nf, kernel_size=3, norm_type=None, act_type='relu', mode='CNA')
# LR_conv = B.conv_block(nf, nf, kernel_size=3, norm_type=None, act_type='relu', mode='CNA')
###################################################################################################################################################
if upsample_mode == 'upconv':
# upsample_block = B.upconv_blcok
upsample_block = B.octave_upconv_blcok
elif upsample_mode == 'pixelshuffle':
# upsample_block = B.pixelshuffle_block
upsample_block = B.octave_pixelshuffle_block
else:
raise NotImplementedError('upsample mode [{:s}] is not found'.format(upsample_mode))
if upscale == 3:
upsampler = upsample_block(nf, nf, 3, alpha=alpha,act_type=act_type)
else:
upsampler = [upsample_block(nf, nf,alpha=alpha, act_type=act_type) for _ in range(n_upscale)]
#HR_conv0 = B.conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=act_type)
HR_conv0 = B.LastOctaveConv(nf, nf, kernel_size=3, alpha=alpha, stride=1, dilation=1, groups=1, bias=True, pad_type='zero', norm_type=None, act_type=act_type)
HR_conv1 = B.conv_block(nf, out_nc, kernel_size=3, norm_type=None, act_type=None)
###################################################################################################################################################
#self.model = B.ShortcutBlock(B.sequential(fea_conv1,*resnet_blocks, LR_conv))
self.model = B.Octave_ShortcutBlock(B.sequential(fea_conv1,*resnet_blocks, LR_conv))
# self.model_LR1 = B.sequential(*resnet_blocks_LR1)
# self.model_HR1 = B.sequential(*resnet_blocks_HR1)
self.subpixel_up = B.sequential(*upsampler, HR_conv0, HR_conv1)
def forward(self, x):
x=self.fea_conv(x)
# res=x
# x=self.fea_conv1(x)
# x_h,x_l=x
# x_h=self.model_HR1 (x_h)
# x_l=self.model_LR1 (x_l)
# x=x_h,x_l
# x=self.LR_conv(x)
# x=res+x
x = self.model(x)
x=self.subpixel_up(x)
return x
#######################################################################
# def forward(self, x):
# x=self.fea_conv(x)
# res=x
# x=self.fea_conv1(x)
# x_h,x_l=x
# x_h=self.model_HR1 (x_h)
# x_l=self.model_LR1 (x_l)
# x=x_h,x_l
# x=self.LR_conv(x)
# x=(res[0]+x[0],res[1]+x[1])
# x=self.subpixel_up(x)
# return x
############################################################################################
RCAN复现版本
###############################################################################################################
#RCAN
class RCAN(nn.Module):
def __init__(self, in_nc, out_nc, nf=64, ng=10, nb=20, reduction=16, upscale=4,alpha=0.75, norm_type='batch', act_type='relu', \
mode='NAC', res_scale=1, upsample_mode='upconv'):
super(RCAN, self).__init__()
n_upscale = int(math.log(upscale, 2))
if upscale == 3:
n_upscale = 1
fea_conv = B.conv_block(in_nc, nf, kernel_size=3, norm_type=None, act_type='relu', mode='CNA')
CA_blocks = [B.ResidualGroupBlock(nf, nb, kernel_size=3, reduction=reduction, norm_type=norm_type, \
act_type=act_type, mode=mode, res_scale=res_scale) for _ in range(ng)]
LR_conv = B.conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=None, mode=mode)
if upsample_mode == 'upconv':
upsample_block = B.upconv_blcok
elif upsample_mode == 'pixelshuffle':
upsample_block = B.pixelshuffle_block
else:
raise NotImplementedError('upsample mode [{:s}] is not found'.format(upsample_mode))
if upscale == 3:
upsampler = upsample_block(nf, nf, 3, act_type=act_type)
else:
upsampler = [upsample_block(nf, nf, act_type=act_type) for _ in range(n_upscale)]
HR_conv0 = B.conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=act_type)
HR_conv1 = B.conv_block(nf, out_nc, kernel_size=3, norm_type=None, act_type=None)
self.model = B.sequential(fea_conv, B.ShortcutBlock(B.sequential(*CA_blocks, LR_conv)),\
*upsampler, HR_conv0, HR_conv1)
def forward(self, x):
x = self.model(x)
return x
#################################################################################################################################
RRDBNet
class RRDB(nn.Module):
'''
Residual in Residual Dense Block
(ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks)
'''
def __init__(self, nc, kernel_size=3, gc=32, stride=1, bias=True, pad_type='zero', \
norm_type=None, act_type='leakyrelu', mode='CNA'):
super(RRDB, self).__init__()
self.RDB1 = ResidualDenseBlock_5C(nc, kernel_size, gc, stride, bias, pad_type, \
norm_type, act_type, mode)
self.RDB2 = ResidualDenseBlock_5C(nc, kernel_size, gc, stride, bias, pad_type, \
norm_type, act_type, mode)
self.RDB3 = ResidualDenseBlock_5C(nc, kernel_size, gc, stride, bias, pad_type, \
norm_type, act_type, mode)
def forward(self, x):
out = self.RDB1(x)
out = self.RDB2(out)
out = self.RDB3(out)
return out.mul(0.2) + x
class ResidualDenseBlock_5C(nn.Module):
'''
Residual Dense Block
style: 5 convs
The core module of paper: (Residual Dense Network for Image Super-Resolution, CVPR 18)
'''
def __init__(self, nc, kernel_size=3, gc=32, stride=1, bias=True, pad_type='zero', \
norm_type=None, act_type='leakyrelu', mode='CNA'):
super(ResidualDenseBlock_5C, self).__init__()
# gc: growth channel, i.e. intermediate channels
self.conv1 = conv_block(nc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, \
norm_type=norm_type, act_type=act_type, mode=mode)
self.conv2 = conv_block(nc+gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, \
norm_type=norm_type, act_type=act_type, mode=mode)
self.conv3 = conv_block(nc+2*gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, \
norm_type=norm_type, act_type=act_type, mode=mode)
self.conv4 = conv_block(nc+3*gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, \
norm_type=norm_type, act_type=act_type, mode=mode)
if mode == 'CNA':
last_act = None
else:
last_act = act_type
self.conv5 = conv_block(nc+4*gc, nc, 3, stride, bias=bias, pad_type=pad_type, \
norm_type=norm_type, act_type=last_act, mode=mode)
def forward(self, x):
x1 = self.conv1(x)
x2 = self.conv2(torch.cat((x, x1), 1))
x3 = self.conv3(torch.cat((x, x1, x2), 1))
x4 = self.conv4(torch.cat((x, x1, x2, x3), 1))
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
return x5.mul(0.2) + x
##########################################################################################################
class channel_attention_OctaveConv(nn.Module):
def __init__(self, in_nc, out_nc, kernel_size, alpha=0.5, stride=1, dilation=1, groups=1, \
bias=True, pad_type='zero', norm_type=None, act_type='prelu', mode='CNA',reduction=8):
super(channel_attention_OctaveConv, self).__init__()
assert mode in ['CNA', 'NAC', 'CNAC'], 'Wong conv mode [{:s}]'.format(mode)
padding = get_valid_padding(kernel_size, dilation) if pad_type == 'zero' else 0
self.stride = stride
self.out=out_nc
self.alpha=alpha
self.l2l = nn.Conv2d(int(alpha * in_nc), int(alpha * out_nc),
kernel_size, 1, padding, dilation, groups, bias)
self.h2h = nn.Conv2d(in_nc - int(alpha * in_nc), out_nc - int(alpha * out_nc),
kernel_size, 1, padding, dilation, groups, bias)
self.a = act(act_type) if act_type else None
self.n_h = norm(norm_type, int(out_nc*(1 - alpha))) if norm_type else None
self.n_l = norm(norm_type, int(out_nc*alpha)) if norm_type else None
self.average=nn.AdaptiveAvgPool2d(1)
self.attention = sequential(
conv_block(in_nc, in_nc // reduction, 1, stride, dilation, groups, bias, pad_type, \
norm_type, act_type, mode),
conv_block(in_nc // reduction, out_nc, 1, stride, dilation, groups, bias, pad_type, \
norm_type, None, mode),
nn.Sigmoid())
def forward(self, x):
X_h, X_l = x
X_h2h = self.h2h(X_h)
X_l2l = self.l2l(X_l)
X_hpooling=self.average(X_h2h)
X_lpooling=self.average(X_l2l)
X_attention=torch.cat((X_hpooling, X_lpooling), dim=1)
X_attention=self.attention(X_attention)
X_h_attention=X_attention[:,:self.out - int(self.alpha * self.out),:,:]
X_l_attention=X_attention[:,self.out - int(self.alpha * self.out):,:,:]
#print(X_l2h.shape,"~~~~",X_h2h.shape)
X_h = X_h2h*X_h_attention
X_l = X_l2l*X_l_attention
if self.n_h and self.n_l:
X_h = self.n_h(X_h)
X_l = self.n_l(X_l)
if self.a:
X_h = self.a(X_h)
X_l = self.a(X_l)
return X_h, X_l
################################################################################################
CARN
#################################################################################################################################
class CARN(nn.Module):
def __init__(self, in_nc=3, out_nc=3, nf=64, nc=4, nb=3, upscale=2, norm_type=None, \
act_type='prelu', mode='NAC', res_scale=1.0,upsample_mode='upconv'):
super(CARN, self).__init__()
n_upscale = int(math.log(upscale, 2))
if upscale == 3:
n_upscale = 1
self.nb = nb
#alpha=0.5
self.fea_conv = B.conv_block(in_nc, nf, kernel_size=3, norm_type=None, act_type='relu', mode='CNA')
self.fea_conv1 = B.conv_block(nf, nf, kernel_size=3, norm_type=None, act_type='relu', mode='CNA')
###############################
#self.fea_conv = B.FirstOctaveConv(in_nc, nf, kernel_size=3, alpha=alpha,norm_type=None, act_type='relu', mode='CNA')
self.CascadeBlocks = nn.ModuleList([B.CascadeBlock(nc, nf, kernel_size=3, norm_type=norm_type, \
act_type=act_type, mode=mode, res_scale=res_scale) for _ in range(nb)])
########octave
# self.CascadeBlocks = nn.ModuleList([B.OctaveCascadeBlock(nc, nf, kernel_size=3, alpha=alpha, \
# norm_type=norm_type, act_type=act_type, mode=mode, res_scale=res_scale) for _ in range(nb)])
##########without connection
# self.CascadeBlocks = nn.ModuleList([B.withoutconnection_OctaveCascadeBlock(nc, nf, kernel_size=3, alpha=alpha, \
# norm_type=norm_type, act_type=act_type, mode=mode, res_scale=res_scale) for _ in range(nb)])
self.CatBlocks = nn.ModuleList([B.conv_block((i + 2)*nf, nf, kernel_size=1, \
norm_type=norm_type, act_type=act_type, mode=mode) for i in range(nb)])
##########octave
# self.CatBlocks = nn.ModuleList([B.OctaveConv((i + 2)*nf, nf, kernel_size=1, alpha=alpha, \
# norm_type=norm_type, act_type=act_type, mode=mode) for i in range(nb)])
##########without connection
# self.CatBlocks = nn.ModuleList([B.withoutconnection_OctaveConv((i + 2)*nf, nf, kernel_size=1, alpha=alpha, \
# norm_type=norm_type, act_type=act_type, mode=mode) for i in range(nb)])
if upsample_mode == 'upconv':
#upsample_block = B.upconv_blcok
upsample_block = B.octave_upconv_blcok
elif upsample_mode == 'pixelshuffle':
upsample_block = B.pixelshuffle_block
#upsample_block = B.octave_pixelshuffle_block##############************************************
else:
raise NotImplementedError('upsample mode [{:s}] is not found'.format(upsample_mode))
if upscale == 3:
upsampler = upsample_block(nf, nf, 3,act_type=act_type)
else:
upsampler = [upsample_block(nf, nf,act_type=act_type) for _ in range(n_upscale)]
HR_conv0 = B.conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=act_type)
#HR_conv0 = B.LastOctaveConv(nf, nf, kernel_size=3, alpha=alpha,norm_type=None, act_type=act_type)
HR_conv1 = B.conv_block(nf, out_nc, kernel_size=3, norm_type=None, act_type=None)
self.subpixel_up = B.sequential(*upsampler, HR_conv0,HR_conv1)
def forward(self, x):
x = self.fea_conv(x)
x=self.fea_conv1(x)
pre_fea = x
for i in range(self.nb):
res = self.CascadeBlocks[i](x)
pre_fea = torch.cat((pre_fea, res), dim=1)
x = self.CatBlocks[i](pre_fea)
# pre_fea = x
# for i in range(self.nb):
# res = self.CascadeBlocks[i](x)
# pre_fea = (torch.cat((pre_fea[0], res[0]), dim=1), \
# torch.cat((pre_fea[1], res[1]), dim=1))
# x = self.CatBlocks[i](pre_fea)
x = self.subpixel_up(x)
return x
#############################################################################################
####################################################################################################
class CascadeBlock(nn.Module):
"""
CascadeBlock, 3-3 style
"""
def __init__(self, nc, gc, kernel_size=3, stride=1, dilation=1, groups=1, \
bias=True, pad_type='zero', norm_type=None, act_type='prelu', mode='CNA', res_scale=1):
super(CascadeBlock, self).__init__()
self.nc = nc
self.ResBlocks = nn.ModuleList([ResNetBlock(gc, gc, gc, kernel_size, stride, dilation, groups, bias, \
pad_type, norm_type, act_type, mode, res_scale) for _ in range(nc)])
self.CatBlocks = nn.ModuleList([conv_block((i + 2)*gc, gc, kernel_size=1, bias=bias, pad_type=pad_type, \
norm_type=norm_type, act_type=act_type, mode=mode) for i in range(nc)])
def forward(self, x):
pre_fea = x
for i in range(self.nc):
res = self.ResBlocks[i](x)
pre_fea = torch.cat((pre_fea, res), dim=1)
x = self.CatBlocks[i](pre_fea)
return x