##############################################################################################
#Multi-scale SR
class MSSR(nn.Module):
def __init__(self, in_nc, out_nc, nf=64, ng=10, nb=20, reduction=16, upscale=4,alpha=0.75, norm_type='batch', act_type='relu', \
mode='NAC', res_scale=1, upsample_mode='upconv'):
super(MSSR, self).__init__()
n_upscale = int(math.log(upscale, 2))
if upscale == 3:
n_upscale = 1
self.pooling1=nn.AvgPool2d(kernel_size=(2, 2), stride=2)
self.pooling2=nn.AvgPool2d(kernel_size=(2, 2), stride=2)
fea_conv1 = B.conv_block(in_nc+nf, nf, kernel_size=3, norm_type=None, act_type='relu', mode='CNA')
resnet_blocks1 = [B.ResNetBlock(nf, nf, nf, norm_type=norm_type, act_type=act_type,mode=mode, res_scale=res_scale) for _ in range(nb)]
LR_conv1=B.conv_block(nf, nf, kernel_size=3, norm_type=norm_type, act_type=None, mode=mode)
self.model1 = B.sequential(fea_conv1, *resnet_blocks1, LR_conv1)
fea_conv2 = B.conv_block(in_nc+nf, nf, kernel_size=3, norm_type=None, act_type='relu', mode='CNA')
resnet_blocks2 = [B.ResNetBlock(nf, nf, nf, norm_type=norm_type, act_type=act_type,mode=mode, res_scale=res_scale) for _ in range(nb)]
LR_conv2=B.conv_block(nf, nf, kernel_size=3, norm_type=norm_type, act_type=None, mode=mode)
self.model2 = B.sequential(fea_conv2, *resnet_blocks2, LR_conv2)
fea_conv3 = B.conv_block(in_nc, nf, kernel_size=3, norm_type=None, act_type='relu', mode='CNA')
resnet_blocks3 = [B.ResNetBlock(nf, nf, nf, norm_type=norm_type, act_type=act_type,mode=mode, res_scale=res_scale) for _ in range(nb)]
LR_conv3=B.conv_block(nf, nf, kernel_size=3, norm_type=norm_type, act_type=None, mode=mode)
self.model3 = B.sequential(fea_conv3, *resnet_blocks3, LR_conv3)
self.upconv=nn.ConvTranspose2d(in_channels=nf,out_channels=nf,kernel_size=9,stride=2, padding=4, output_padding=1)
#self.P_conv = B.conv_block(nf, in_nc*(upscale ** 2), kernel_size=3, norm_type=norm_type, act_type=None, mode=mode)
#self.subpixel_up = nn.PixelShuffle(upscale)
if upsample_mode == 'upconv':
upsample_block = B.upconv_blcok
elif upsample_mode == 'pixelshuffle':
upsample_block = B.pixelshuffle_block
else:
raise NotImplementedError('upsample mode [{:s}] is not found'.format(upsample_mode))
if upscale == 3:
upsampler = upsample_block(nf, nf, 3, act_type=act_type)
else:
upsampler = [upsample_block(nf, nf, act_type=act_type) for _ in range(n_upscale)]
HR_conv0 = B.conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=act_type)
HR_conv1 = B.conv_block(nf, out_nc, kernel_size=3, norm_type=None, act_type=None)
self.subpixel_up = B.sequential(*upsampler, HR_conv0, HR_conv1)
def forward(self, x):
#get different scale
x1=x
x2=self.pooling1(x)
x3=self.pooling2(self.pooling1(x))
x3=self.model3(x3)
x3=self.upconv(x3)
x2=torch.cat((x2,x3), dim=1)
x2=self.model2(x2)
x2=self.upconv(x2)
x1=torch.cat((x2,x1), dim=1)
x=self.model1(x1)
#x = self.P_conv(x)
x = self.subpixel_up(x)
return x
################################################################################################