前言
文献名称:Learned Image Compression with Discretized Gaussian Mixture Likelihoods and Attention Modules
本文基于CompressAI的库进行复现
github地址:compressAI
关于compressAI相关博客说明:CompressAI:基于pytorch的图像压缩框架使用
安装好compressai后相当于把这个底层库引入了我们的工程
相关环境搭配可以参考Joint Autoregressive and Hierarchical Priors for Learned Image Compression文献复现文献复现
同样都是使用compressAI进行复现
1、创建工程
我这里在做好compressAI所需前置工作后,把我需要的几个文件单独移入了自己的工程
2、创建训练命令脚本
因为这篇文献其实是可以直接调用相关命令对实验进行复现,所以主要把相关命令输入,按照步骤一步步执行即可,相关脚本命令可以参考sh脚本运行命令
创建成功后:
# conda activate torchll
python3 train.py -m cheng2020-anchor -d /xxx/imageNet --epochs 1400000 -lr 1e-4 --lambda 0.0016 --batch-size 16 --cuda --save > log.out
这里的学习率,在文献中分成了两段,最后80k次迭代使用1e-5的学习率,这里我们就直接使用一个学习率进行尝试看看
其中patchsize取的默认256,最后可以加上自己保持的路径
具体的命令说明可以参见官网以及CompressAI:基于pytorch的图像压缩框架使用
训练完成后,会在.out文件中打印相关输出
3、整理训练流程
这里主要是在joint实验上进行一个网络的更改,所以仍然沿用joint的方式,更改对应的网络
-
找到对应网络结构
-
沿用joint方法,只改变网络结构相关
具体模型结构见附录
其他具体训练流程同joint实验
-
加载训练命令,如上述图片的数据集、 e p o c h epoch epoch、学习率、 p a t c h s i z e patchsize patchsize、随机种子等
-
对数据集图片进行剪裁
分别读取train训练数据集与test测试数据集的图片,剪裁大小 p a t c h s i z e patchsize patchsize初始值是256,train训练集随机剪裁成 p a t c h s i z e ∗ p a t h s i z e patchsize*pathsize patchsize∗pathsize,test测试数据集中心剪裁成 p a t c h s i z e ∗ p a t h s i z e patchsize*pathsize patchsize∗pathsize -
加载剪裁好的两个数据集,以及训练相关数据如 b a t c h s i z e batchsize batchsize等
-
读取需要训练的模型网络、进行优化器的初始设置
(1)模型引用位置compressai.zoo
这里的网络是通过之前的训练命令,调用cheng2020的模型,对应本篇文献的网络模型,模型代码可以见上面的网络结构代码 -
训练前先判断是否在之前有保存的节点,有则继续训练,无则重新开始
-
对train训练集以及test测试集分别带入优化器进行梯度下降,开始模型训练
-
损失收敛后训练完成,保存模型以及相应的节点
4、更新CDF保证熵编码的正常运行(模型地址与1保持一致)
相关脚本命令可以参考sh脚本运行命令
官方命令
python -m compressai.utils.update_model --architecture ARCH checkpoint_best_loss.pth.tar
自己命令
python -m compressai.utils.update_model --architecture mbt2018 checkpoint_best_loss.pth.tar
5、评价模型,获取性能值
官方命令
python -m compressai.utils.eval_model checkpoint /path/to/image/dataset -a ARCH -p path/to/checkpoint-xxxxxxxx.pth.tar
自己的命令
python3 -m compressai.utils.eval_model checkpoint /home/zqb/lilin/code/Joint-autoregressive-and-hierarchical-priors-reappearance/kodak -a mbt2018 -p /home/zqb/lilin/code/Joint-autoregressive-and-hierarchical-priors-reappearance/checkpoint_best_loss-779c59c4.pth.tar > eval_log.out
可以通过输出的日志看到评估结果
6、整理评估流程
评估流程从代码命令上看可以知道是直接调用comp