Generative Adversarial Networks for Extreme Learned Image Compression论文中的量化操作
论文链接:https://arxiv.org/abs/1804.02958
代码地址:https://github.com/Justin-Tan/generative-compression(只找到了基于tensoflow1.x版本的代码,哭了)
Agustsson, Eirikur, et al. “Generative adversarial networks for extreme learned image compression.” Proceedings of the IEEE/CVF International Conference on Computer Vision. 2019.
量化方法
在这篇论文中并没有详细提到量化的方法细节,仅提到量化方法使用的是https://arxiv.org/abs/1801.04260这篇论文中提到的量化方法,经过查阅另一篇论文我弄清楚了该量化方法的细节。

由于硬量化在反向传播的过程中存在不可导的问题,所以该方法是在前向传播中使用了硬量化,反向传播中使用的是软量化部分的梯度。
量化方法的PyTorch实现
import torch
import torch.nn.functional as F
def quantizer(w, temperature=1, L=5):
# 定义中心点
centers = torch.arange(-2, 3, dtype=torch.float32, device=w.device) # [-2.0, -1.0, 0.0, 1.0, 2.0]
# 堆叠w在最后一个维度上L次
w_stack = w.unsqueeze(-1).repeat(*[1]*w.dim(), L) # 形状:[batch, height, width, channels, L]
# 计算w_stack与每个中心点的绝对差值
abs_diff = torch.abs(w_stack - centers)
# 取最小差值的位置,得到w_hard的索引
w_hard_indices = torch.argmin(abs_diff, dim=-1)
# 转换回实际的值
w_hard = w_hard_indices.float() + centers.min()
# 计算softmax,温度参数temperature
smx = F.softmax(-1.0 / temperature * abs_diff, dim=-1)
# 计算w_soft,通过einsum
w_soft = torch.einsum('ijklm,m->ijkl', smx, centers)
# 实现stop_gradient,用detach()
w_bar = (w_hard - w_soft).detach() + w_soft
return w_bar
953

被折叠的 条评论
为什么被折叠?



