Checkerboard Context Model for Efficient Learned Image Compression文献复现(非官方,改动较大)

本文档详细记录了Checkerboard Context Model的复现过程,使用PyTorch框架,并基于CompressAI库进行实验。作者通过调整训练参数、增加数据集,对比不同码率点的性能和解码时间,旨在实现高效的图像压缩。在实验中,作者遇到了滤波器数量和性能提升的问题,并给出了解决方案。
摘要由CSDN通过智能技术生成

前言

引入大佬的讲解博客棋盘格上下文图像压缩
以及对应的github:github参考地址棋盘格实验复现
本文仍然沿用CompressAI的库进行复现,在joint实验基础上实现joint+棋盘格的结合
关于compressAI相关博客说明:CompressAI:基于pytorch的图像压缩框架使用
joint实验复现:joint实验复现

1、创建工程

主要分为以下几个部分,主要参考代码为github参考地址
在这里插入图片描述
这里的train.py是我复制compressai中的train.py可以进行复用。其他几个部分沿用大佬的文件。

1.1 train.py的具体更改

  1. 引入更改后的模型
    在这里插入图片描述

  2. 修改调用的模型
    在这里插入图片描述
    在这里插入图片描述

2、创建训练命令脚本

因为这篇文献其实是可以直接调用相关命令对实验进行复现,所以主要把相关命令输入,按照步骤一步步执行即可,相关脚本命令可以参考sh脚本运行命令

创建成功后:
在这里插入图片描述

# conda activate torchll
python3 train.py -d /xxx/imageNet1000t --epochs 2000 -lr 1e-4 --lambda 0.0016 --batch-size 8 --cuda --save > log.out



其中patchsize取的默认256,最后可以加上自己保持的路径
具体的命令说明可以参见官网以及CompressAI:基于pytorch的图像压缩框架使用

训练完成后,会在.out文件中打印相关输出
在这里插入图片描述

3、整理训练流程

  1. 加载训练命令,如上述图片的数据集、 e p o c h epoch epoch、学习率、 p a t c h s i z e patchsize patchsize、随机种子等
  2. 对数据集图片进行剪裁
    分别读取train训练数据集与test测试数据集的图片,剪裁大小 p a t c h s i z e patchsize patchsize初始值是256,train训练集随机剪裁成 p a t c h s i z e ∗ p a t h s i z e patchsize*pathsize patchsizepathsize,test测试数据集中心剪裁成 p a t c h s i z e ∗ p a t h s i z e patchsize*pathsize patchsizepathsize
  3. 加载剪裁好的两个数据集,以及训练相关数据如 b a t c h s i z e batchsize batchsize
  4. 读取需要训练的模型网络(这里的上下文用了棋盘格掩膜)、进行优化器的初始设置
    (1)模型引用位置,引入自己的模型
    在这里插入图片描述

模型代码可以参见附录

  1. 训练前先判断是否在之前有保存的节点,有则继续训练,无则重新开始
  2. 对train训练集以及test测试集分别带入优化器进行梯度下降,开始模型训练
  3. 损失收敛后训练完成,保存模型以及相应的节点

其实对于训练流程而言,其实都是一样的,这几篇基于compressai都是调用一个训练方法,区别在于具体的模型更改,比如这里加入了棋盘格上下文模型

4、重新修改训练流程(引用JiangWeibeta的方法,zqb_dev分支)

针对以上的方法,涉及到CDF的更新,需要在compressai中逐层添加对应新加的模型,所以下面尝试一下新的train.py
棋盘格实验复现
训练流程整理,同上文,模型的引用和使用如下
在这里插入图片描述
在这里插入图片描述

4.1 train.py


def options():
    parser = argparse.ArgumentParser(description="Example training script.")
    parser.add_argument(
        "-m",
        "--model",
        default="CheckerboardAutogressive",
        choices=models.keys(),
        help="Model architecture (default: %(default)s)",
    )
    parser.add_argument(
        "-d",
        "--dataset",
        default='/home/ll/datasets/imageNet1000t',
        type=str,
        required=False,
        help="Training dataset"
    )
    parser.add_argument(
        "-e",
        "--epochs",
        default=11,
        type=int,
        help="Number of epochs (default: %(default)s)",
    )
    parser.add_argument(
        "-lr",
        "--learning-rate",
        default=1e-4,
        type=float,
        help="Learning rate (default: %(default)s)",
    )
    parser.add_argument(
        "-n",
        "--num-workers",
        type=int,
        default=4,
        help="Dataloaders threads (default: %(default)s)",
    )
    parser.add_argument(
        "--lambda",
        dest="lmbda",
        type=float,
        default=1e-2,
        help="Bit-rate distortion parameter (default: %(default)s)",
    )
    parser.add_argument(
        "--batch-size",
        type=int,
        default=8,
        help="Batch size (default: %(default)s)"
    )
    parser.add_argument(
        "--test-batch-size",
        type=int,
        default=64,
        help="Test batch size (default: %(default)s)",
    )
    parser.add_argument(
        "--aux-learning-rate",
        default=1e-3,
        help="Auxiliary loss learning rate (default: %(default)s)",
    )
    parser.add_argument(
        "--patch-size",
        type=int,
        nargs=2,
        default=(256, 256),
        help="Size of the patches to be cropped (default: %(default)s)",
    )
    parser.add_argument(
        "--cuda",
        default=True,
        action="store_true",
        help="Use cuda")
    parser.add_argument(
        "--save",
        action="store_true",
        default=True,
        help="Save model to disk"
    )
    parser.add_argument(
        "--seed",
        type=float,
        default=234,
        help="Set random seed for reproducibility"
    )
    parser.add_argument(
        "--clip_max_norm",
        default=1.0,
        type=float,
        help="gradient clipping max norm (default: %(default)s",
    )
    parser.add_argument(
        "--checkpoint",
        default='',
        type=str,
        help="Path to a checkpoint"
    )
    parser.add_argument(
        "--N",
        default=192,
        type=int
    )
    parser.add_argument(
        "--M",
        default=192,
        type=int
    )
    parser.add_argument(
        "--G",
        default=24,
        type=int
    )

    args = parser.parse_args()

    return args


if __name__ == '__main__':
    args = options()
    if args.seed is not None:
        torch.manual_seed(args.seed)
        random.seed(args.seed)

    train_transforms = transforms.Compose(
        [transforms.RandomCrop(args.patch_size), transforms.ToTensor()]
    )

    test_transforms = transforms.Compose(
        [transforms.CenterCrop(args.patch_size), transforms.ToTensor()]
    )

    train_dataset = ImageFolder(args.dataset, split="train", transform=train_transforms)
    test_dataset = ImageFolder(args.dataset, split="test", transform=test_transforms)

    device = "cuda" if args.cuda and torch.cuda.is_available() else "cpu"

    train_dataloader = DataLoader(
        train_dataset,
        batch_size=args.batch_size,
        num_workers=args.num_workers,
        shuffle=True,
        pin_memory=(device == "cuda"),
        drop_last=True
    )
    test_dataloader = DataLoader(
        test_dataset,
        batch_size=args.test_batch_size,
        num_workers=args.num_workers,
        shuffle=False,
        pin_memory=(device == "cuda"),
        drop_last=True
    )
    model = CheckerboardAutogressive()
    net = model.to(device)

    if args.cuda and torch.cuda.device_count() > 1:
        net = train_method.CustomDataParallel(net)

    optimizer, aux_optimizer 
评论 7
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值