VARIATIONAL IMAGE COMPRESSION WITH A SCALE HYPERPRIOR文献实验复现

前言

这篇文章是在END-TO-END OPTIMIZED IMAGE COMPRESSION文献基础上进行的改进,主要是加入了超先验网络对边信息进行了处理。相关环境配置与基础可以参考END-TO-END OPTIMIZED IMAGE COMPRESSION
github地址:github

1、相关命令

(1)训练

python bmshj2018.py -V train

同样,我这里方便调试,加入了launch.json

{
   
  // 使用 IntelliSense 了解相关属性。 
  // 悬停以查看现有属性的描述。
  // 欲了解更多信息,请访问: https://go.microsoft.com/fwlink/?linkid=830387
  "version": "0.2.0",
  "configurations": [
    {
   
      "name": "Python: 当前文件",
      "type": "python",
      "request": "launch",
      "program": "${file}",
      "console": "integratedTerminal",
      "justMyCode": true,

      // bmshj2018
      "args": ["--verbose","train"] // 训练


    
    } 
  ]
}

2、数据集

  if args.train_glob:
    train_dataset = get_custom_dataset("train", args)
    validation_dataset = get_custom_dataset("validation", args)
  else:
    train_dataset = get_dataset("clic", "train", args)
    validation_dataset = get_dataset("clic", "validation", args)
  validation_dataset = validation_dataset.take(args.max_validation_steps)

给定数据集地址的情况下使用对应数据集,否则使用默认CLIC数据集

3、整理训练流程

整个代码其实主要还是在bls2017.py上进行的补充,所以整个流转流程是和bls2017的过程一样,这里我简单整理一下这篇文献的训练过程

  1. 实例化对象
    init中初始化一些参数(LocationScaleIndexedEntropyModel、ContinuousBatchedEntropyModel两个熵模型中所需参数,对应的模型说明可以见tfc熵模型,具体的参数有scale_fn以及给定的num_scales)
    初始化非线性分析变换、非线性综合变换、超先验非线性分析变换、超先验非线性综合变换、添加均匀噪声、获得先验概率等
    call中通过两个熵模型计算
    (1)x->y(非线性分析变换)
    (2)y->z(超先验非线性分析变换)
    (3)z->z_hat、边信息bit数(ContinuousBatchedEntropyModel熵模型)
    (4)y->y_hat、比特数(LocationScaleIndexedEntropyModel熵模型)
    (5)计算bpp(这里bpp=bit/px,bit数等于上面两项bit数相加)、mse、loss
    其中非线性分析变换即cnn卷积的过程,其中有一步补零的过程是为了输入与输出图片尺寸相等

  2. model.compile
    通过model.compile配置训练方法,使用优化器进行梯度下降、算bpp、mse、lose的加权平均

  3. 过滤剪裁数据集
    通过参数查看是否给定了数据集路径
    给定数据集路径:直接剪裁成256x256(统一剪裁成256*256送入网络训练,后面压缩的图片不会改变大小)
    未给定数据集路径:用CLIC数据集,过滤出图片大小大于256x256的三通的图片,然后进行剪裁
    分出训练数据集与验证数据集

  4. model.fit
    传入训练数据集,设置相关参数(epoch等)进行训练
    通过retval = super().fit(*args, **kwargs)进入model的train_step(不得不说这里封装的太严实了…但从单文件的代码根本找不到怎么跳进去的)

  5. train_step
    self.trainable_variables获取变量集,通过传入变量集与定义的损失函数,进行前向传播与反向误差传播更新参数。然后更新loss, bpp, mse

  def train_step(self, x):
    with tf.GradientTape() as tape:
      loss, bpp, mse = self(x, training=True)
    variables = self.trainable_variables
    gradients = tape.gradient(loss, variables)
    self.optimizer.apply_gradients(zip(gradients, variables))
    self.loss.update_state(loss)
    self.bpp.update_state(bpp)
    self.mse.update_state(mse)
    return {
   m.name: m.result() for m in [self.loss, self.bpp, self.mse]}
  1. 接下来就是重复训练的过程,直到到终止条件(比如epoch达到10000)

4、压缩一张图片

python bmshj2018.py [options] compress original.png compressed.tfci

这里通过launch.json调试

"args": ["--verbose","compress","./models/kodak/kodim01.png", "./models/kodim01.tfci"] // 压缩一张图

引入你自己的path

 "--model_path", default="./models/bmshj2018Model/bmshj2018_test",

我这里压缩一张图片是通过终端调试进行的
压缩成功后

Mean squared error: 208.0925
PSNR (dB): 24.95
Multiscale SSIM: 0.9040
Multiscale SSIM (dB): 10.18
Bits per pixel: 0.2020

生成的tfci文件如下
在这里插入图片描述

5、整理压缩一张图片流程

  1. 加载模型、读取图片文件
  2. 调用模型的compress方法,更改x、y、z即原图像与重建后图像、以及由y经超先验网络提取的随便变量z的维度,进行非线性分析变换(编码)、超先验网络的非线性变换,返回对应的张量
  3. 获取压缩张量的压缩表示,通过压缩表示和shape信息写一个二进制文件
  4. 调用模型depress方法(调用两个熵模型的解压方法获得y_hat与z_hat,更改对应shape信息,进而通过非线性综合变换进行解码),解压获取重建后的图像tensor
  5. 对比原图像和重建图像,计算相应性能(mse、psnr、msssim、msssim_db)
  6. 计算码率bpp(先计算出总的像素值,码率就是每像素所占bit数)

6、压缩一个文件夹下的所有图片

  1. 添加压缩文件夹下所有图片的命令
  # 'compressAll' subcommand.
  compressAll_cmd = subparsers.add_parser(
      "compressAll",
      formatter_class=argparse.ArgumentDefaultsHelpFormatter,
      description="读取文件下的文件进行压缩操作")
  
  # Arguments for 'compressAll'.
  compressAll_cmd.add_argument(
    "input_folder",
    help="输入文件夹.")
  compressAll_cmd.add_argument(
    "output_folder",
    help="输出文件夹.")
  1. 主分支加入压缩所有图片对应的方法
  elif args.command == "compressAll":
    compressAll(args)

在这里插入图片描述

  1. 添加压缩所有图片对应的方法
    主要是利用调用compress方法获取相应的属性值,其中此方法中需要拼接compress所需命令
def compressAll(args):
  """压缩文件夹的文件"""
  # print(args, 'args')
  files = glob.glob(args.input_folder + '/*png')
  # print(files, 'files')
  perArgs = copy.copy(args) # 浅拷贝,不改变args的值
  # print(perArgs, 'perArgs')
  bpp_list = []
  mse_list = []
  psnr_list = []
  mssim_list = []
  msssim_db_list = []
  # 循环遍历kodak数据集
  for img in files:
    # print(img, 'img')
    # img为图片完整的相对路径
    imgIndexFirst = img.find('/kodim') # 索引
    imgIndexNext = img.find('.png')
    imgName = img[imgIndexFirst: imgIndexNext] # 单独的图片文件名,如kodim01.png
    # print(imgName, 'imgName') # 单独的图片文件名,如/kodim01
    perArgs.input_file = img
    perArgs.output_file = args.output_folder + imgName + '.tfci'
    # print(perArgs, 'perArgs')
    # print(args, 'args')
    bpp, mse, psnr, msssim, msssim_db = perCompress(perArgs)
    print(bpp, mse
  • 2
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 3
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值