深度学习学习记录-1【端到端压缩/compressai/自编码器】

目录

1.compressai库安装

1.1安装步骤

1.2报错与解决:

1.2.1 No module named 'compressai._CXX'

1.2.2 关于C++的报错,一般建议升级项目所在环境中的gcc/g++版本(linux系统)

2.模型训练(examples)

2.1模型训练:

2.2模型评估(参考):

2.3模型推理:单张图片


1.compressai库安装

1.1安装步骤

建议项目安装至虚拟环境

git clone https://github.com/InterDigitalInc/CompressAI compressai
cd compressai
pip install -U pip && pip install -e .

1.2报错与解决:

1.2.1 No module named 'compressai._CXX'

报错来源:运行train.py时,在from compressai_.CXX处,compressai文件下的两个pyd文件无法被正常调用(在本地pycharm上模型可以正常训练,代码上传至服务器后不可以): _CXX.cp39-win_amd64.pyd和ans.cp39-win_amd64.pyd。

解决思路:在服务器中重新安装一遍compressai库,模型可用正常训练,不再报错

1.2.2 关于C++的报错,一般建议升级项目所在环境中的gcc/g++版本(linux系统)

解决思路:不同环境根据需求安装不同版本GCC,需要保证各版本共存

1.安装gcc/g++

sudo add-apt-repository ppa:ubuntu-toolchain-r/test
sudo apt update
sudo apt install gcc-9 g++-9

在运行第一行时出现如下报错,可用直接运行第三行:

Cannot add PPA: 'ppa:~jonathonf/ubuntu/gcc-9.4'.
The user named '~jonathonf' has no PPA named 'ubuntu/gcc-9.4'

检测apt源是否有想要安装的包的版本:

sudo apt-cache search gcc # gcc可替换为其他包名,同样是有效的查询
sudo apt-cache show gcc #展示版本号

2.检查是否安装成功

运行:gcc --version 和 g++ --version,出现以下即安装新版本成功

如果出现的版本仍为原来的版本,采用3.中的方法(设置版本优先级)

3.设置默认版本

查看已安装版本:dpkg -l | grep gcc

使用update-alternatives管理系统多版本,将安装好的新版本设置为默认版本(通过优先级设置)。其中主命令--install 将/usr/bin/gcc gcc更新至/usr/bin/gcc-4.7版本,20是优先级别(数字越大,优先级别越高),--slave将随主命令的更新而更新,保证gcc和g++编译器版本一致。

#安装软链接
sudo update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-9 20 --slave /usr/bin/g++ g++ /usr/bin/g++-9
sudo update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-7 10 --slave /usr/bin/g++ g++ /usr/bin/g++-7
#删除软链接
sudo update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-9 --slave /usr/bin/g++ g++ /usr/bin/g++-9

查看可选择的gcc版本:sudo update-alternatives --config gcc

前面带*即现在的版本,可以输入selection列对应数字选择想要的版本

2.模型训练(examples)

examples中包含三类对象(image、pointcloud、video)的模型训练代码,本项目只关注image。compressai库中包含多个压缩模型,各模型都是在base class:CompressionModel的基础上添加压缩/重建模块构成自编码器。以bmshj2018-factorized为例:

注释给出了模型框架:输入图像数据x,通过编码器g_a(将输入映射到一个低维的潜在空间latent space,形成编码code或潜在表示latent representation)得到输出y,经过量化Q、熵瓶颈EB得到y_hat,最后输入解码器g_s得到重建数据x_hat。

class FactorizedPrior(CompressionModel):
    r"""Factorized Prior model from J. Balle, D. Minnen, S. Singh, S.J. Hwang,
    N. Johnston: `"Variational Image Compression with a Scale Hyperprior"
    <https://arxiv.org/abs/1802.01436>`_, Int Conf. on Learning Representations
    (ICLR), 2018.

    .. code-block:: none

                  ┌───┐    y
            x ──►─┤g_a├──►─┐
                  └───┘    │
                           ▼
                         ┌─┴─┐
                         │ Q │
                         └─┬─┘
                           │
                     y_hat ▼
                           │
                           ·
                        EB :
                           ·
                           │
                     y_hat ▼
                           │
                  ┌───┐    │
        x_hat ──◄─┤g_s├────┘
                  └───┘

        EB = Entropy bottleneck

    Args:
        N (int): Number of channels
        M (int): Number of channels in the expansion layers (last layer of the
            encoder and last layer of the hyperprior decoder)
    """

    def __init__(self, N, M, **kwargs):
        super().__init__(**kwargs)

        self.entropy_bottleneck = EntropyBottleneck(M)
        #编码
        self.g_a = nn.Sequential(
            conv(3, N),
            GDN(N),
            conv(N, N),
            GDN(N),
            conv(N, N),
            GDN(N),
            conv(N, M),
        )
        #解码
        self.g_s = nn.Sequential(
            deconv(M, N),
            GDN(N, inverse=True),
            deconv(N, N),
            GDN(N, inverse=True),
            deconv(N, N),
            GDN(N, inverse=True),
            deconv(N, 3),
        )

        self.N = N
        self.M = M
    #下采样因子,每经过一个卷积层,图像像素减小一半
    @property
    def downsampling_factor(self) -> int:
        return 2**4
    #前向传播
    def forward(self, x):
        y = self.g_a(x)
        y_hat, y_likelihoods = self.entropy_bottleneck(y)
        x_hat = self.g_s(y_hat)

        return {
            "x_hat": x_hat,
            "likelihoods": {
                "y": y_likelihoods,
            },
        }
    #从预训练中得到通道数
    @classmethod
    def from_state_dict(cls, state_dict):
        """Return a new model instance from `state_dict`."""
        N = state_dict["g_a.0.weight"].size(0)
        M = state_dict["g_a.6.weight"].size(0)
        net = cls(N, M)
        net.load_state_dict(state_dict)
        return net
    #压缩
    def compress(self, x):
        y = self.g_a(x)
        y_strings = self.entropy_bottleneck.compress(y)
        return {"strings": [y_strings], "shape": y.size()[-2:]}
    #重建
    def decompress(self, strings, shape):
        assert isinstance(strings, list) and len(strings) == 1
        y_hat = self.entropy_bottleneck.decompress(strings[0], shape)
        x_hat = self.g_s(y_hat).clamp_(0, 1)
        return {"x_hat": x_hat}

2.1模型训练:

只进行forward

python train.py -m model -d dataset/path --cuda --save

可以采用compressai.zoo中自带的预训练模型,main函数中具体操作为:

#args.model是采用的模型,调用img_model连接到zoo而不是models
#quality可选1-8,metric是评价指标(mse/ms-ssim),pretrained是否调用预训练
#quality的选择决定了输入输出通道数/预训练质量
net = image_models[args.model](quality, metric="mse", pretrained=False, progress=True)

2.2模型评估(参考):

import torch.nn.functional as F
from torchvision import transforms
import pandas as pd
import torch
import os
import sys
import math
import argparse
import time
from pytorch_msssim import ms_ssim
from PIL import Image
from net import FactorizedPrior

print(torch.cuda.is_available())


def compute_psnr(a, b):
    mse = torch.mean((a - b) ** 2).item()
    return -10 * math.log10(mse)


def compute_msssim(a, b):
    return -10 * math.log10(1 - ms_ssim(a, b, data_range=1.).item())


def compute_bpp(out_net):
    size = out_net['x_hat'].size()
    num_pixels = size[0] * size[2] * size[3]
    return sum(torch.log(likelihoods).sum() / (-math.log(2) * num_pixels)
               for likelihoods in out_net['likelihoods'].values()).item()


def pad(x, p):
    h, w = x.size(2), x.size(3)
    new_h = (h + p - 1) // p * p
    new_w = (w + p - 1) // p * p
    padding_left = (new_w - w) // 2
    padding_right = new_w - w - padding_left
    padding_top = (new_h - h) // 2
    padding_bottom = new_h - h - padding_top
    x_padded = F.pad(
        x,
        (padding_left, padding_right, padding_top, padding_bottom),
        mode="constant",
        value=0,
    )
    return x_padded, (padding_left, padding_right, padding_top, padding_bottom)


def crop(x, padding):
    return F.pad(
        x,
        (-padding[0], -padding[1], -padding[2], -padding[3]),
    )


def parse_args(argv):
    parser = argparse.ArgumentParser(description="Example testing script.")
    parser.add_argument("--cuda", action="store_true", help="Use cuda")
    parser.add_argument(
        "--clip_max_norm",
        default=1.0,
        type=float,
        help="gradient clipping max norm (default: %(default)s",
    )
    parser.add_argument("--checkpoint", type=str, help="Path to a checkpoint")
    parser.add_argument("--data", type=str, help="Path to dataset")
    parser.add_argument(
        "--real", action="store_true", default=True
    )
    parser.set_defaults(real=False)
    args = parser.parse_args(argv)
    return args


def main(argv):
    args = parse_args(argv)
    p = 128
    path = args.data
    img_list = []
    for file in os.listdir(path):
        if file[-3:] in ["jpg", "png", "peg"]:
            img_list.append(file)
    if args.cuda:
        device = 'cuda:0'
    else:
        device = 'cpu'
    net = FactorizedPrior(64,128)
    net = net.to(device)
    net.eval()
    count = 0
    PSNR = 0
    Bit_rate = 0
    MS_SSIM = 0
    total_time = 0
    dictory = {}
    if args.checkpoint:  # load from previous checkpoint
        print("Loading", args.checkpoint)
        checkpoint = torch.load(args.checkpoint, map_location=device)
        for k, v in checkpoint["state_dict"].items():
            dictory[k.replace("module.", "")] = v
        net.load_state_dict(dictory)
    if args.real:
        net.update()
        for img_name in img_list:
            img_path = os.path.join(path, img_name)
            img = transforms.ToTensor()(Image.open(img_path).convert('RGB')).to(device)
            x = img.unsqueeze(0)
            x_padded, padding = pad(x, p)
            count += 1
            with torch.no_grad():
                if args.cuda:
                    torch.cuda.synchronize()
                s = time.time()
                out_enc = net.compress(x_padded)
                out_dec = net.decompress(out_enc["strings"], out_enc["shape"])
                if args.cuda:
                    torch.cuda.synchronize()
                e = time.time()
                total_time += (e - s)
                out_dec["x_hat"] = crop(out_dec["x_hat"], padding)
                num_pixels = x.size(0) * x.size(2) * x.size(3)
                print(f'Bitrate: {(sum(len(s[0]) for s in out_enc["strings"]) * 8.0 / num_pixels):.3f}bpp')
                print(f'MS-SSIM: {compute_msssim(x, out_dec["x_hat"]):.2f}dB')
                print(f'PSNR: {compute_psnr(x, out_dec["x_hat"]):.2f}dB')
                Bit_rate += sum(len(s[0]) for s in out_enc["strings"]) * 8.0 / num_pixels
                PSNR += compute_psnr(x, out_dec["x_hat"])
                MS_SSIM += compute_msssim(x, out_dec["x_hat"])

    else:
        for img_name in img_list:
            img_path = os.path.join(path, img_name)
            img = Image.open(img_path).convert('RGB')
            x = transforms.ToTensor()(img).unsqueeze(0).to(device)
            x_padded, padding = pad(x, p)
            count += 1
            with torch.no_grad():
                if args.cuda:
                    torch.cuda.synchronize()
                s = time.time()
                out_net = net.forward(x_padded)
                if args.cuda:
                    torch.cuda.synchronize()
                e = time.time()
                total_time += (e - s)
                out_net['x_hat'].clamp_(0, 1)
                out_net["x_hat"] = crop(out_net["x_hat"], padding)
                print(f'PSNR: {compute_psnr(x, out_net["x_hat"]):.2f}dB')
                print(f'MS-SSIM: {compute_msssim(x, out_net["x_hat"]):.2f}dB')
                print(f'Bit-rate: {compute_bpp(out_net):.3f}bpp')
                PSNR += compute_psnr(x, out_net["x_hat"])
                MS_SSIM += compute_msssim(x, out_net["x_hat"])
                Bit_rate += compute_bpp(out_net)
    PSNR = PSNR / count
    MS_SSIM = MS_SSIM / count
    Bit_rate = Bit_rate / count
    total_time = total_time / count
    print(f'average_PSNR: {PSNR:.2f}dB')
    print(f'average_MS-SSIM: {MS_SSIM:.4f}')
    print(f'average_Bit-rate: {Bit_rate:.3f} bpp')
    print(f'average_time: {total_time:.3f} ms')


if __name__ == "__main__":
    print(torch.cuda.is_available())
    main(sys.argv[1:])

2.3模型推理:单张图片

import math
import torch
from torchvision import transforms

from PIL import Image

from pytorch_msssim import ms_ssim
from compressai.zoo import bmshj2018_factorized
from compressai.models.google import FactorizedPrior

def compute_psnr(a, b):
    mse = torch.mean((a - b) ** 2).item()
    return -10 * math.log10(mse)


def compute_msssim(a, b):
    return ms_ssim(a, b, data_range=1.).item()


def compute_bpp(out_net):
    size = out_net['x_hat'].size()
    num_pixels = size[0] * size[2] * size[3]
    return sum(torch.log(likelihoods).sum() / (-math.log(2) * num_pixels)
               for likelihoods in out_net['likelihoods'].values()).item()


if __name__ == '__main__':
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print(device)
    checkpoint_path = 'checkpoint_loss_best.pth.tar'
    checkpoint = torch.load(checkpoint_path, map_location=device)
    net=FactorizedPrior(128,192)
    net.load_state_dict(checkpoint['state_dict'])
    print(f'Parameters: {sum(p.numel() for p in net.parameters())}')

    img = Image.open('Kodak24/kodim01.png').convert('RGB')
    transform = transforms.Compose([
        transforms.Resize((512, 512)),
        transforms.ToTensor(),
    ])
    x = transform(img).unsqueeze(0).to(device)

    with torch.no_grad():
        out_net = net.forward(x)
        out_net['x_hat'].clamp_(0, 1)
        print(out_net.keys())

    print(f'PSNR: {compute_psnr(x, out_net["x_hat"]):.2f}dB')
    print(f'MS-SSIM: {compute_msssim(x, out_net["x_hat"]):.4f}')
    print(f'Bit-rate: {compute_bpp(out_net):.3f} bpp')

  • 22
    点赞
  • 22
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值