前言
文献名称:Joint Autoregressive and Hierarchical Priors for Learned Image Compression
本文基于CompressAI的库进行复现
github地址:compressAI
关于compressAI相关博客说明:CompressAI:基于pytorch的图像压缩框架使用
安装好compressai后相当于把这个底层库引入了我们的工程
1、创建工程
我这里在做好compressAI所需前置工作后,把我需要的几个文件单独移入了自己的工程
其实这里如果安装了compressai,可以不用引入google文件,我这里放进来只是方便自己查看
2、创建训练命令脚本
因为这篇文献其实是可以直接调用相关命令对实验进行复现,所以主要把相关命令输入,按照步骤一步步执行即可,相关脚本命令可以参考sh脚本运行命令
创建成功后:
conda activate torchll
python3 train.py -m mbt2018 -d /xxx/imageNet --epochs 500 -lr 1e-4 --lambda 0.0016 --batch-size 16 --cuda --save > log.out
其中patchsize取的默认256,最后可以加上自己保持的路径
具体的命令说明可以参见官网以及CompressAI:基于pytorch的图像压缩框架使用
训练完成后,会在.out文件中打印相关输出
3、整理训练流程
-
加载训练命令,如上述图片的数据集、 e p o c h epoch epoch、学习率、 p a t c h s i z e patchsize patchsize、随机种子等
-
对数据集图片进行剪裁
分别读取train训练数据集与test测试数据集的图片,剪裁大小 p a t c h s i z e patchsize patchsize初始值是256,train训练集随机剪裁成 p a t c h s i z e ∗ p a t h s i z e patchsize*pathsize patchsize∗pathsize,test测试数据集中心剪裁成 p a t c h s i z e ∗ p a t h s i z e patchsize*pathsize patchsize∗pathsize -
加载剪裁好的两个数据集,以及训练相关数据如 b a t c h s i z e batchsize batchsize等
-
读取需要训练的模型网络、进行优化器的初始设置
(1)模型引用位置compressai.zoo
这里的网络是通过之前的训练命令,调用mbt2018的模型,对应本篇文献的网络模型,模型代码可以参见附录 -
训练前先判断是否在之前有保存的节点,有则继续训练,无则重新开始
-
对train训练集以及test测试集分别带入优化器进行梯度下降,开始模型训练
-
损失收敛后训练完成,保存模型以及相应的节点
4、更新CDF保证熵编码的正常运行(模型地址与1保持一致)
相关脚本命令可以参考sh脚本运行命令
官方命令
python -m compressai.utils.update_model --architecture ARCH checkpoint_best_loss.pth.tar
自己命令
python -m compressai.utils.update_model --architecture mbt2018 checkpoint_best_loss.pth.tar
5、评价模型,获取性能值
官方命令
python -m compressai.utils.eval_model checkpoint /path/to/image/dataset -a ARCH -p path/to/checkpoint-xxxxxxxx.pth.tar
自己的命令
python3 -m compressai.utils.eval_model checkpoint /home/zqb/lilin/code/Joint-autoregressive-and-hierarchical-priors-reappearance/kodak -a mbt2018 -p /home/zqb/lilin/code/Joint-autoregressive-and-hierarchical-priors-reappearance/checkpoint_best_loss-779c59c4.pth.tar > eval_log.out
可以通过输出的日志看到评估结果
6、整理评估流程
评估流程从代码命令上看可以知道是直接调用compressai.utils.eval_model
,即下图所示文件
核心代码如下
# 评估模型
def eval_model(model, filepaths, entropy_estimation=False, half=False):
device = next(model.parameters()).device
metrics = defaultdict(float)
for f in filepaths:
x = read_image(f).to(device)
if not entropy_estimation:
if half:
model = model.half()
x = x.half()
rv = inference(model, x)
else:
rv = inference_entropy_estimation(model, x) # 获取相关性能值psnr等
for k, v in rv.items():
metrics[k] += v
for k, v in metrics.items():
metrics[k] = v / len(filepaths)
return metrics
- 遍历给定路径下的所有图片,读取图片
- 每一层遍历调用模型前向传播、compress、depress方法,获取重建后的x_hat
- 每一层遍历中计算像素数、psnr等性能值
- 获取指定路径下所有性能值相加,求平均
7、实验
1、官方文献实验数据
bpp | psnr |
---|---|
0.071997 | 26.804116 |
0.153354 | 28.880747 |
0.264381 | 30.927089 |
0.428511 | 33.028649 |
0.635404 | 34.998064 |
0.904279 | 37.053312 |
1.258828 | 39.120817 |
1.982050 | 42.165220 |
2.992778 | 45.074915 |
2、自己实验数据
N (int): Number of channels
M (int): Number of channels in the expansion layers (last layer of the encoder and last layer of the hyperprior decoder)
上面两个参数我记为num_filters 、num_filters_M
1、第一个码率点
(1)joint_FIR_01(学习率1e-4,lambda=0.0016,epochs=500)
model_name | lambda | bpp | epochs | num_filters | num_filters_M | psnr | ms-ssim | encoding_time | decoding_time |
---|---|---|---|---|---|---|---|---|---|
joint_FIR_01 | 0.0016 | 0.11311848958333331 | 500 | 192 | 192 | 27.9804873104851 | 0.9130257219076157 | 8.318476100762686 | 19.977182120084763 |
增大epoch进行测试训练
(2)joint_FIR_02(学习率1e-4,lambda=0.0016,epochs=1000)
model_name | lambda | bpp | epochs | num_filters | num_filters_M | psnr | ms-ssim | encoding_time | decoding_time |
---|---|---|---|---|---|---|---|---|---|
joint_FIR_01 | 0.0016 | 0.11311848958333331 | 500 | 192 | 192 | 27.9804873104851 | 0.9130257219076157 | 8.318476100762686 | 19.977182120084763 |
joint_FIR_02 | 0.0016 | 0.11440022786458331 | 1000 | 192 | 192 | 27.919279305731433 | 0.911840446293354 | 5.100616623957952 | 10.740893175204596 |
可以看到几乎没有差别,最后取joint_FIR_02作为第一个码率点模型,因为解码时间更少,方便和后面做对比
2、第二个码率点
(1)joint_SEC_01(学习率1e-4,lambda=0.0032,epochs=500)
model_name | lambda | bpp | epochs | num_filters | num_filters_M | psnr | ms-ssim | encoding_time | decoding_time |
---|---|---|---|---|---|---|---|---|---|
joint_SEC_01 | 0.0032 | 0.1898600260416667 | 500 | 192 | 192 | 29.69223977516354 | 0.9429223984479904 | 3.911218543847402 | 8.840558975934982 |
3、第三个码率点
(1)joint_THI_01(学习率1e-4,lambda=0.0075,epochs=500)
model_name | lambda | bpp | epochs | num_filters | num_filters_M | psnr | ms-ssim | encoding_time | decoding_time |
---|---|---|---|---|---|---|---|---|---|
joint_THI_01 | 0.0075 | 0.35229153103298616 | 500 | 192 | 192 | 31.694376047084095 | 0.9636343965927759 | 3.875972956418991 | 8.78865984082222 |
性能有些许下降,考虑增大epoch试试
(2)joint_THI_02(学习率1e-4,lambda=0.0075,epochs=1000)
model_name | lambda | bpp | epochs | num_filters | num_filters_M | psnr | ms-ssim | encoding_time | decoding_time |
---|---|---|---|---|---|---|---|---|---|
joint_THI_01 | 0.0075 | 0.35229153103298616 | 500 | 192 | 192 | 31.694376047084095 | 0.963634396592775 |