论文地址:https://arxiv.org/abs/2011.03029
代码地址:https://github.com/InterDigitalInc/CompressAI
一. 简介
一个关于End-to-End Image Compression 的pytorch库,预计复现以下四篇论文的模型。
- 《End-to-end Optimized Image Compression》
- 《Variational Image Compression With A Scale Hyperprior》
- 《Joint Autoregressive and Hierarchical Priors for Learned Image Compression》
- 《Learned Image Compression with Discretized Gaussian Mixture Likelihoods and
Attention Modules》 ------该论文正在进行中
二. 安装
因为国内网的问题,先下载好torch是比较好的安装方式,先将Conda内源换成清华镜像源,虽然清华源也很慢很慢很慢,但是比国外源略微快上一丢丢。
- 安装cudatoolkit和cudnn包,需要和pytorch版本相对应。
- 安装pytorch。
- 下载compressAi 库,并且用 pip 进行安装。
- 检测安装结果。
用conda创建python=3.8, cuda=10.1的虚拟环境(建议)
conda create -n env_name python=3.8 cudatoolkit=10.1 cudnn
创建环境后需要激活虚拟环境,以在该虚拟环境下下载对应的库包。
conda activate env_name
如果使用基础的环境,则直接通过conda下载cuda即可
conda install cudatoolkit=10.1 cudnn
安装pytotch,进入pytorch官网:https://pytorch.org,选择对应的版本
根据提示输入下载命令, -c pytorch 是指指定官方通道,torch的服务器在国外,emmm,会断的厉害,但是清华源断的也挺厉害的,多下载几次=-=:
conda install pytorch torchvision torchaudio
torch下载完成后,根据官方指导,开始下载CompressAI, 从clone工程到你的机器上,下载结束后,进行pip安装。
git clone https://github.com/InterDigitalInc/CompressAI compressai
cd compressai
pip install -U pip && pip install -e . //该命令用下一条命令替换,更快地安装。
pip install -e . -i https://pypi.douban.com/simple //pip 豆瓣源比清华源好=-=
安装结束后,输入
conda list //查看该环境下的安装包,如果出现compressai,即安装成功
用 python 验证也可,
python
import compressai //不报错即安装成功
三. 使用
Compressai的一级结构如下,具体使用API指导:https://interdigitalinc.github.io/CompressAI/
其中,主要关注两个目录,compressai目录下即pip编译的源码,修改这里的代码会修改compresssai的API应用, example目录下的是代码是使用范例。
使用:
其中
/path/to/my/image/dataset/ 表示数据集的目录, 该数据集下分为 train 和test目录, train内部放train的 .png图像, test放测试图像。 --cuda 使用GPU,–save保存训练好的模型。
python examples/train.py -d /path/to/my/image/dataset/ --epochs 300 -lr 1e-4 --batch-size 16 --cuda --save
训练结束后需要更新CDF保证熵编码的正常运行:
python -m compressai.utils.update_model [-h] [-n NAME] [-d DIR] [–no-update] [–architecture {factorized-prior,jarhp,mean-scale-hyperprior,scale-hyperprior}] filepath
python -m compressai.utils.update_model [-d DIR] [--architecture {factorized-prior,jarhp,mean-scale-hyperprior,scale-hyperprior}] filepath
评价模型:
/path/to/images/folder/ 和上述的不同,该文件夹内直接存储需要test的png图像。
-a $ARCH 表示采用的预设定的模型,列表如下六种。
- bmshj2018_factorized
- bmshj2018_hyperprior
- mbt2018
- mbt2018_mean
- cheng2020_anchor
- cheng2020_attn
-p $MODEL_CHECKPOINT 表示存储的网络模型。
python -m compressai.utils.eval_model checkpoint /path/to/images/folder/ -a $ARCH -p $MODEL_CHECKPOINT...
四. 注意
由于训练结束需要更新Entropy的CDF以正常进行测试阶段的熵编码工作,但是上述的CDF更新制定了预先定义好的框架,当采用自己的框架的时候,CDF的更新需要自行阅读对应源码并且修改进行CDF的更新。