import cv2
import math
import numpy as np
import os
import queue
import threading
import torch
from basicsr.utils.download_util import load_file_from_url
from torch.nn import functional as F
ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))classRealESRGANer():"""A helper class for upsampling images with RealESRGAN.
Args:
scale (int): Upsampling scale factor used in the networks. It is usually 2 or 4.
model_path (str): The path to the pretrained model. It can be urls (will first download it automatically).
model (nn.Module): The defined network. Default: None.
tile (int): As too large images result in the out of GPU memory issue, so this tile option will first crop
input images into tiles, and then process each of them. Finally, they will be merged into one image.
0 denotes for do not use tile. Default: 0.
tile_pad (int): The pad size for each tile, to remove border artifacts. Default: 10.
pre_pad (int): Pad the input images to avoid border artifacts. Default: 10.
half (float): Whether to use half precision during inference. Default: False.
"""def__init__(self,
scale,
model_path,
dni_weight=None,
model=None,
tile=0,
tile_pad=10,
pre_pad=10,
half=False,
device=None,
gpu_id=None):
self.scale = scale
self.tile_size = tile
self.tile_pad = tile_pad
self.pre_pad = pre_pad
self.mod_scale =None
self.half = half
# initialize modelif gpu_id:
self.device = torch.device(f'cuda:{gpu_id}'if torch.cuda.is_available()else'cpu')if device isNoneelse device
else:
self.device = torch.device('cuda'if torch.cuda.is_available()else'cpu')if device isNoneelse device
ifisinstance(model_path,list):# dniassertlen(model_path)==len(dni_weight),'model_path and dni_weight should have the save length.'
loadnet = self.dni(model_path[0], model_path[1], dni_weight)else:# if the model_path starts with https, it will first download models to the folder: weightsif model_path.startswith('https://'):
model_path = load_file_from_url(
url=model_path, model_dir=os.path.join(ROOT_DIR,'weights'), progress=True, file_name=None)
loadnet = torch.load(model_path, map_location=torch.device('cpu'))# prefer to use params_emaif'params_ema'in loadnet:
keyname ='params_ema'else:
keyname ='params'
model.load_state_dict(loadnet[keyname], strict=True)
model.eval()
self.model = model.to(self.device)if self.half:
self.model = self.model.half()defdni(self, net_a, net_b, dni_weight, key='params', loc='cpu'):"""Deep network interpolation.
``Paper: Deep Network Interpolation for Continuous Imagery Effect Transition``
"""
net_a = torch.load(net_a, map_location=torch.device(loc))
net_b = torch.load(net_b, map_location=torch.device(loc))for k, v_a in net_a[key].items():
net_a[key][k]= dni_weight[0]* v_a + dni_weight[1]* net_b[key][k]return net_a
defpre_process(self, img):"""Pre-process, such as pre-pad and mod pad, so that the images can be divisible
"""
img = torch.from_numpy(np.transpose(img,(2,0,1))).float()
self.img = img.unsqueeze(0).to(self.device)if self.half:
self.img = self.img.half()# pre_padif self.pre_pad !=0:
self.img = F.pad(self.img,(0, self.pre_pad,0, self.pre_pad),'reflect')# mod pad for divisible bordersif self.scale ==2:
self.mod_scale =2elif self.scale ==1:
self.mod_scale =4if self.mod_scale isnotNone:
self.mod_pad_h, self.mod_pad_w =0,0
_, _, h, w = self.img.size()if(h % self.mod_scale !=0):
self.mod_pad_h =(self.mod_scale - h % self.mod_scale)if(w % self.mod_scale !=0):
self.mod_pad_w =(self.mod_scale - w % self.mod_scale)
self.img = F.pad(self.img,(0, self.mod_pad_w,0, self.mod_pad_h),'reflect')defprocess(self):# model inference# self.output = self.model(self.img)# cv2.imwrite("/home/sunyingli/y_channel.jpg", ((self.img[:, 2, :, :].cpu().numpy()*255).astype(np.uint8)[0]))
self.output = self.model(self.img[:,2,:,:].unsqueeze(1))print("Input_size: "+str(self.img[:,2,:,:].unsqueeze(1).shape))from thop import profile
# 2 相当于只取Y通道
f =open(os.devnull,"w")import sys
sys.stdout = f
flops, params = profile(self.model, inputs=(self.img[:,2,:,:].unsqueeze(1),))
sys.stdout = sys.__stdout__
f.close()# 将FLOPs转换为TOPs,计算精度为32位
tops = flops*24/10**12print("24 frames per second computing power: ")# 打印结果print("FLOPs:", flops)print("TOPs:", tops)print("Params:", params)print("-"*60)deftile_process(self):"""It will first crop input images to tiles, and then process each tile.
Finally, all the processed tiles are merged into one images.
Modified from: https://github.com/ata4/esrgan-launcher
"""
batch, channel, height, width = self.img.shape
output_height = height * self.scale
output_width = width * self.scale
output_shape =(batch, channel, output_height, output_width)# start with black image
self.output = self.img.new_zeros(output_shape)
tiles_x = math.ceil(width / self.tile_size)
tiles_y = math.ceil(height / self.tile_size)# loop over all tilesfor y inrange(tiles_y):for x inrange(tiles_x):# extract tile from input image
ofs_x = x * self.tile_size
ofs_y = y * self.tile_size
# input tile area on total image
input_start_x = ofs_x
input_end_x =min(ofs_x + self.tile_size, width)
input_start_y = ofs_y
input_end_y =min(ofs_y + self.tile_size, height)# input tile area on total image with padding
input_start_x_pad =max(input_start_x - self.tile_pad,0)
input_end_x_pad =min(input_end_x + self.tile_pad, width)
input_start_y_pad =max(input_start_y - self.tile_pad,0)
input_end_y_pad =min(input_end_y + self.tile_pad, height)# input tile dimensions
input_tile_width = input_end_x - input_start_x
input_tile_height = input_end_y - input_start_y
tile_idx = y * tiles_x + x +1
input_tile = self.img[:,:, input_start_y_pad:input_end_y_pad, input_start_x_pad:input_end_x_pad]# upscale tiletry:with torch.no_grad():
output_tile = self.model(input_tile)except RuntimeError as error:print('Error', error)print(f'\tTile {tile_idx}/{tiles_x * tiles_y}')# output tile area on total image
output_start_x = input_start_x * self.scale
output_end_x = input_end_x * self.scale
output_start_y = input_start_y * self.scale
output_end_y = input_end_y * self.scale
# output tile area without padding
output_start_x_tile =(input_start_x - input_start_x_pad)* self.scale
output_end_x_tile = output_start_x_tile + input_tile_width * self.scale
output_start_y_tile =(input_start_y - input_start_y_pad)* self.scale
output_end_y_tile = output_start_y_tile + input_tile_height * self.scale
# put tile into output image
self.output[:,:, output_start_y:output_end_y,
output_start_x:output_end_x]= output_tile[:,:, output_start_y_tile:output_end_y_tile,
output_start_x_tile:output_end_x_tile]defpost_process(self):# remove extra padif self.mod_scale isnotNone:
_, _, h, w = self.output.size()
self.output = self.output[:,:,0:h - self.mod_pad_h * self.scale,0:w - self.mod_pad_w * self.scale]# remove prepadif self.pre_pad !=0:
_, _, h, w = self.output.size()
self.output = self.output[:,:,0:h - self.pre_pad * self.scale,0:w - self.pre_pad * self.scale]return self.output
@torch.no_grad()defenhance(self, img, outscale=None, alpha_upsampler='realesrgan'):
h_input, w_input = img.shape[0:2]# img: numpy
img = img.astype(np.float32)if np.max(img)>256:# 16-bit image
max_range =65535print('\tInput is a 16-bit image')else:
max_range =255
img = img / max_range
iflen(img.shape)==2:# gray image
img_mode ='L'
img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)elif img.shape[2]==4:# RGBA image with alpha channel
img_mode ='RGBA'
alpha = img[:,:,3]
img = img[:,:,0:3]
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)if alpha_upsampler =='realesrgan':
alpha = cv2.cvtColor(alpha, cv2.COLOR_GRAY2RGB)else:
img_mode ='RGB'
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)# # 将图像转换为YUV通道# img_yuv = cv2.cvtColor(img, cv2.COLOR_BGR2YUV)# # 分离YUV通道# y, u, v = cv2.split(img_yuv)# ------------------- process image (without the alpha channel) ------------------- #
self.pre_process(img)if self.tile_size >0:
self.tile_process()else:
self.process()
output_img = self.post_process()
output_img = output_img.data.squeeze().float().cpu().clamp_(0,1).numpy()# output_img = np.transpose(output_img[[2, 1, 0], :, :], (1, 2, 0))if img_mode =='L':
output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2GRAY)# ------------------- process the alpha channel if necessary ------------------- #if img_mode =='RGBA':if alpha_upsampler =='realesrgan':
self.pre_process(alpha)if self.tile_size >0:
self.tile_process()else:
self.process()
output_alpha = self.post_process()
output_alpha = output_alpha.data.squeeze().float().cpu().clamp_(0,1).numpy()
output_alpha = np.transpose(output_alpha[[2,1,0],:,:],(1,2,0))
output_alpha = cv2.cvtColor(output_alpha, cv2.COLOR_BGR2GRAY)else:# use the cv2 resize for alpha channel
h, w = alpha.shape[0:2]
output_alpha = cv2.resize(alpha,(w * self.scale, h * self.scale), interpolation=cv2.INTER_LINEAR)# merge the alpha channel
output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2BGRA)
output_img[:,:,3]= output_alpha
# ------------------------------ return ------------------------------ #if max_range ==65535:# 16-bit image
output =(output_img *65535.0).round().astype(np.uint16)else:
output =(output_img *255.0).round().astype(np.uint8)if outscale isnotNoneand outscale !=float(self.scale):
output = cv2.resize(
output,(int(w_input * outscale),int(h_input * outscale),), interpolation=cv2.INTER_LANCZOS4)return output, img_mode
classPrefetchReader(threading.Thread):"""Prefetch images.
Args:
img_list (list[str]): A image list of image paths to be read.
num_prefetch_queue (int): Number of prefetch queue.
"""def__init__(self, img_list, num_prefetch_queue):super().__init__()
self.que = queue.Queue(num_prefetch_queue)
self.img_list = img_list
defrun(self):for img_path in self.img_list:
img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
self.que.put(img)
self.que.put(None)def__next__(self):
next_item = self.que.get()if next_item isNone:raise StopIteration
return next_item
def__iter__(self):return self
classIOConsumer(threading.Thread):def__init__(self, opt, que, qid):super().__init__()
self._queue = que
self.qid = qid
self.opt = opt
defrun(self):whileTrue:
msg = self._queue.get()ifisinstance(msg,str)and msg =='quit':break
output = msg['output']
save_path = msg['save_path']
cv2.imwrite(save_path, output)print(f'IO worker {self.qid} is done.')
import argparse
import cv2
import glob
import os
from basicsr.archs.rrdbnet_arch import RRDBNet
from basicsr.utils.download_util import load_file_from_url
from realesrgan import RealESRGANer
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
import numpy as np
defmain():"""Inference demo for Real-ESRGAN.
"""
parser = argparse.ArgumentParser()
parser.add_argument('-i','--input',type=str, default='/home/sunyingli/inputs/test_004_lq',help='Input image or folder')
parser.add_argument('-n','--model_name',type=str,
default='RealESRGAN_x2plus',help=('Model names: RealESRGAN_x4plus | RealESRNet_x4plus | RealESRGAN_x4plus_anime_6B | RealESRGAN_x2plus | ''realesr-animevideov3 | realesr-general-x4v3'))
parser.add_argument('-o','--output',type=str, default='/home/sunyingli/inputs/test_004_lq_2X',help='Output folder')
parser.add_argument('-dn','--denoise_strength',type=float,
default=0.5,help=('Denoise strength. 0 for weak denoise (keep noise), 1 for strong denoise ability. ''Only used for the realesr-general-x4v3 model'))
parser.add_argument('-s','--outscale',type=float, default=2,help='The final upsampling scale of the image')
parser.add_argument(# '--model_path', type=str, default="/home/sunyingli/Real-ESRGAN/experiments/train_realesrnet_x2plus_32_1_16_4channel__123conv_1rdb1_net_oneresize_no_conv_hr_pairdata/models/net_g_520000.pth", help='[Option] Model path. Usually, you do not need to specify it')'--model_path',type=str, default="/home/sunyingli/Real-ESRGAN/experiments/train_realesrnet_x2plus_32_1_16_4channel__123conv_1rdb1_net_oneresize_no_conv_hr_pairdata_0806/models/net_g_410000.pth",help='[Option] Model path. Usually, you do not need to specify it')
parser.add_argument('--suffix',type=str, default='',help='Suffix of the restored image')
parser.add_argument('-t','--tile',type=int, default=0,help='Tile size, 0 for no tile during testing')
parser.add_argument('--tile_pad',type=int, default=10,help='Tile padding')
parser.add_argument('--pre_pad',type=int, default=0,help='Pre padding size at each border')
parser.add_argument('--face_enhance', action='store_true',help='Use GFPGAN to enhance face')
parser.add_argument('--fp32', action='store_true',help='Use fp32 precision during inference. Default: fp16 (half precision).')
parser.add_argument('--alpha_upsampler',type=str,
default='realesrgan',help='The upsampler for the alpha channels. Options: realesrgan | bicubic')
parser.add_argument('--ext',type=str,
default='auto',help='Image extension. Options: auto | jpg | png, auto means using the same extension as inputs')
parser.add_argument('-g','--gpu-id',type=int, default=3,help='gpu device to use (default=None) can be 0,1,2 for multi-gpu')
args = parser.parse_args()# determine models according to model names
args.model_name = args.model_name.split('.')[0]if args.model_name =='RealESRGAN_x4plus':# x4 RRDBNet model
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
netscale =4
file_url =['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth']elif args.model_name =='RealESRNet_x4plus':# x4 RRDBNet model
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
netscale =4
file_url =['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/RealESRNet_x4plus.pth']elif args.model_name =='RealESRGAN_x4plus_anime_6B':# x4 RRDBNet model with 6 blocks
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4)
netscale =4
file_url =['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth']elif args.model_name =='RealESRGAN_x2plus':# x2 RRDBNet model
model = RRDBNet(num_in_ch=1, num_out_ch=1, num_feat=32, num_block=1, num_grow_ch=16, scale=2)
netscale =2
file_url =['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth']elif args.model_name =='realesr-animevideov3':# x4 VGG-style model (XS size)
model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu')
netscale =4
file_url =['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth']elif args.model_name =='realesr-general-x4v3':# x4 VGG-style model (S size)
model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
netscale =4
file_url =['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth','https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth']# determine model pathsif args.model_path isnotNone:
model_path = args.model_path
else:
model_path = os.path.join('weights', args.model_name +'.pth')ifnot os.path.isfile(model_path):
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))for url in file_url:# model_path will be updated
model_path = load_file_from_url(
url=url, model_dir=os.path.join(ROOT_DIR,'weights'), progress=True, file_name=None)# use dni to control the denoise strength
dni_weight =Noneif args.model_name =='realesr-general-x4v3'and args.denoise_strength !=1:
wdn_model_path = model_path.replace('realesr-general-x4v3','realesr-general-wdn-x4v3')
model_path =[model_path, wdn_model_path]
dni_weight =[args.denoise_strength,1- args.denoise_strength]# restorer
upsampler = RealESRGANer(
scale=netscale,
model_path=model_path,
dni_weight=dni_weight,
model=model,
tile=args.tile,
tile_pad=args.tile_pad,
pre_pad=args.pre_pad,
half=not args.fp32,
gpu_id=args.gpu_id)if args.face_enhance:# Use GFPGAN for face enhancementfrom gfpgan import GFPGANer
face_enhancer = GFPGANer(
model_path='https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth',
upscale=args.outscale,
arch='clean',
channel_multiplier=2,
bg_upsampler=upsampler)
os.makedirs(args.output, exist_ok=True)if os.path.isfile(args.input):
paths =[args.input]else:
paths =sorted(glob.glob(os.path.join(args.input,'*')))for idx, path inenumerate(paths):
imgname, extension = os.path.splitext(os.path.basename(path))print('Testing', idx, imgname)
img_BGR = cv2.imread(path)
img_yuv = cv2.cvtColor(img_BGR, cv2.COLOR_BGR2YUV)
img = img_yuv
img_resized = cv2.resize(img_BGR,None, fx=2, fy=2, interpolation=cv2.INTER_CUBIC)
img_YUV = cv2.cvtColor(img_resized, cv2.COLOR_BGR2YUV)# 分离YUV通道
y, u, v = cv2.split(img_YUV)iflen(img.shape)==3and img.shape[2]==4:
img_mode ='RGBA'else:
img_mode =Nonetry:# if True:if args.face_enhance:
_, _, output = face_enhancer.enhance(img, has_aligned=False, only_center_face=False, paste_back=True)else:
output, _ = upsampler.enhance(img, outscale=args.outscale)except RuntimeError as error:print('Error', error)print('If you encounter CUDA out of memory, try to set --tile with a smaller number.')else:if args.ext =='auto':
extension = extension[1:]else:
extension = args.ext
if img_mode =='RGBA':# RGBA images should be saved in png format
extension ='png'if args.suffix =='':
save_path = os.path.join(args.output,f'{imgname}.{extension}')else:
save_path = os.path.join(args.output,f'{imgname}_{args.suffix}.{extension}')# cv2.imwrite(save_path, output)# 保存为彩色图像# 创建一个空白的彩色图像
h, w = output.shape
img_YUV_OUT = np.zeros((h, w,3), dtype=np.uint8)# 为图像的Y、U、V通道赋值
img_YUV_OUT[:,:,0]= output # Y通道
img_YUV_OUT[:,:,1]= u # U通道
img_YUV_OUT[:,:,2]= v # V通道# 将图像从YUV颜色空间转换回BGR颜色空间
img_BGR_OUT = cv2.cvtColor(img_YUV_OUT, cv2.COLOR_YUV2BGR)# 保存合成后的图像
cv2.imwrite(save_path, img_BGR_OUT)if __name__ =='__main__':
main()
import torch
from torch import nn as nn
from torch.nn import functional as F
from basicsr.utils.registry import ARCH_REGISTRY
from.arch_util import default_init_weights, make_layer, pixel_unshuffle
import time
classResidualDenseBlock(nn.Module):"""Residual Dense Block.
Used in RRDB block in ESRGAN.
Args:
num_feat (int): Channel number of intermediate features.
num_grow_ch (int): Channels for each growth.
"""def__init__(self, num_feat=64, num_grow_ch=32):super(ResidualDenseBlock, self).__init__()# 5 # self.conv1 = nn.Conv2d(num_feat, num_grow_ch, 3, 1, 1)# self.conv2 = nn.Conv2d(num_feat + num_grow_ch, num_grow_ch, 3, 1, 1)# self.conv3 = nn.Conv2d(num_feat + 2 * num_grow_ch, num_grow_ch, 3, 1, 1)# self.conv4 = nn.Conv2d(num_feat + 3 * num_grow_ch, num_grow_ch, 3, 1, 1)# self.conv5 = nn.Conv2d(num_feat + 4 * num_grow_ch, num_feat, 3, 1, 1)# self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)# default_init_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1)# 3
self.conv1 = nn.Conv2d(num_feat, num_grow_ch,3,1,1)
self.conv2 = nn.Conv2d(num_feat + num_grow_ch, num_grow_ch,3,1,1)
self.conv5 = nn.Conv2d(num_feat +2* num_grow_ch, num_feat,3,1,1)
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
default_init_weights([self.conv1, self.conv2, self.conv5],0.1)defforward(self, x):
x1 = self.lrelu(self.conv1(x))
x2 = self.lrelu(self.conv2(torch.cat((x, x1),1)))# x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))# x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))# x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))# # Empirically, we use 0.2 to scale the residual for better performance# return x5 * 0.2 + x##############################################################################################
x5 = self.conv5(torch.cat((x, x1, x2),1))# Empirically, we use 0.2 to scale the residual for better performancereturn x5 *0.2+ x
##############################################################################################classRRDB(nn.Module):"""Residual in Residual Dense Block.
Used in RRDB-Net in ESRGAN.
Args:
num_feat (int): Channel number of intermediate features.
num_grow_ch (int): Channels for each growth.
"""def__init__(self, num_feat, num_grow_ch=32):super(RRDB, self).__init__()
self.rdb1 = ResidualDenseBlock(num_feat, num_grow_ch)# self.rdb2 = ResidualDenseBlock(num_feat, num_grow_ch)# self.rdb3 = ResidualDenseBlock(num_feat, num_grow_ch)defforward(self, x):
out = self.rdb1(x)# out = self.rdb2(out)# out = self.rdb3(out)# Empirically, we use 0.2 to scale the residual for better performancereturn out *0.2+ x
# @ARCH_REGISTRY.register()# class RRDBNet(nn.Module):# """Networks consisting of Residual in Residual Dense Block, which is used# in ESRGAN.# ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks.# We extend ESRGAN for scale x2 and scale x1.# Note: This is one option for scale 1, scale 2 in RRDBNet.# We first employ the pixel-unshuffle (an inverse operation of pixelshuffle to reduce the spatial size# and enlarge the channel size before feeding inputs into the main ESRGAN architecture.# Args:# num_in_ch (int): Channel number of inputs.# num_out_ch (int): Channel number of outputs.# num_feat (int): Channel number of intermediate features.# Default: 64# num_block (int): Block number in the trunk network. Defaults: 23# num_grow_ch (int): Channels for each growth. Default: 32.# """# def __init__(self, num_in_ch, num_out_ch, scale=4, num_feat=64, num_block=23, num_grow_ch=32):# super(RRDBNet, self).__init__()# self.scale = scale# if scale == 2:# num_in_ch = num_in_ch * 4# elif scale == 1:# num_in_ch = num_in_ch * 16# self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)# self.body = make_layer(RRDB, num_block, num_feat=num_feat, num_grow_ch=num_grow_ch)# self.conv_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1)# # upsample# self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)# self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)# self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)# self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)# self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)# def forward(self, x):# if self.scale == 2:# feat = pixel_unshuffle(x, scale=2)# elif self.scale == 1:# feat = pixel_unshuffle(x, scale=4)# else:# feat = x# feat = self.conv_first(feat)# body_feat = self.conv_body(self.body(feat))# feat = feat + body_feat# # upsample# feat = self.lrelu(self.conv_up1(F.interpolate(feat, scale_factor=2, mode='nearest')))# feat = self.lrelu(self.conv_up2(F.interpolate(feat, scale_factor=2, mode='nearest')))# out = self.conv_last(self.lrelu(self.conv_hr(feat)))# return outdefpixel_shuffle_resume(x, scale):""" Pixel shuffle.
Args:
x (Tensor): Input feature with shape (b, c, h, w).
scale (int): Upsample ratio.
Returns:
Tensor: the pixel shuffled feature.
"""# scale = scale * 2
b, c, h, w = x.size()
out_channel =1assert h %2==0and w %2==0
hh = h * scale
ww = w * scale
# x_view = x.view(b, out_channel, 2, 4, h, w)# return x_view.permute(0, 1, 4, 2, 5, 3).reshape(b, out_channel, hh, ww)
x_view = x.view(b, out_channel, scale, scale, h, w)return x_view.permute(0,1,4,2,5,3).reshape(b, out_channel, hh, ww)@ARCH_REGISTRY.register()classRRDBNet(nn.Module):"""Networks consisting of Residual in Residual Dense Block, which is used
in ESRGAN.
ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks.
We extend ESRGAN for scale x2 and scale x1.
Note: This is one option for scale 1, scale 2 in RRDBNet.
We first employ the pixel-unshuffle (an inverse operation of pixelshuffle to reduce the spatial size
and enlarge the channel size before feeding inputs into the main ESRGAN architecture.
Args:
num_in_ch (int): Channel number of inputs.
num_out_ch (int): Channel number of outputs.
num_feat (int): Channel number of intermediate features.
Default: 64
num_block (int): Block number in the trunk network. Defaults: 23
num_grow_ch (int): Channels for each growth. Default: 32.
"""def__init__(self, num_in_ch, num_out_ch, scale=4, num_feat=64, num_block=23, num_grow_ch=32):super(RRDBNet, self).__init__()
self.scale = scale
if scale ==2:
num_in_ch = num_in_ch *4elif scale ==1:
num_in_ch = num_in_ch *16
self.conv_first = nn.Conv2d(num_in_ch, num_feat,3,1,1)
self.body = make_layer(RRDB, num_block, num_feat=num_feat, num_grow_ch=num_grow_ch)
self.conv_body = nn.Conv2d(num_feat, num_feat,3,1,1)# upsample
self.conv_up1 = nn.Conv2d(num_feat, num_feat,3,1,1)# self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)# self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
self.conv_last = nn.Conv2d(num_feat, num_out_ch *4,3,1,1)
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)defforward(self, x):# start_time = time.time()if self.scale ==2:
feat = pixel_unshuffle(x, scale=2)elif self.scale ==1:
feat = pixel_unshuffle(x, scale=4)else:
feat = x
# end_time = time.time()# print(f"pixel_unshuffle: {end_time - start_time} 秒")# start_time = time.time()
feat = self.conv_first(feat)
body_feat = self.conv_body(self.body(feat))
feat = feat + body_feat
# upsample# end_time = time.time()# print(f"body_feat: {end_time - start_time} 秒")# start_time = time.time()
feat = self.lrelu(self.conv_up1(F.interpolate(feat, scale_factor=2, mode='bilinear')))# end_time = time.time()# print(f"interpolate: {end_time - start_time} 秒")# feat = self.lrelu(self.conv_up2(F.interpolate(feat, scale_factor=2, mode='nearest')))# out = self.conv_last(self.lrelu(self.conv_hr(feat)))# start_time = time.time()
out = self.conv_last(feat)# end_time = time.time()# print(f"self.conv_last(feat): {end_time - start_time} 秒")# start_time = time.time()# result = pixel_shuffle_resume(out, self.scale)# end_time = time.time()# print(f"pixel_shuffle_resume(out, self.scale): {end_time - start_time} 秒")# return resultreturn pixel_shuffle_resume(out, self.scale)
import collections.abc
import math
import torch
import torchvision
import warnings
from distutils.version import LooseVersion
from itertools import repeat
from torch import nn as nn
from torch.nn import functional as F
from torch.nn import init as init
from torch.nn.modules.batchnorm import _BatchNorm
from basicsr.ops.dcn import ModulatedDeformConvPack, modulated_deform_conv
from basicsr.utils import get_root_logger
@torch.no_grad()defdefault_init_weights(module_list, scale=1, bias_fill=0,**kwargs):"""Initialize network weights.
Args:
module_list (list[nn.Module] | nn.Module): Modules to be initialized.
scale (float): Scale initialized weights, especially for residual
blocks. Default: 1.
bias_fill (float): The value to fill bias. Default: 0
kwargs (dict): Other arguments for initialization function.
"""ifnotisinstance(module_list,list):
module_list =[module_list]for module in module_list:for m in module.modules():ifisinstance(m, nn.Conv2d):
init.kaiming_normal_(m.weight,**kwargs)
m.weight.data *= scale
if m.bias isnotNone:
m.bias.data.fill_(bias_fill)elifisinstance(m, nn.Linear):
init.kaiming_normal_(m.weight,**kwargs)
m.weight.data *= scale
if m.bias isnotNone:
m.bias.data.fill_(bias_fill)elifisinstance(m, _BatchNorm):
init.constant_(m.weight,1)if m.bias isnotNone:
m.bias.data.fill_(bias_fill)defmake_layer(basic_block, num_basic_block,**kwarg):"""Make layers by stacking the same blocks.
Args:
basic_block (nn.module): nn.module class for basic block.
num_basic_block (int): number of blocks.
Returns:
nn.Sequential: Stacked blocks in nn.Sequential.
"""
layers =[]for _ inrange(num_basic_block):
layers.append(basic_block(**kwarg))return nn.Sequential(*layers)classResidualBlockNoBN(nn.Module):"""Residual block without BN.
Args:
num_feat (int): Channel number of intermediate features.
Default: 64.
res_scale (float): Residual scale. Default: 1.
pytorch_init (bool): If set to True, use pytorch default init,
otherwise, use default_init_weights. Default: False.
"""def__init__(self, num_feat=64, res_scale=1, pytorch_init=False):super(ResidualBlockNoBN, self).__init__()
self.res_scale = res_scale
self.conv1 = nn.Conv2d(num_feat, num_feat,3,1,1, bias=True)
self.conv2 = nn.Conv2d(num_feat, num_feat,3,1,1, bias=True)
self.relu = nn.ReLU(inplace=True)ifnot pytorch_init:
default_init_weights([self.conv1, self.conv2],0.1)defforward(self, x):
identity = x
out = self.conv2(self.relu(self.conv1(x)))return identity + out * self.res_scale
classUpsample(nn.Sequential):"""Upsample module.
Args:
scale (int): Scale factor. Supported scales: 2^n and 3.
num_feat (int): Channel number of intermediate features.
"""def__init__(self, scale, num_feat):
m =[]if(scale &(scale -1))==0:# scale = 2^nfor _ inrange(int(math.log(scale,2))):
m.append(nn.Conv2d(num_feat,4* num_feat,3,1,1))
m.append(nn.PixelShuffle(2))elif scale ==3:
m.append(nn.Conv2d(num_feat,9* num_feat,3,1,1))
m.append(nn.PixelShuffle(3))else:raise ValueError(f'scale {scale} is not supported. Supported scales: 2^n and 3.')super(Upsample, self).__init__(*m)defflow_warp(x, flow, interp_mode='bilinear', padding_mode='zeros', align_corners=True):"""Warp an image or feature map with optical flow.
Args:
x (Tensor): Tensor with size (n, c, h, w).
flow (Tensor): Tensor with size (n, h, w, 2), normal value.
interp_mode (str): 'nearest' or 'bilinear'. Default: 'bilinear'.
padding_mode (str): 'zeros' or 'border' or 'reflection'.
Default: 'zeros'.
align_corners (bool): Before pytorch 1.3, the default value is
align_corners=True. After pytorch 1.3, the default value is
align_corners=False. Here, we use the True as default.
Returns:
Tensor: Warped image or feature map.
"""assert x.size()[-2:]== flow.size()[1:3]
_, _, h, w = x.size()# create mesh grid
grid_y, grid_x = torch.meshgrid(torch.arange(0, h).type_as(x), torch.arange(0, w).type_as(x))
grid = torch.stack((grid_x, grid_y),2).float()# W(x), H(y), 2
grid.requires_grad =False
vgrid = grid + flow
# scale grid to [-1,1]
vgrid_x =2.0* vgrid[:,:,:,0]/max(w -1,1)-1.0
vgrid_y =2.0* vgrid[:,:,:,1]/max(h -1,1)-1.0
vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3)
output = F.grid_sample(x, vgrid_scaled, mode=interp_mode, padding_mode=padding_mode, align_corners=align_corners)# TODO, what if align_corners=Falsereturn output
defresize_flow(flow, size_type, sizes, interp_mode='bilinear', align_corners=False):"""Resize a flow according to ratio or shape.
Args:
flow (Tensor): Precomputed flow. shape [N, 2, H, W].
size_type (str): 'ratio' or 'shape'.
sizes (list[int | float]): the ratio for resizing or the final output
shape.
1) The order of ratio should be [ratio_h, ratio_w]. For
downsampling, the ratio should be smaller than 1.0 (i.e., ratio
< 1.0). For upsampling, the ratio should be larger than 1.0 (i.e.,
ratio > 1.0).
2) The order of output_size should be [out_h, out_w].
interp_mode (str): The mode of interpolation for resizing.
Default: 'bilinear'.
align_corners (bool): Whether align corners. Default: False.
Returns:
Tensor: Resized flow.
"""
_, _, flow_h, flow_w = flow.size()if size_type =='ratio':
output_h, output_w =int(flow_h * sizes[0]),int(flow_w * sizes[1])elif size_type =='shape':
output_h, output_w = sizes[0], sizes[1]else:raise ValueError(f'Size type should be ratio or shape, but got type {size_type}.')
input_flow = flow.clone()
ratio_h = output_h / flow_h
ratio_w = output_w / flow_w
input_flow[:,0,:,:]*= ratio_w
input_flow[:,1,:,:]*= ratio_h
resized_flow = F.interpolate(input=input_flow, size=(output_h, output_w), mode=interp_mode, align_corners=align_corners)return resized_flow
# # TODO: may write a cpp file# def pixel_unshuffle(x, scale):# """ Pixel unshuffle.# Args:# x (Tensor): Input feature with shape (b, c, hh, hw).# scale (int): Downsample ratio.# Returns:# Tensor: the pixel unshuffled feature.# """# b, c, hh, hw = x.size()# out_channel = c * (scale**2)# assert hh % scale == 0 and hw % scale == 0# h = hh // scale# w = hw // scale# x_view = x.view(b, c, h, scale, w, scale)# return x_view.permute(0, 1, 3, 5, 2, 4).reshape(b, out_channel, h, w)# TODO: may write a cpp filedefpixel_unshuffle(x, scale):""" Pixel unshuffle.
Args:
x (Tensor): Input feature with shape (b, c, hh, hw).
scale (int): Downsample ratio.
Returns:
Tensor: the pixel unshuffled feature.
"""# scale = scale * 2
b, c, hh, hw = x.size()
out_channel = c *(scale**2)assert hh % scale ==0and hw % scale ==0
h = hh // scale
w = hw // scale
x_view = x.view(b, c, h, scale, w, scale)return x_view.permute(0,1,3,5,2,4).reshape(b, out_channel, h, w)classDCNv2Pack(ModulatedDeformConvPack):"""Modulated deformable conv for deformable alignment.
Different from the official DCNv2Pack, which generates offsets and masks
from the preceding features, this DCNv2Pack takes another different
features to generate offsets and masks.
``Paper: Delving Deep into Deformable Alignment in Video Super-Resolution``
"""defforward(self, x, feat):
out = self.conv_offset(feat)
o1, o2, mask = torch.chunk(out,3, dim=1)
offset = torch.cat((o1, o2), dim=1)
mask = torch.sigmoid(mask)
offset_absmean = torch.mean(torch.abs(offset))if offset_absmean >50:
logger = get_root_logger()
logger.warning(f'Offset abs mean is {offset_absmean}, larger than 50.')if LooseVersion(torchvision.__version__)>= LooseVersion('0.9.0'):return torchvision.ops.deform_conv2d(x, offset, self.weight, self.bias, self.stride, self.padding,
self.dilation, mask)else:return modulated_deform_conv(x, offset, mask, self.weight, self.bias, self.stride, self.padding,
self.dilation, self.groups, self.deformable_groups)def_no_grad_trunc_normal_(tensor, mean, std, a, b):# From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/weight_init.py# Cut & paste from PyTorch official master until it's in a few official releases - RW# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdfdefnorm_cdf(x):# Computes standard normal cumulative distribution functionreturn(1.+ math.erf(x / math.sqrt(2.)))/2.if(mean < a -2* std)or(mean > b +2* std):
warnings.warn('mean is more than 2 std from [a, b] in nn.init.trunc_normal_. ''The distribution of values may be incorrect.',
stacklevel=2)with torch.no_grad():# Values are generated by using a truncated uniform distribution and# then using the inverse CDF for the normal distribution.# Get upper and lower cdf values
low = norm_cdf((a - mean)/ std)
up = norm_cdf((b - mean)/ std)# Uniformly fill tensor with values from [low, up], then translate to# [2l-1, 2u-1].
tensor.uniform_(2* low -1,2* up -1)# Use inverse cdf transform for normal distribution to get truncated# standard normal
tensor.erfinv_()# Transform to proper mean, std
tensor.mul_(std * math.sqrt(2.))
tensor.add_(mean)# Clamp to ensure it's in the proper range
tensor.clamp_(min=a,max=b)return tensor
deftrunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):r"""Fills the input Tensor with values drawn from a truncated
normal distribution.
From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/weight_init.py
The values are effectively drawn from the
normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
with values outside :math:`[a, b]` redrawn until they are within
the bounds. The method used for generating the random values works
best when :math:`a \leq \text{mean} \leq b`.
Args:
tensor: an n-dimensional `torch.Tensor`
mean: the mean of the normal distribution
std: the standard deviation of the normal distribution
a: the minimum cutoff value
b: the maximum cutoff value
Examples:
>>> w = torch.empty(3, 5)
>>> nn.init.trunc_normal_(w)
"""return _no_grad_trunc_normal_(tensor, mean, std, a, b)# From PyTorchdef_ntuple(n):defparse(x):ifisinstance(x, collections.abc.Iterable):return x
returntuple(repeat(x, n))return parse
to_1tuple = _ntuple(1)
to_2tuple = _ntuple(2)
to_3tuple = _ntuple(3)
to_4tuple = _ntuple(4)
to_ntuple = _ntuple
```python
在这里插入代码片
import numpy as np
import random
import torch
from basicsr.data.degradations import random_add_gaussian_noise_pt, random_add_poisson_noise_pt
from basicsr.data.transforms import paired_random_crop
from basicsr.models.srgan_model import SRGANModel
from basicsr.utils import DiffJPEG, USMSharp
from basicsr.utils.img_process_util import filter2D
from basicsr.utils.registry import MODEL_REGISTRY
from collections import OrderedDict
from torch.nn import functional as F
import time
@MODEL_REGISTRY.register()classRealESRGANModel(SRGANModel):"""RealESRGAN Model for Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data.
It mainly performs:
1. randomly synthesize LQ images in GPU tensors
2. optimize the networks with GAN training.
"""def__init__(self, opt):super(RealESRGANModel, self).__init__(opt)
self.jpeger = DiffJPEG(differentiable=False).cuda()# simulate JPEG compression artifacts
self.usm_sharpener = USMSharp().cuda()# do usm sharpening
self.queue_size = opt.get('queue_size',180)@torch.no_grad()def_dequeue_and_enqueue(self):"""It is the training pair pool for increasing the diversity in a batch.
Batch processing limits the diversity of synthetic degradations in a batch. For example, samples in a
batch could not have different resize scaling factors. Therefore, we employ this training pair pool
to increase the degradation diversity in a batch.
"""# initialize
b, c, h, w = self.lq.size()ifnothasattr(self,'queue_lr'):assert self.queue_size % b ==0,f'queue size {self.queue_size} should be divisible by batch size {b}'
self.queue_lr = torch.zeros(self.queue_size, c, h, w).cuda()
_, c, h, w = self.gt.size()
self.queue_gt = torch.zeros(self.queue_size, c, h, w).cuda()
self.queue_ptr =0if self.queue_ptr == self.queue_size:# the pool is full# do dequeue and enqueue# shuffle
idx = torch.randperm(self.queue_size)
self.queue_lr = self.queue_lr[idx]
self.queue_gt = self.queue_gt[idx]# get first b samples
lq_dequeue = self.queue_lr[0:b,:,:,:].clone()
gt_dequeue = self.queue_gt[0:b,:,:,:].clone()# update the queue
self.queue_lr[0:b,:,:,:]= self.lq.clone()
self.queue_gt[0:b,:,:,:]= self.gt.clone()
self.lq = lq_dequeue
self.gt = gt_dequeue
else:# only do enqueue
self.queue_lr[self.queue_ptr:self.queue_ptr + b,:,:,:]= self.lq.clone()
self.queue_gt[self.queue_ptr:self.queue_ptr + b,:,:,:]= self.gt.clone()
self.queue_ptr = self.queue_ptr + b
@torch.no_grad()deffeed_data(self, data):"""Accept data from dataloader, and then add two-order degradations to obtain LQ images.
"""if self.is_train and self.opt.get('high_order_degradation',True):# training data synthesis
self.gt = data['gt'].to(self.device)
self.gt_usm = self.usm_sharpener(self.gt)
self.kernel1 = data['kernel1'].to(self.device)
self.kernel2 = data['kernel2'].to(self.device)
self.sinc_kernel = data['sinc_kernel'].to(self.device)
ori_h, ori_w = self.gt.size()[2:4]# ----------------------- The first degradation process ----------------------- ## blur
out = filter2D(self.gt_usm, self.kernel1)# random resize
updown_type = random.choices(['up','down','keep'], self.opt['resize_prob'])[0]if updown_type =='up':
scale = np.random.uniform(1, self.opt['resize_range'][1])elif updown_type =='down':
scale = np.random.uniform(self.opt['resize_range'][0],1)else:
scale =1
mode = random.choice(['area','bilinear','bicubic'])
out = F.interpolate(out, scale_factor=scale, mode=mode)# add noise
gray_noise_prob = self.opt['gray_noise_prob']if np.random.uniform()< self.opt['gaussian_noise_prob']:
out = random_add_gaussian_noise_pt(
out, sigma_range=self.opt['noise_range'], clip=True, rounds=False, gray_prob=gray_noise_prob)else:
out = random_add_poisson_noise_pt(
out,
scale_range=self.opt['poisson_scale_range'],
gray_prob=gray_noise_prob,
clip=True,
rounds=False)# JPEG compression
jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range'])
out = torch.clamp(out,0,1)# clamp to [0, 1], otherwise JPEGer will result in unpleasant artifacts
out = self.jpeger(out, quality=jpeg_p)# ----------------------- The second degradation process ----------------------- ## blurif np.random.uniform()< self.opt['second_blur_prob']:
out = filter2D(out, self.kernel2)# random resize
updown_type = random.choices(['up','down','keep'], self.opt['resize_prob2'])[0]if updown_type =='up':
scale = np.random.uniform(1, self.opt['resize_range2'][1])elif updown_type =='down':
scale = np.random.uniform(self.opt['resize_range2'][0],1)else:
scale =1
mode = random.choice(['area','bilinear','bicubic'])
out = F.interpolate(
out, size=(int(ori_h / self.opt['scale']* scale),int(ori_w / self.opt['scale']* scale)), mode=mode)# add noise
gray_noise_prob = self.opt['gray_noise_prob2']if np.random.uniform()< self.opt['gaussian_noise_prob2']:
out = random_add_gaussian_noise_pt(
out, sigma_range=self.opt['noise_range2'], clip=True, rounds=False, gray_prob=gray_noise_prob)else:
out = random_add_poisson_noise_pt(
out,
scale_range=self.opt['poisson_scale_range2'],
gray_prob=gray_noise_prob,
clip=True,
rounds=False)# JPEG compression + the final sinc filter# We also need to resize images to desired sizes. We group [resize back + sinc filter] together# as one operation.# We consider two orders:# 1. [resize back + sinc filter] + JPEG compression# 2. JPEG compression + [resize back + sinc filter]# Empirically, we find other combinations (sinc + JPEG + Resize) will introduce twisted lines.if np.random.uniform()<0.5:# resize back + the final sinc filter
mode = random.choice(['area','bilinear','bicubic'])
out = F.interpolate(out, size=(ori_h // self.opt['scale'], ori_w // self.opt['scale']), mode=mode)
out = filter2D(out, self.sinc_kernel)# JPEG compression
jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range2'])
out = torch.clamp(out,0,1)
out = self.jpeger(out, quality=jpeg_p)else:# JPEG compression
jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range2'])
out = torch.clamp(out,0,1)
out = self.jpeger(out, quality=jpeg_p)# resize back + the final sinc filter
mode = random.choice(['area','bilinear','bicubic'])
out = F.interpolate(out, size=(ori_h // self.opt['scale'], ori_w // self.opt['scale']), mode=mode)
out = filter2D(out, self.sinc_kernel)# clamp and round
self.lq = torch.clamp((out *255.0).round(),0,255)/255.# random crop
gt_size = self.opt['gt_size'](self.gt, self.gt_usm), self.lq = paired_random_crop([self.gt, self.gt_usm], self.lq, gt_size,
self.opt['scale'])# training pair pool#################################################################
self.lq = self.lq[:,0,:,:].unsqueeze(1)
self.gt = self.gt[:,0,:,:].unsqueeze(1)
self.gt_usm = self.gt_usm[:,0,:,:].unsqueeze(1)###################################################################
self._dequeue_and_enqueue()# sharpen self.gt again, as we have changed the self.gt with self._dequeue_and_enqueue
self.gt_usm = self.usm_sharpener(self.gt)
self.lq = self.lq.contiguous()# for the warning: grad and param do not obey the gradient layout contractelse:# for paired training or validation
self.lq = data['lq'].to(self.device)##########################################
self.lq = self.lq[:,0,:,:].unsqueeze(1)##########################################if'gt'in data:
self.gt = data['gt'].to(self.device)
self.gt_usm = self.usm_sharpener(self.gt)############################################
self.gt = self.gt[:,0,:,:].unsqueeze(1)
self.gt_usm = self.gt_usm[:,0,:,:].unsqueeze(1)############################################defnondist_validation(self, dataloader, current_iter, tb_logger, save_img):# do not use the synthetic process during validation
self.is_train =Falsesuper(RealESRGANModel, self).nondist_validation(dataloader, current_iter, tb_logger, save_img)
self.is_train =Truedefoptimize_parameters(self, current_iter):# usm sharpening
l1_gt = self.gt_usm
percep_gt = self.gt_usm
gan_gt = self.gt_usm
if self.opt['l1_gt_usm']isFalse:
l1_gt = self.gt
if self.opt['percep_gt_usm']isFalse:
percep_gt = self.gt
if self.opt['gan_gt_usm']isFalse:
gan_gt = self.gt
# optimize net_gfor p in self.net_d.parameters():
p.requires_grad =False
self.optimizer_g.zero_grad()
self.output = self.net_g(self.lq)
l_g_total =0
loss_dict = OrderedDict()if(current_iter % self.net_d_iters ==0and current_iter > self.net_d_init_iters):# pixel lossif self.cri_pix:
l_g_pix = self.cri_pix(self.output, l1_gt)
l_g_total += l_g_pix
loss_dict['l_g_pix']= l_g_pix
# perceptual lossif self.cri_perceptual:
l_g_percep, l_g_style = self.cri_perceptual(self.output, percep_gt)if l_g_percep isnotNone:
l_g_total += l_g_percep
loss_dict['l_g_percep']= l_g_percep
if l_g_style isnotNone:
l_g_total += l_g_style
loss_dict['l_g_style']= l_g_style
# gan loss
fake_g_pred = self.net_d(self.output)
l_g_gan = self.cri_gan(fake_g_pred,True, is_disc=False)
l_g_total += l_g_gan
loss_dict['l_g_gan']= l_g_gan
l_g_total.backward()
self.optimizer_g.step()# optimize net_dfor p in self.net_d.parameters():
p.requires_grad =True
self.optimizer_d.zero_grad()# real
real_d_pred = self.net_d(gan_gt)
l_d_real = self.cri_gan(real_d_pred,True, is_disc=True)
loss_dict['l_d_real']= l_d_real
loss_dict['out_d_real']= torch.mean(real_d_pred.detach())
l_d_real.backward()# fake
fake_d_pred = self.net_d(self.output.detach().clone())# clone for pt1.9
l_d_fake = self.cri_gan(fake_d_pred,False, is_disc=True)
loss_dict['l_d_fake']= l_d_fake
loss_dict['out_d_fake']= torch.mean(fake_d_pred.detach())
l_d_fake.backward()
self.optimizer_d.step()if self.ema_decay >0:
self.model_ema(decay=self.ema_decay)
self.log_dict = self.reduce_loss_dict(loss_dict)
import numpy as np
import random
import torch
from basicsr.data.degradations import random_add_gaussian_noise_pt, random_add_poisson_noise_pt
from basicsr.data.transforms import paired_random_crop
from basicsr.models.sr_model import SRModel
from basicsr.utils import DiffJPEG, USMSharp
from basicsr.utils.img_process_util import filter2D
from basicsr.utils.registry import MODEL_REGISTRY
from torch.nn import functional as F
@MODEL_REGISTRY.register()classRealESRNetModel(SRModel):"""RealESRNet Model for Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data.
It is trained without GAN losses.
It mainly performs:
1. randomly synthesize LQ images in GPU tensors
2. optimize the networks with GAN training.
"""def__init__(self, opt):super(RealESRNetModel, self).__init__(opt)
self.jpeger = DiffJPEG(differentiable=False).cuda()# simulate JPEG compression artifacts
self.usm_sharpener = USMSharp().cuda()# do usm sharpening
self.queue_size = opt.get('queue_size',180)@torch.no_grad()def_dequeue_and_enqueue(self):"""It is the training pair pool for increasing the diversity in a batch.
Batch processing limits the diversity of synthetic degradations in a batch. For example, samples in a
batch could not have different resize scaling factors. Therefore, we employ this training pair pool
to increase the degradation diversity in a batch.
"""# initialize
b, c, h, w = self.lq.size()ifnothasattr(self,'queue_lr'):assert self.queue_size % b ==0,f'queue size {self.queue_size} should be divisible by batch size {b}'
self.queue_lr = torch.zeros(self.queue_size, c, h, w).cuda()
_, c, h, w = self.gt.size()
self.queue_gt = torch.zeros(self.queue_size, c, h, w).cuda()
self.queue_ptr =0if self.queue_ptr == self.queue_size:# the pool is full# do dequeue and enqueue# shuffle
idx = torch.randperm(self.queue_size)
self.queue_lr = self.queue_lr[idx]
self.queue_gt = self.queue_gt[idx]# get first b samples
lq_dequeue = self.queue_lr[0:b,:,:,:].clone()
gt_dequeue = self.queue_gt[0:b,:,:,:].clone()# update the queue
self.queue_lr[0:b,:,:,:]= self.lq.clone()
self.queue_gt[0:b,:,:,:]= self.gt.clone()
self.lq = lq_dequeue
self.gt = gt_dequeue
else:# only do enqueue
self.queue_lr[self.queue_ptr:self.queue_ptr + b,:,:,:]= self.lq.clone()
self.queue_gt[self.queue_ptr:self.queue_ptr + b,:,:,:]= self.gt.clone()
self.queue_ptr = self.queue_ptr + b
@torch.no_grad()deffeed_data(self, data):"""Accept data from dataloader, and then add two-order degradations to obtain LQ images.
"""if self.is_train and self.opt.get('high_order_degradation',True):# training data synthesis
self.gt = data['gt'].to(self.device)# USM sharpen the GT imagesif self.opt['gt_usm']isTrue:
self.gt = self.usm_sharpener(self.gt)
self.kernel1 = data['kernel1'].to(self.device)
self.kernel2 = data['kernel2'].to(self.device)
self.sinc_kernel = data['sinc_kernel'].to(self.device)
ori_h, ori_w = self.gt.size()[2:4]# ----------------------- The first degradation process ----------------------- ## blur
out = filter2D(self.gt, self.kernel1)# random resize
updown_type = random.choices(['up','down','keep'], self.opt['resize_prob'])[0]if updown_type =='up':
scale = np.random.uniform(1, self.opt['resize_range'][1])elif updown_type =='down':
scale = np.random.uniform(self.opt['resize_range'][0],1)else:
scale =1
mode = random.choice(['area','bilinear','bicubic'])
out = F.interpolate(out, scale_factor=scale, mode=mode)# add noise
gray_noise_prob = self.opt['gray_noise_prob']if np.random.uniform()< self.opt['gaussian_noise_prob']:
out = random_add_gaussian_noise_pt(
out, sigma_range=self.opt['noise_range'], clip=True, rounds=False, gray_prob=gray_noise_prob)else:
out = random_add_poisson_noise_pt(
out,
scale_range=self.opt['poisson_scale_range'],
gray_prob=gray_noise_prob,
clip=True,
rounds=False)# JPEG compression
jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range'])
out = torch.clamp(out,0,1)# clamp to [0, 1], otherwise JPEGer will result in unpleasant artifacts
out = self.jpeger(out, quality=jpeg_p)# ----------------------- The second degradation process ----------------------- ## blurif np.random.uniform()< self.opt['second_blur_prob']:
out = filter2D(out, self.kernel2)# random resize
updown_type = random.choices(['up','down','keep'], self.opt['resize_prob2'])[0]if updown_type =='up':
scale = np.random.uniform(1, self.opt['resize_range2'][1])elif updown_type =='down':
scale = np.random.uniform(self.opt['resize_range2'][0],1)else:
scale =1
mode = random.choice(['area','bilinear','bicubic'])
out = F.interpolate(
out, size=(int(ori_h / self.opt['scale']* scale),int(ori_w / self.opt['scale']* scale)), mode=mode)# add noise
gray_noise_prob = self.opt['gray_noise_prob2']if np.random.uniform()< self.opt['gaussian_noise_prob2']:
out = random_add_gaussian_noise_pt(
out, sigma_range=self.opt['noise_range2'], clip=True, rounds=False, gray_prob=gray_noise_prob)else:
out = random_add_poisson_noise_pt(
out,
scale_range=self.opt['poisson_scale_range2'],
gray_prob=gray_noise_prob,
clip=True,
rounds=False)# JPEG compression + the final sinc filter# We also need to resize images to desired sizes. We group [resize back + sinc filter] together# as one operation.# We consider two orders:# 1. [resize back + sinc filter] + JPEG compression# 2. JPEG compression + [resize back + sinc filter]# Empirically, we find other combinations (sinc + JPEG + Resize) will introduce twisted lines.if np.random.uniform()<0.5:# resize back + the final sinc filter
mode = random.choice(['area','bilinear','bicubic'])
out = F.interpolate(out, size=(ori_h // self.opt['scale'], ori_w // self.opt['scale']), mode=mode)
out = filter2D(out, self.sinc_kernel)# JPEG compression
jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range2'])
out = torch.clamp(out,0,1)
out = self.jpeger(out, quality=jpeg_p)else:# JPEG compression
jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range2'])
out = torch.clamp(out,0,1)
out = self.jpeger(out, quality=jpeg_p)# resize back + the final sinc filter
mode = random.choice(['area','bilinear','bicubic'])
out = F.interpolate(out, size=(ori_h // self.opt['scale'], ori_w // self.opt['scale']), mode=mode)
out = filter2D(out, self.sinc_kernel)# clamp and round
self.lq = torch.clamp((out *255.0).round(),0,255)/255.# random crop
gt_size = self.opt['gt_size']
self.gt, self.lq = paired_random_crop(self.gt, self.lq, gt_size, self.opt['scale'])####################################################################
self.lq = self.lq[:,0,:,:].unsqueeze(1)
self.gt = self.gt[:,0,:,:].unsqueeze(1)#################################################################### # training pair pool
self._dequeue_and_enqueue()
self.lq = self.lq.contiguous()# for the warning: grad and param do not obey the gradient layout contractelse:# for paired training or validation
self.lq = data['lq'].to(self.device)##########################################
self.lq = self.lq[:,0,:,:].unsqueeze(1)##########################################if'gt'in data:
self.gt = data['gt'].to(self.device)
self.gt_usm = self.usm_sharpener(self.gt)############################################
self.gt = self.gt[:,0,:,:].unsqueeze(1)
self.gt_usm = self.gt_usm[:,0,:,:].unsqueeze(1)############################################defnondist_validation(self, dataloader, current_iter, tb_logger, save_img):# do not use the synthetic process during validation
self.is_train =Falsesuper(RealESRNetModel, self).nondist_validation(dataloader, current_iter, tb_logger, save_img)
self.is_train =True