超分之SRGAN原文解读链接
1. 主训练文件 main.py
import argparse
import os
from math import log10
import pandas as pd
import torch.optim as optim
import torch.utils.data
import torchvision.utils as utils
from torch.autograd import Variable
from torch.utils.data import DataLoader
from tqdm import tqdm
import pytorch_ssim
from data_utils import TrainDatasetFromFolder, ValDatasetFromFolder, display_transform
from loss import GeneratorLoss
from model import Generator, Discriminator
parser = argparse.ArgumentParser(description='Train Super Resolution Models')
parser.add_argument('--crop_size', default=88, type=int, help='training images crop size')
parser.add_argument('--upscale_factor', default=4, type=int, choices=[2, 4, 8],
help='super resolution upscale factor')
parser.add_argument('--num_epochs', default=2, type=int, help='train epoch number')
if __name__ == '__main__':
opt = parser.parse_args()
CROP_SIZE = opt.crop_size
UPSCALE_FACTOR = opt.upscale_factor
NUM_EPOCHS = opt.num_epochs
train_set = TrainDatasetFromFolder('E:\\Datasets\\SR\\DIV2K\\DIV2K_train_HR', crop_size=CROP_SIZE, upscale_factor=UPSCALE_FACTOR)
val_set = ValDatasetFromFolder('E:\\Datasets\\SR\\DIV2K\\DIV2K_valid_HR', upscale_factor=UPSCALE_FACTOR)
train_loader = DataLoader(dataset=train_set, num_workers=4, batch_size=16, shuffle=True)
val_loader = DataLoader(dataset=val_set, num_workers=4, batch_size=1, shuffle=False)
netG = Generator(UPSCALE_FACTOR)
print('# generator parameters:', sum(param.numel() for param in netG.parameters()))
netD = Discriminator()
print('# discriminator parameters:', sum(param.numel() for param in netD.parameters()))
generator_criterion = GeneratorLoss()
if torch.cuda.is_available():
netG.cuda()
netD.cuda()
generator_criterion.cuda()
optimizerG = optim.Adam(netG.parameters())
optimizerD = optim.Adam(netD.parameters())
results = {'d_loss': [], 'g_loss': [], 'd_score': [], 'g_score': [], 'psnr': [], 'ssim': []}
for epoch in range(1, NUM_EPOCHS + 1):
train_bar = tqdm(train_loader)
running_results = {'batch_sizes': 0, 'd_loss': 0, 'g_loss': 0, 'd_score': 0, 'g_score': 0}
netG.train()
netD.train()
for data, target in train_bar:
g_update_first = True
batch_size = data.size(0)
running_results['batch_sizes'] += batch_size
real_img = Variable(target)
if torch.cuda.is_available():
real_img = real_img.cuda()
z = Variable(data)
if torch.cuda.is_available():
z = z.cuda()
fake_img = netG(z)
netD.zero_grad()
real_out = netD(real_img).mean()
fake_out = netD(fake_img).mean()
d_loss = 1 - real_out + fake_out
d_loss.backward(retain_graph=True)
optimizerD.step()
netG.zero_grad()
fake_img = netG(z)
fake_out = netD(fake_img).mean()
g_loss = generator_criterion(fake_out, fake_img, real_img)
g_loss.backward()
fake_img = netG(z)
fake_out = netD(fake_img).mean()
optimizerG.step()
running_results['g_loss'] += g_loss.item() * batch_size
running_results['d_loss'] += d_loss.item() * batch_size
running_results['d_score'] += real_out.item() * batch_size
running_results['g_score'] += fake_out.item() * batch_size
train_bar.set_description(desc='[%d/%d] Loss_D: %.4f Loss_G: %.4f D(x): %.4f D(G(z)): %.4f' % (
epoch, NUM_EPOCHS, running_results['d_loss'] / running_results['batch_sizes'],
running_results['g_loss'] / running_results['batch_sizes'],
running_results['d_score'] / running_results['batch_sizes'],
running_results['g_score'] / running_results['batch_sizes']))
netG.eval()
out_path = 'training_results/SRF_' + str(UPSCALE_FACTOR) + '/'
if not os.path.exists(out_path):
os.makedirs(out_path)
with torch.no_grad():
val_bar = tqdm(val_loader)
valing_results = {'mse': 0, 'ssims': 0, 'psnr': 0, 'ssim': 0, 'batch_sizes': 0}
val_images = []
for val_lr, val_hr_restore, val_hr in val_bar:
batch_size = val_lr.size(0)
valing_results['batch_sizes'] += batch_size
lr = val_lr
hr = val_hr
if torch.cuda.is_available():
lr = lr.cuda()
hr = hr.cuda()
sr = netG(lr)
batch_mse = ((sr - hr) ** 2).data.mean()
valing_results['mse'] += batch_mse * batch_size
batch_ssim = pytorch_ssim.ssim(sr, hr).item()
valing_results['ssims'] += batch_ssim * batch_size
valing_results['psnr'] = 10 * log10((hr.max()**2) / (valing_results['mse'] / valing_results['batch_sizes']))
valing_results['ssim'] = valing_results['ssims'] / valing_results['batch_sizes']
val_bar.set_description(
desc='[converting LR images to SR images] PSNR: %.4f dB SSIM: %.4f' % (
valing_results['psnr'], valing_results['ssim']))
val_images.extend(
[display_transform()(val_hr_restore.squeeze(0)), display_transform()(hr.data.cpu().squeeze(0)),
display_transform()(sr.data.cpu().squeeze(0))])
val_images = torch.stack(val_images)
val_images = torch.chunk(val_images, val_images.size(0) // 15)
val_save_bar = tqdm(val_images, desc='[saving training results]')
index = 1
for image in val_save_bar:
image = utils.make_grid(image, nrow=3, padding=5)
utils.save_image(image, out_path + 'epoch_%d_index_%d.png' % (epoch, index), padding=5)
index += 1
torch.save(netG.state_dict(), 'epochs/netG_epoch_%d_%d.pth' % (UPSCALE_FACTOR, epoch))
torch.save(netD.state_dict(), 'epochs/netD_epoch_%d_%d.pth' % (UPSCALE_FACTOR, epoch))
results['d_loss'].append(running_results['d_loss'] / running_results['batch_sizes'])
results['g_loss'].append(running_results['g_loss'] / running_results['batch_sizes'])
results['d_score'].append(running_results['d_score'] / running_results['batch_sizes'])
results['g_score'].append(running_results['g_score'] / running_results['batch_sizes'])
results['psnr'].append(valing_results['psnr'])
results['ssim'].append(valing_results['ssim'])
if epoch % 10 == 0 and epoch != 0:
out_path = 'statistics/'
data_frame = pd.DataFrame(
data={'Loss_D': results['d_loss'], 'Loss_G': results['g_loss'], 'Score_D': results['d_score'],
'Score_G': results['g_score'], 'PSNR': results['psnr'], 'SSIM': results['ssim']},
index=range(1, epoch + 1))
data_frame.to_csv(out_path + 'srf_' + str(UPSCALE_FACTOR) + '_train_results.csv', index_label='Epoch')
2. 自定义训练集、验证集、测试集文件 data_tilis.py
from os import listdir
from os.path import join
from PIL import Image
from torch.utils.data.dataset import Dataset
from torchvision.transforms import Compose, RandomCrop, ToTensor, ToPILImage, CenterCrop, Resize, InterpolationMode
def is_image_file(filename):
"""用于判断filename是否是png、jpg、jpeg等格式"""
return any(filename.endswith(extension) for extension in ['.png', '.jpg', '.jpeg', '.PNG', '.JPG', '.JPEG'])
def calculate_valid_crop_size(crop_size, upscale_factor):
"""将图片剪裁成缩放因子的整数倍"""
return crop_size - (crop_size % upscale_factor)
def train_hr_transform(crop_size):
return Compose([
RandomCrop(crop_size),
ToTensor(),
])
def train_lr_transform(crop_size, upscale_factor):
return Compose([
ToPILImage(),
Resize(crop_size // upscale_factor, interpolation=InterpolationMode.BICUBIC),
ToTensor()
])
def display_transform():
return Compose([
ToPILImage(),
Resize(400),
CenterCrop(400),
ToTensor()
])
class TrainDatasetFromFolder(Dataset):
def __init__(self, dataset_dir, crop_size, upscale_factor):
super(TrainDatasetFromFolder, self).__init__()
self.image_filenames = [join(dataset_dir, x) for x in listdir(dataset_dir) if is_image_file(x)]
crop_size = calculate_valid_crop_size(crop_size, upscale_factor)
self.hr_transform = train_hr_transform(crop_size)
self.lr_transform = train_lr_transform(crop_size, upscale_factor)
def __getitem__(self, index):
hr_image = self.hr_transform(Image.open(self.image_filenames[index]))
lr_image = self.lr_transform(hr_image)
return lr_image, hr_image
def __len__(self):
return len(self.image_filenames)
class ValDatasetFromFolder(Dataset):
def __init__(self, dataset_dir, upscale_factor):
super(ValDatasetFromFolder, self).__init__()
self.upscale_factor = upscale_factor
self.image_filenames = [join(dataset_dir, x) for x in listdir(dataset_dir) if is_image_file(x)]
def __getitem__(self, index):
hr_image = Image.open(self.image_filenames[index])
w, h = hr_image.size
crop_size = calculate_valid_crop_size(min(w, h), self.upscale_factor)
lr_scale = Resize(crop_size // self.upscale_factor, interpolation=InterpolationMode.BICUBIC)
hr_scale = Resize(crop_size, interpolation=InterpolationMode.BICUBIC)
hr_image = CenterCrop(crop_size)(hr_image)
lr_image = lr_scale(hr_image)
hr_restore_img = hr_scale(lr_image)
return ToTensor()(lr_image), ToTensor()(hr_restore_img), ToTensor()(hr_image)
def __len__(self):
return len(self.image_filenames)
class TestDatasetFromFolder(Dataset):
def __init__(self, dataset_dir, upscale_factor):
super(TestDatasetFromFolder, self).__init__()
self.lr_path = dataset_dir + '/SRF_' + str(upscale_factor) + '/data/'
self.hr_path = dataset_dir + '/SRF_' + str(upscale_factor) + '/target/'
self.upscale_factor = upscale_factor
self.lr_filenames = [join(self.lr_path, x) for x in listdir(self.lr_path) if is_image_file(x)]
self.hr_filenames = [join(self.hr_path, x) for x in listdir(self.hr_path) if is_image_file(x)]
def __getitem__(self, index):
image_name = self.lr_filenames[index].split('/')[-1]
lr_image = Image.open(self.lr_filenames[index])
w, h = lr_image.size
hr_image = Image.open(self.hr_filenames[index])
hr_scale = Resize((self.upscale_factor * h, self.upscale_factor * w), interpolation=InterpolationMode.BICUBIC)
hr_restore_img = hr_scale(lr_image)
return image_name, ToTensor()(lr_image), ToTensor()(hr_restore_img), ToTensor()(hr_image)
def __len__(self):
return len(self.lr_filenames)
3. 自定义GAN网络模型文件 model.py
import math
import torch
from torch import nn
class Generator(nn.Module):
def __init__(self, scale_factor):
upsample_block_num = int(math.log(scale_factor, 2))
super(Generator, self).__init__()
self.block1 = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=9, padding=4),
nn.PReLU()
)
self.block2 = ResidualBlock(64)
self.block3 = ResidualBlock(64)
self.block4 = ResidualBlock(64)
self.block5 = ResidualBlock(64)
self.block6 = ResidualBlock(64)
self.block7 = nn.Sequential(
nn.Conv2d(64, 64, kernel_size=3, padding=1),
nn.BatchNorm2d(64)
)
block8 = [UpsampleBLock(64, 2) for _ in range(upsample_block_num)]
block8.append(nn.Conv2d(64, 3, kernel_size=9, padding=4))
self.block8 = nn.Sequential(*block8)
def forward(self, x):
block1 = self.block1(x)
block2 = self.block2(block1)
block3 = self.block3(block2)
block4 = self.block4(block3)
block5 = self.block5(block4)
block6 = self.block6(block5)
block7 = self.block7(block6)
block8 = self.block8(block1 + block7)
return (torch.tanh(block8) + 1) / 2
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.net = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=3, padding=1),
nn.LeakyReLU(0.2),
nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1),
nn.BatchNorm2d(64),
nn.LeakyReLU(0.2),
nn.Conv2d(64, 128, kernel_size=3, padding=1),
nn.BatchNorm2d(128),
nn.LeakyReLU(0.2),
nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=1),
nn.BatchNorm2d(128),
nn.LeakyReLU(0.2),
nn.Conv2d(128, 256, kernel_size=3, padding=1),
nn.BatchNorm2d(256),
nn.LeakyReLU(0.2),
nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=1),
nn.BatchNorm2d(256),
nn.LeakyReLU(0.2),
nn.Conv2d(256, 512, kernel_size=3, padding=1),
nn.BatchNorm2d(512),
nn.LeakyReLU(0.2),
nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=1),
nn.BatchNorm2d(512),
nn.LeakyReLU(0.2),
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(512, 1024, kernel_size=1),
nn.LeakyReLU(0.2),
nn.Conv2d(1024, 1, kernel_size=1)
)
def forward(self, x):
batch_size = x.size(0)
return torch.sigmoid(self.net(x).view(batch_size))
class ResidualBlock(nn.Module):
def __init__(self, channels):
super(ResidualBlock, self).__init__()
self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
self.bn1 = nn.BatchNorm2d(channels)
self.prelu = nn.PReLU()
self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
self.bn2 = nn.BatchNorm2d(channels)
def forward(self, x):
residual = self.conv1(x)
residual = self.bn1(residual)
residual = self.prelu(residual)
residual = self.conv2(residual)
residual = self.bn2(residual)
return x + residual
class UpsampleBLock(nn.Module):
"""上采样块设计"""
def __init__(self, in_channels, up_scale):
super(UpsampleBLock, self).__init__()
self.conv = nn.Conv2d(in_channels, in_channels * up_scale ** 2, kernel_size=3, padding=1)
self.pixel_shuffle = nn.PixelShuffle(up_scale)
self.prelu = nn.PReLU()
def forward(self, x):
x = self.conv(x)
x = self.pixel_shuffle(x)
x = self.prelu(x)
return x
4. 自定义损失函数文件 loss.py
import torch
from torch import nn
from torchvision.models.vgg import vgg16
class GeneratorLoss(nn.Module):
def __init__(self):
super(GeneratorLoss, self).__init__()
vgg = vgg16(pretrained=True)
loss_network = nn.Sequential(*list(vgg.features)[:31]).eval()
for param in loss_network.parameters():
param.requires_grad = False
self.loss_network = loss_network
self.mse_loss = nn.MSELoss()
self.tv_loss = TVLoss()
def forward(self, out_labels, out_images, target_images):
adversarial_loss = torch.mean(1 - out_labels)
perception_loss = self.mse_loss(self.loss_network(out_images), self.loss_network(target_images))
image_loss = self.mse_loss(out_images, target_images)
tv_loss = self.tv_loss(out_images)
return image_loss + 0.001 * adversarial_loss + 0.006 * perception_loss + 2e-8 * tv_loss
class TVLoss(nn.Module):
def __init__(self, tv_loss_weight=1):
super(TVLoss, self).__init__()
self.tv_loss_weight = tv_loss_weight
def forward(self, x):
batch_size = x.size()[0]
h_x = x.size()[2]
w_x = x.size()[3]
count_h = self.tensor_size(x[:, :, 1:, :])
count_w = self.tensor_size(x[:, :, :, 1:])
h_tv = torch.pow((x[:, :, 1:, :] - x[:, :, :h_x - 1, :]), 2).sum()
w_tv = torch.pow((x[:, :, :, 1:] - x[:, :, :, :w_x - 1]), 2).sum()
return self.tv_loss_weight * 2 * (h_tv / count_h + w_tv / count_w) / batch_size
@staticmethod
def tensor_size(t):
return t.size()[1] * t.size()[2] * t.size()[3]
if __name__ == "__main__":
g_loss = GeneratorLoss()
print(g_loss)
5. 自定义评价指标SSIM文件 pytorch_ssim_init_.py
from math import exp
import torch
import torch.nn.functional as F
from torch.autograd import Variable
def gaussian(window_size, sigma):
"""生成一维高斯滤波函数"""
gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)])
return gauss / gauss.sum()
def create_window(window_size, channel):
"""创建二维窗口"""
_1D_window = gaussian(window_size, 1.5).unsqueeze(1)
_2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
return window
def _ssim(img1, img2, window, window_size, channel, size_average=True):
"""结构相似度:用于比较两幅图像的相似度"""
mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel)
mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)
mu1_sq = mu1.pow(2)
mu2_sq = mu2.pow(2)
mu1_mu2 = mu1 * mu2
sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq
sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq
sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2
C1 = 0.01 ** 2
C2 = 0.03 ** 2
ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
if size_average:
return ssim_map.mean()
else:
return ssim_map.mean(1).mean(1).mean(1)
class SSIM(torch.nn.Module):
def __init__(self, window_size=11, size_average=True):
super(SSIM, self).__init__()
self.window_size = window_size
self.size_average = size_average
self.channel = 1
self.window = create_window(window_size, self.channel)
def forward(self, img1, img2):
(_, channel, _, _) = img1.size()
if channel == self.channel and self.window.data.type() == img1.data.type():
window = self.window
else:
window = create_window(self.window_size, channel)
if img1.is_cuda:
window = window.cuda(img1.get_device())
window = window.type_as(img1)
self.window = window
self.channel = channel
return _ssim(img1, img2, window, self.window_size, channel, self.size_average)
def ssim(img1, img2, window_size=11, size_average=True):
(_, channel, _, _) = img1.size()
window = create_window(window_size, channel)
if img1.is_cuda:
window = window.cuda(img1.get_device())
window = window.type_as(img1)
return _ssim(img1, img2, window, window_size, channel, size_average)
6. 图片测试文件 test_image.py
import argparse
import time
import torch
from PIL import Image
from torch.autograd import Variable
from torchvision.transforms import ToTensor, ToPILImage
from model import Generator
parser = argparse.ArgumentParser(description='Test Single Image')
parser.add_argument('--upscale_factor', default=4, type=int, help='super resolution upscale factor')
parser.add_argument('--test_mode', default='GPU', type=str, choices=['GPU', 'CPU'], help='using GPU or CPU')
parser.add_argument('--image_name', default='SUT1.jpg', type=str, help='test low resolution image name')
parser.add_argument('--model_name', default='netG_epoch_4_100.pth', type=str, help='generator model epoch name')
opt = parser.parse_args()
UPSCALE_FACTOR = opt.upscale_factor
TEST_MODE = True if opt.test_mode == 'GPU' else False
IMAGE_NAME = opt.image_name
IMAGE_PATH = 'test_photo/'
MODEL_NAME = opt.model_name
model = Generator(UPSCALE_FACTOR).eval()
if TEST_MODE:
model.cuda()
model.load_state_dict(torch.load('epochs/' + MODEL_NAME))
else:
model.load_state_dict(torch.load('epochs/' + MODEL_NAME, map_location=lambda storage, loc: storage))
image = Image.open(IMAGE_PATH + IMAGE_NAME)
image = Variable(ToTensor()(image)).unsqueeze(0)
print(image.shape)
if TEST_MODE:
image = image.cuda()
start = time.process_time()
out = model(image)
elapsed = (time.process_time() - start)
print('cost ' + str(elapsed) + ' s')
out_img = ToPILImage()(out[0].data.cpu())
out_img.save('test_photo/out_srf_' + str(UPSCALE_FACTOR) + '_' + IMAGE_NAME)
7. 视频测试文件 test_video.py
import argparse
import cv2
import numpy as np
import torch
import torchvision.transforms as transforms
from PIL import Image
from torch.autograd import Variable
from torchvision.transforms import ToTensor, ToPILImage
from tqdm import tqdm
from model import Generator
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Test Single Video')
parser.add_argument('--upscale_factor', default=4, type=int, help='super resolution upscale factor')
parser.add_argument('--video_name', type=str, help='test low resolution video name')
parser.add_argument('--model_name', default='netG_epoch_4_100.pth', type=str, help='generator model epoch name')
opt = parser.parse_args()
UPSCALE_FACTOR = opt.upscale_factor
VIDEO_NAME = opt.video_name
MODEL_NAME = opt.model_name
model = Generator(UPSCALE_FACTOR).eval()
if torch.cuda.is_available():
model = model.cuda()
model.load_state_dict(torch.load('epochs/' + MODEL_NAME))
videoCapture = cv2.VideoCapture(VIDEO_NAME)
fps = videoCapture.get(cv2.CAP_PROP_FPS)
frame_numbers = videoCapture.get(cv2.CAP_PROP_FRAME_COUNT)
sr_video_size = (int(videoCapture.get(cv2.CAP_PROP_FRAME_WIDTH) * UPSCALE_FACTOR),
int(videoCapture.get(cv2.CAP_PROP_FRAME_HEIGHT)) * UPSCALE_FACTOR)
compared_video_size = (int(videoCapture.get(cv2.CAP_PROP_FRAME_WIDTH) * UPSCALE_FACTOR * 2 + 10),
int(videoCapture.get(cv2.CAP_PROP_FRAME_HEIGHT)) * UPSCALE_FACTOR + 10 + int(
int(videoCapture.get(cv2.CAP_PROP_FRAME_WIDTH) * UPSCALE_FACTOR * 2 + 10) / int(
10 * int(int(
videoCapture.get(cv2.CAP_PROP_FRAME_WIDTH) * UPSCALE_FACTOR) // 5 + 1)) * int(
int(videoCapture.get(cv2.CAP_PROP_FRAME_WIDTH) * UPSCALE_FACTOR) // 5 - 9)))
output_sr_name = 'out_srf_' + str(UPSCALE_FACTOR) + '_' + VIDEO_NAME.split('.')[0] + '.avi'
output_compared_name = 'compare_srf_' + str(UPSCALE_FACTOR) + '_' + VIDEO_NAME.split('.')[0] + '.avi'
sr_video_writer = cv2.VideoWriter(output_sr_name, cv2.VideoWriter_fourcc('M', 'P', 'E', 'G'), fps, sr_video_size)
compared_video_writer = cv2.VideoWriter(output_compared_name, cv2.VideoWriter_fourcc('M', 'P', 'E', 'G'), fps,
compared_video_size)
success, frame = videoCapture.read()
test_bar = tqdm(range(int(frame_numbers)), desc='[processing video and saving result videos]')
for index in test_bar:
if success:
image = Variable(ToTensor()(frame), volatile=True).unsqueeze(0)
if torch.cuda.is_available():
image = image.cuda()
out = model(image)
out = out.cpu()
out_img = out.data[0].numpy()
out_img *= 255.0
out_img = (np.uint8(out_img)).transpose((1, 2, 0))
sr_video_writer.write(out_img)
out_img = ToPILImage()(out_img)
crop_out_imgs = transforms.FiveCrop(size=out_img.width // 5 - 9)(out_img)
crop_out_imgs = [np.asarray(transforms.Pad(padding=(10, 5, 0, 0))(img)) for img in crop_out_imgs]
out_img = transforms.Pad(padding=(5, 0, 0, 5))(out_img)
compared_img = transforms.Resize(size=(sr_video_size[1], sr_video_size[0]), interpolation=Image.BICUBIC)(
ToPILImage()(frame))
crop_compared_imgs = transforms.FiveCrop(size=compared_img.width // 5 - 9)(compared_img)
crop_compared_imgs = [np.asarray(transforms.Pad(padding=(0, 5, 10, 0))(img)) for img in crop_compared_imgs]
compared_img = transforms.Pad(padding=(0, 0, 5, 5))(compared_img)
top_image = np.concatenate((np.asarray(compared_img), np.asarray(out_img)), axis=1)
bottom_image = np.concatenate(crop_compared_imgs + crop_out_imgs, axis=1)
bottom_image = np.asarray(transforms.Resize(
size=(int(top_image.shape[1] / bottom_image.shape[1] * bottom_image.shape[0]), top_image.shape[1]))(
ToPILImage()(bottom_image)))
final_image = np.concatenate((top_image, bottom_image))
compared_video_writer.write(final_image)
success, frame = videoCapture.read()