盲超分率-元学习和KernelGAN结合-MetaKernelGAN-Meta-Learned Kernel For Blind Super-Resolution Kernel Estimation
论文链接:Meta-Learned Kernel For Blind Super-Resolution Kernel Estimation
源码链接: royson/metakernelgan
MetaKernelGAN通过结合元学习与KernelGAN的方法,实现了对模糊核的有效估计和高质量图像的恢复,从而有效解决了盲超分辨率问题。
主要涉及到一下两个重点前置知识
元学习
元学习的简单理解
通过一下公式来做一个简单的理解,元学习分为内外层,具体训练逻辑如下:
首先外层会初始化一个w,这个作为算法参数,在外层更新,这个w会作为条件输入到内层循环中
内层用外层初始化的参数w在内层的支撑集上进行训练,得到在算法参数w的情况下最好内层参数o,用内层训练得到的参数作为外层的输入,在查询集上进行训练,如果w是好的算法参数,那么o在外层也能有好的表现,如果内层参数o不好,则说明是原来的外层参数w是不好的,则更新外层的算法参数w
更新算法参数w之后,循环上面的过程输入到内存进行内层参数更新,再更新外层。
如此循环后,w算法参数就能够得到适用于不同任务的算法参数,只需要输入一个类型任务,经过内层循环后,w就能够很快适用到该类型任务上。
主要的研究方向
论文使用的是学习初始化,下面介绍一些主要研究方向:学习优化(Learning to Optimize)、学习初始化(Learning to Initialize)、学习权重(Learning to Weight)、学习奖励(Learning to Reward)、学习增强(Learning to Augment)、数据集蒸馏(Dataset Distillation)、神经架构搜索(Neural Architecture Search)
学习优化是元学习的一个重要方向,旨在通过训练学习算法本身的优化过程。传统的优化算法(如 SGD、Adam)是手动设计的,而学习优化则试图通过元学习的方法来自动生成或改进优化算法,从而提高模型在新任务上的训练效率和效果。这一方向的研究通常涉及到设计优化器,使其能够快速适应不同的学习任务和数据分布。
本篇文章主要使用的学习初始化(Learning to Initialize),最经典的也就是下面这种模型MAML,学习初始化关注的是如何选择或学习模型参数的初始值,以便在训练新任务时能够更快地收敛。通过元学习,模型可以从多个任务的经验中获取初始化参数的有效性,从而在面对新的任务时,能够更迅速地找到适合的参数配置。这对于处理少样本学习等任务尤为重要,因为良好的初始化可以显著减少训练所需的时间和样本数量。
MAML
MAML的全称是Model-Agnostic Meta-Learning。它是一种元学习方法,旨在通过学习模型的初始化参数,使得模型能够在新任务上快速适应。
快速适应:MAML的核心思想是通过在多个任务上训练一个模型,使得这个模型在遇到新的任务时能够仅通过少量的梯度更新(通常是几次迭代)就快速适应。
模型无关:MAML的一个重要特点是“模型无关”,意味着它可以应用于各种类型的模型(如神经网络、线性模型等),并不局限于特定的架构或算法。
优化过程:在MAML中,模型在多个任务上进行内层训练(任务特定的更新),然后通过外层优化来调整模型的初始化参数。这使得模型在面对新任务时能够迅速收敛。
简单来说,就是算法参数和内层参数是同一个,内外层更新的都是同一个参数。
MetaKernelGAN
主要思想是使用元学习的方法来估计图像的模糊核(blur kernel)。它通过训练一个生成模糊核的模型,使得该模型能够适应多种类型的图像。当需要估计某一特定类型图像的模糊核时,只需提供少量的图像(甚至只是一张),模型就能够有效地生成适合该图像的模糊核,从而使得图像去模糊或模糊化处理效果达到最优。
MetaKernelGAN架构图
主要有内外层循环组成,内外层循环都在更新一个下采样器(downsampling),一个判别器(discriminator)
内层循环:仅计算下采样器和判别器的损失来更新下采样器和判别器。
外层循环:除了需要计算下采样器和判别器以外,还要需要计算核损失,来更新我们的模糊核。
从图中可以看出,对于外层参数的更新,是经过多步累计之后才对外层参数的进行更新的。下面将从源码层面剖析它的训练过程
数据集准备
在数据集制作过程中,首先从DIV2K数据集中采样高分辨率图像,将其裁剪为192×192的图像块。裁剪后的图像块会随机进行90度旋转、垂直翻转或水平翻转等数据增强操作,然后下采样生成低分辨率图像,作为用于超分辨率任务的输入数据。在元学习中同时学习生成器和判别器的参数,并将仅元学习生成器作为消融实验进行对比。
在实际制作的时候的,裁剪后的Patch是使用lmdb数据库,实际上就是生成两个数据文件,图片大小都是3x192x192,在执行之前会把这个文件的图片数据读取到dataloder,之后每个step取一张图训练。
加粗样式
外层训练
外层主要是先初始化下采样器和判别器的参数,外层的更新在内层更新之后,相当于内层更新完了参数之后,然后将更新后的参数拿到外层循环去进行训练, 来检测一下内层训练的怎么样,然后更新下采样器和判别器。不断的循环之后,最终外层会得到最好下采样器参数,同时也可以通过下采样器计算出最好的模糊核。
这里不算严格意义上的外层训练,就是简单的控制需要更新多少个任务,具体的外层训练在下面的另一张图里进行。
# 这里不算严格意义上的外层训练,就是简单的控制需要更新多少个任务
def train(self):
# 初始化训练数据迭代器,使用cycle函数使得迭代器无限循环遍历数据
self.train_iter = iter(cycle(self.train_dataloader))
# 确保训练步数 existing_step 小于总步数 steps,否则训练已经完成
assert self.existing_step < self.steps, 'Training is done.'
# 记录剩余的训练步数
logger.info(f'{self.steps - self.existing_step} steps left.')
# 从 existing_step 开始训练到 steps 结束
for i in range(self.existing_step, self.steps):
# 逐个将所有模型设为训练模式
for model_typ, model in self.meta_models.items():
model.train() # 设置为训练模式
self.optimizers[model_typ].zero_grad() # 清空模型的梯度
# 调用 meta_train 函数执行具体训练步骤,返回损失值
loss = self.meta_train(i)
# 更新所有优化器的参数
for opt in self.optimizers.values():
opt.step()
# 更新学习率调度器
for sch in self.schedulers.values():
sch.step()
# 记录训练的日志,包含当前的步数和损失
log_msg = f'[Step {i+1}] Train error: {loss}, '
logger.debug(log_msg)
# 每隔指定的步数(save_every)保存一次模型
if (i + 1) % self.args.optim.save_every == 0:
logger.info(log_msg) # 输出日志信息
self.save_ckp(i + 1, models=self.meta_models) # 保存检查点
# 记录当前的训练步数
self.log_current_step(i + 1)
内层训练
- 首先从192x192大小的图片进行模糊化,生成两张图
patch_lrs_son_t
(# 第二次下采样,作为生成模糊的标准图像,用来更新外层参数)—>gt_patch_son
patch_lrs_dad_t
(# 第一次降采样,用来送入内层来更新内层参数)—>patch_dad
# 来源于下面代码 gt_patch_son, patch_dad, gt_kernel = self.sample_task()
def sample_task(self):
'''
具体实现时,以2X为例,是从输入图上,分别裁剪两个patch,一个给G网络,分辨率高些,代码里是64x64,这张图经过G网络之后,分辨率变为26x26,因为没有做padding,所以不是32x32,
同时裁剪一个26x26的patch,是给D网络的,同时加入了少量的噪声,这样就构成了训练数据对。当然,裁剪patch时,并不是随机的,是通过构建了一个概率map,跟图像内容相关,
应该是让尽量能取到边缘细节吧,G patch和D patch尽量纹理接近吧
'''
kernel = dutils.get_downsampling_ops(self.degradation_operation, train=True, scale=self.scale) # 将标准kernel记录下来
logger.debug(f'Train task: kernel: {kernel.shape if kernel is not None else kernel}')
labels = next(self.train_iter)
# 第一次降采样,用来送入内层来更新内层参数
patch_lrs_dad_t = fkp.batch_degradation(labels, kernel, np.array([self.scale, self.scale]), self.args.data.img_noise, device=self.device)
print(patch_lrs_dad_t.shape)
# 第二次下采样,作为生成模糊的标准图像,用来更新外层参数
patch_lrs_son_t = fkp.batch_degradation(patch_lrs_dad_t, kernel, np.array([2, 2]), self.args.data.img_noise, device=self.device)
print(patch_lrs_son_t.shape)
if kernel is not None:
kernel = torch.from_numpy(kernel).float()
return patch_lrs_son_t, patch_lrs_dad_t, kernel
- 也就是上面采样的patch_lrs_dad_t,基于这张裁剪图进行内层更新,作为内层训练的输入(96x96)
patch_dad
,内层训练是现将patch_dad
下采样得到patch_son
,然后再patch_dad
和patch_son
在相对位置来进行随机裁剪,得到fake_d_in
(patch_son
得到)和real_d_in
(patch_dad
得到),首先用fake_d_in
下采样器生成的图像来计算下采样生成器的loss
值,然后更新下采样器的参数,然后用更新后的下采样器再对patch_dad
重新进行下采样在得到patch_son
,再从新的patch_son
中找到原来fake_d_in
的位置取出来得到g_out
,然后让g_out
和real_d_in
来计算判别器的loss值,然后更新判别器,这样完成了一次内层更新,是基于patch_dad
图像
def meta_train(self, step):
# 初始化总损失和用于存储克隆后的模型字典 learners
loss = 0
learners = {}
# 克隆每个 meta 模型,存入 learners 中
for model_typ, model in self.meta_models.items():
learners[model_typ] = model.clone()
# 从任务中采样获取真实的低分辨率图像、父图像和真实核
gt_patch_son, patch_dad, gt_kernel = self.sample_task()
print(gt_patch_son.shape)
print(patch_dad.shape)
# 将数据传输到设备上(CPU或GPU)
gt_patch_son, patch_dad, gt_kernel = \
tutils.to_device(*[gt_patch_son, patch_dad, gt_kernel], cpu=self.args.sys.cpu)
# 初始化验证过程中低分辨率子图像和判别器误差的列表
v_lr_son_error = []
v_dis_error = []
# 计算 patch_dad 的面积
area = np.prod(patch_dad.shape[-2:])
# 通过多个适应步骤来调整模型
for adapt_step in range(learners['downsampling'].optim_args.task_adapt_steps):
# 使用克隆的下采样模型生成子图像 patch_son
patch_son = learners['downsampling'](patch_dad)
print(patch_son.shape)
# 随机裁剪生成器输出的子图像和父图像,获取裁剪后的图像块和坐标
fake_d_in, real_d_in, coords = tutils.random_image_crop(patch_son, \
img_dad=patch_dad, scale=self.scale, patch_size=self.dis_input_ps)
print(fake_d_in.shape) # 裁剪后的子图像
print(real_d_in.shape) # 裁剪后的父图像
# 计算生成器损失(对抗损失)
gen_error = self.compute_loss_generator(learners['discriminator'], fake_i=fake_d_in)
# 添加正则化损失
reg_error = self.compute_kernel_reg(learners['downsampling'])
gen_error += reg_error
# 更新下采样模块参数
self.lr_decay(adapt_step, learners['downsampling'], area)
learners['downsampling'].adapt(gen_error, first_order=learners['downsampling'].optim_args.first_order)
# 更新判别器参数
patch_son = learners['downsampling'](patch_dad)
patch_son = patch_son.detach()
# 根据生成器生成的坐标从子图像中提取图像块
g_out = tutils.get_patch_using_coordinates(patch_son, coords, patch_size=self.dis_input_ps)
print(g_out.shape)
# 计算判别器损失(基于真实图像块和生成图像块)
dis_error = self.compute_loss_discriminator(learners['discriminator'], real_i=real_d_in, fake_i=g_out)
# 更新判别器参数
learners['discriminator'].adapt(dis_error, first_order=learners['discriminator'].optim_args.first_order)
# 在指定的步数进行验证
if self.args.optim.validate_steps and (adapt_step + 1) in self.args.optim.validate_steps:
logger.debug(f'Validating at step {adapt_step}')
# 进行验证,计算当前模型的生成器和判别器误差
g_er, d_er = self._validate_meta_train(learners, patch_dad, gt_patch_son, gt_kernel)
v_lr_son_error.append(g_er)
v_dis_error.append(d_er)
# 如果没有指定验证步数,直接在最后一次计算外层误差
if not self.args.optim.validate_steps:
v_lr_son_error, v_dis_error = self._validate_meta_train(learners, patch_dad, gt_patch_son, gt_kernel)
else:
# 根据步骤获取损失权重并计算加权损失
loss_weights = self._get_loss_weights(step).to(self.device)
v_lr_son_error = torch.sum(loss_weights * torch.stack(v_lr_son_error).squeeze())
if self.outer_lsgan_loss:
v_dis_error = torch.sum(loss_weights * torch.stack(v_dis_error).squeeze())
# 外层循环中的生成器损失计算与回传
loss += v_lr_son_error.item()
v_lr_son_error.backward(retain_graph=not learners['discriminator'].optim_args.first_order)
# 如果不是一阶优化,则判别器的梯度归零
if not learners['downsampling'].optim_args.first_order and 'discriminator' in learners:
self.optimizers['discriminator'].zero_grad()
# 外层循环中的判别器损失计算与回传
if self.outer_lsgan_loss:
loss += v_dis_error.item()
self.optimizers['discriminator'].zero_grad()
v_dis_error.backward()
# 更新模型参数
for model_typ, model in self.meta_models.items():
if model.optim_args.task_opt.upper() == 'ADAM':
# 将克隆模型的梯度复制到原始模型中
for model_p, clone_p in zip(model.parameters(), learners[model_typ].parameters()):
if model_p.requires_grad:
model_p.grad = clone_p.grad.clone()
# 如果任务设置了梯度裁剪,则进行梯度裁剪
if model.optim_args.task_gradient_clip is not None and model.optim_args.task_gradient_clip > 0:
nn.utils.clip_grad_value_(model.parameters(), model.optim_args.task_gradient_clip)
return loss
- 内层更新完成以后,我们每隔一些步数就计算一次外层的loss值(这里使用
patch_dad
和gt_patch_son
来进行计算的,以gt_patch_son
作为下采样之后的标准,通过采用下采样器采样patch_dad
得到patch_son
,让它和gt_patch_son
来计算loss值更新参数),计算loss
值的时候就不仅仅是下采样器和判别器了,还有kernel的loss值,方便后续在外层更新kernel
的参数,但是不进行更新,只是记录下来,等到完成一次任务的所有内层步数之后,再对外层参数进行加权累加求loss
值,更新外层参数,但是实际上由于是基于MAML
的,所以外层参数还是下采样器和判别器,但是外层还需要获得一个可以适应多个任务的核,所以还需要更新kernel的参数。
def _validate_meta_train(self, learners, patch_dad, gt_patch_son, gt_kernel):
# 使用 learners['downsampling'] 对输入的 patch_dad 进行下采样,得到 patch_son
patch_son = learners['downsampling'](patch_dad)
# 初始化验证过程中低分辨率图像的误差和总损失
v_lr_son_error = 0
loss = 0
# 验证集上的训练,使用已经更新的下采样和判别器参数
if self.gt_son_loss:
# 计算生成的低分辨率图像 patch_son 和真实低分辨率图像 gt_patch_son 之间的 L1 损失
v_lr_son_error += self.l1_loss(patch_son, gt_patch_son)
logger.debug(f'L1 LR_son Loss: {v_lr_son_error.item()}')
if self.outer_lsgan_loss:
# 随机裁剪生成器输出 patch_son,获取裁剪的图像块 d_in 和相应的坐标 coords
d_in, _, coords = tutils.random_image_crop(patch_son, patch_size=self.dis_input_ps)
# 计算生成器的损失(对抗损失)
g_error = self.compute_loss_generator(learners['discriminator'], d_in)
v_lr_son_error += g_error
# 计算正则化误差,使用双三次插值损失作为正则化的一部分
reg_error = self.compute_kernel_reg(learners['downsampling'], enable_bicubic_loss=True, g_input=patch_dad, g_output=patch_son)
v_lr_son_error += reg_error
logger.debug(f'Reg Error: {reg_error.item()}, Gen Error: {g_error.item()}')
if self.kernel_loss:
# 计算生成的退化核和真实核 gt_kernel 之间的误差
k_error = self.compute_kernel_loss(learners['downsampling'], gt_kernel)
v_lr_son_error += k_error
logger.debug(f'Total Kernel Error: {k_error.item()}')
logger.debug(f'Total Gen error: {v_lr_son_error.item()}')
# 初始化判别器误差
dis_error = 0
if self.outer_lsgan_loss:
# 使用生成器生成的图像(patch_son)和裁剪坐标,获取生成图像对应的图像块 g_out
g_out = tutils.get_patch_using_coordinates(patch_son.detach(), coords, patch_size=self.dis_input_ps)
# 获取真实图像 gt_patch_son 中与生成器输出对应的图像块 real_d_in
real_d_in = tutils.get_patch_using_coordinates(gt_patch_son, coords, patch_size=self.dis_input_ps)
# 计算判别器的损失,比较真实图像块和生成图像块
dis_error = self.compute_loss_discriminator(learners['discriminator'], real_i=real_d_in, fake_i=g_out)
logger.debug(f'Dis error: {dis_error.item()}')
# 返回生成器误差和判别器误差
return v_lr_son_error, dis_error
如需源码讲解可以联系我