Generative Adversarial Networks for Extreme Learned Image Compression论文中的量化操作

部署运行你感兴趣的模型镜像

Generative Adversarial Networks for Extreme Learned Image Compression论文中的量化操作

论文链接:https://arxiv.org/abs/1804.02958

代码地址:https://github.com/Justin-Tan/generative-compression(只找到了基于tensoflow1.x版本的代码,哭了)

Agustsson, Eirikur, et al. “Generative adversarial networks for extreme learned image compression.” Proceedings of the IEEE/CVF International Conference on Computer Vision. 2019.

量化方法

在这篇论文中并没有详细提到量化的方法细节,仅提到量化方法使用的是https://arxiv.org/abs/1801.04260这篇论文中提到的量化方法,经过查阅另一篇论文我弄清楚了该量化方法的细节。

image-20241223141943460

由于硬量化在反向传播的过程中存在不可导的问题,所以该方法是在前向传播中使用了硬量化,反向传播中使用的是软量化部分的梯度。

量化方法的PyTorch实现

import torch
import torch.nn.functional as F

def quantizer(w, temperature=1, L=5):
    # 定义中心点
    centers = torch.arange(-2, 3, dtype=torch.float32, device=w.device)  # [-2.0, -1.0, 0.0, 1.0, 2.0]
    
    # 堆叠w在最后一个维度上L次
    w_stack = w.unsqueeze(-1).repeat(*[1]*w.dim(), L)  # 形状:[batch, height, width, channels, L]
    
    # 计算w_stack与每个中心点的绝对差值
    abs_diff = torch.abs(w_stack - centers)
    
    # 取最小差值的位置,得到w_hard的索引
    w_hard_indices = torch.argmin(abs_diff, dim=-1)
    
    # 转换回实际的值
    w_hard = w_hard_indices.float() + centers.min()
    
    # 计算softmax,温度参数temperature
    smx = F.softmax(-1.0 / temperature * abs_diff, dim=-1)
    
    # 计算w_soft,通过einsum
    w_soft = torch.einsum('ijklm,m->ijkl', smx, centers)
    
    # 实现stop_gradient,用detach()
    w_bar = (w_hard - w_soft).detach() + w_soft
    
    return w_bar

您可能感兴趣的与本文相关的镜像

PyTorch 2.5

PyTorch 2.5

PyTorch
Cuda

PyTorch 是一个开源的 Python 机器学习库,基于 Torch 库,底层由 C++ 实现,应用于人工智能领域,如计算机视觉和自然语言处理

评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

一只通信仔

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值