前言
引入大佬的讲解博客棋盘格上下文图像压缩
以及对应的github:github参考地址、棋盘格实验复现
本文仍然沿用CompressAI的库进行复现,在joint实验基础上实现joint+棋盘格的结合
关于compressAI相关博客说明:CompressAI:基于pytorch的图像压缩框架使用
joint实验复现:joint实验复现
1、创建工程
主要分为以下几个部分,主要参考代码为github参考地址
这里的train.py是我复制compressai中的train.py可以进行复用。其他几个部分沿用大佬的文件。
1.1 train.py的具体更改
-
引入更改后的模型
-
修改调用的模型
2、创建训练命令脚本
因为这篇文献其实是可以直接调用相关命令对实验进行复现,所以主要把相关命令输入,按照步骤一步步执行即可,相关脚本命令可以参考sh脚本运行命令
创建成功后:
# conda activate torchll
python3 train.py -d /xxx/imageNet1000t --epochs 2000 -lr 1e-4 --lambda 0.0016 --batch-size 8 --cuda --save > log.out
其中patchsize取的默认256,最后可以加上自己保持的路径
具体的命令说明可以参见官网以及CompressAI:基于pytorch的图像压缩框架使用
训练完成后,会在.out文件中打印相关输出
3、整理训练流程
- 加载训练命令,如上述图片的数据集、 e p o c h epoch epoch、学习率、 p a t c h s i z e patchsize patchsize、随机种子等
- 对数据集图片进行剪裁
分别读取train训练数据集与test测试数据集的图片,剪裁大小 p a t c h s i z e patchsize patchsize初始值是256,train训练集随机剪裁成 p a t c h s i z e ∗ p a t h s i z e patchsize*pathsize patchsize∗pathsize,test测试数据集中心剪裁成 p a t c h s i z e ∗ p a t h s i z e patchsize*pathsize patchsize∗pathsize - 加载剪裁好的两个数据集,以及训练相关数据如 b a t c h s i z e batchsize batchsize等
- 读取需要训练的模型网络(这里的上下文用了棋盘格掩膜)、进行优化器的初始设置
(1)模型引用位置,引入自己的模型
模型代码可以参见附录
- 训练前先判断是否在之前有保存的节点,有则继续训练,无则重新开始
- 对train训练集以及test测试集分别带入优化器进行梯度下降,开始模型训练
- 损失收敛后训练完成,保存模型以及相应的节点
其实对于训练流程而言,其实都是一样的,这几篇基于compressai都是调用一个训练方法,区别在于具体的模型更改,比如这里加入了棋盘格上下文模型
4、重新修改训练流程(引用JiangWeibeta的方法,zqb_dev分支)
针对以上的方法,涉及到CDF的更新,需要在compressai中逐层添加对应新加的模型,所以下面尝试一下新的train.py
棋盘格实验复现
训练流程整理,同上文,模型的引用和使用如下
4.1 train.py
def options():
parser = argparse.ArgumentParser(description="Example training script.")
parser.add_argument(
"-m",
"--model",
default="CheckerboardAutogressive",
choices=models.keys(),
help="Model architecture (default: %(default)s)",
)
parser.add_argument(
"-d",
"--dataset",
default='/home/ll/datasets/imageNet1000t',
type=str,
required=False,
help="Training dataset"
)
parser.add_argument(
"-e",
"--epochs",
default=11,
type=int,
help="Number of epochs (default: %(default)s)",
)
parser.add_argument(
"-lr",
"--learning-rate",
default=1e-4,
type=float,
help="Learning rate (default: %(default)s)",
)
parser.add_argument(
"-n",
"--num-workers",
type=int,
default=4,
help="Dataloaders threads (default: %(default)s)",
)
parser.add_argument(
"--lambda",
dest="lmbda",
type=float,
default=1e-2,
help="Bit-rate distortion parameter (default: %(default)s)",
)
parser.add_argument(
"--batch-size",
type=int,
default=8,
help="Batch size (default: %(default)s)"
)
parser.add_argument(
"--test-batch-size",
type=int,
default=64,
help="Test batch size (default: %(default)s)",
)
parser.add_argument(
"--aux-learning-rate",
default=1e-3,
help="Auxiliary loss learning rate (default: %(default)s)",
)
parser.add_argument(
"--patch-size",
type=int,
nargs=2,
default=(256, 256),
help="Size of the patches to be cropped (default: %(default)s)",
)
parser.add_argument(
"--cuda",
default=True,
action="store_true",
help="Use cuda")
parser.add_argument(
"--save",
action="store_true",
default=True,
help="Save model to disk"
)
parser.add_argument(
"--seed",
type=float,
default=234,
help="Set random seed for reproducibility"
)
parser.add_argument(
"--clip_max_norm",
default=1.0,
type=float,
help="gradient clipping max norm (default: %(default)s",
)
parser.add_argument(
"--checkpoint",
default='',
type=str,
help="Path to a checkpoint"
)
parser.add_argument(
"--N",
default=192,
type=int
)
parser.add_argument(
"--M",
default=192,
type=int
)
parser.add_argument(
"--G",
default=24,
type=int
)
args = parser.parse_args()
return args
if __name__ == '__main__':
args = options()
if args.seed is not None:
torch.manual_seed(args.seed)
random.seed(args.seed)
train_transforms = transforms.Compose(
[transforms.RandomCrop(args.patch_size), transforms.ToTensor()]
)
test_transforms = transforms.Compose(
[transforms.CenterCrop(args.patch_size), transforms.ToTensor()]
)
train_dataset = ImageFolder(args.dataset, split="train", transform=train_transforms)
test_dataset = ImageFolder(args.dataset, split="test", transform=test_transforms)
device = "cuda" if args.cuda and torch.cuda.is_available() else "cpu"
train_dataloader = DataLoader(
train_dataset,
batch_size=args.batch_size,
num_workers=args.num_workers,
shuffle=True,
pin_memory=(device == "cuda"),
drop_last=True
)
test_dataloader = DataLoader(
test_dataset,
batch_size=args.test_batch_size,
num_workers=args.num_workers,
shuffle=False,
pin_memory=(device == "cuda"),
drop_last=True
)
model = CheckerboardAutogressive()
net = model.to(device)
if args.cuda and torch.cuda.device_count() > 1:
net = train_method.CustomDataParallel(net)
optimizer, aux_optimizer