FBCNN:使用 AI 模型修复图像的电子包浆

1. 引入

  • 当一张 JPEG 图像被不断压缩传输时可能会由于压缩不良而产生的马赛克或者模糊痕迹,俗称“图像电子包浆”。

  • JPEG Artifacts Remove,即 JPEG 图像伪影移除技术,就可以一定程度上解决上述的问题。

  • 通常使用图像处理或者深度学习算法,移除 JPEG 图像由于压缩不良而产生的马赛克或者模糊痕迹。

  • 本次就简单介绍一个发表于 ICCV 2021 上基于深度学习实现的 JPEG Artifacts Remove 算法模型 FBCNN

2. 效果展示

  • 彩色 JPEG 图像伪影去除:
    在这里插入图片描述

  • 灰度 JPEG 图像伪影去除:

    在这里插入图片描述

3. 参考资料

4. 快速使用

  • 使用 PaddleHub 快速调用 FBCNN 模型

  • 可根据图像类型(彩色 / 灰度)选择对应 PaddleHub Module(fbcnn_color / fbcnn_gray)

# 彩色图像
!hub run fbcnn_color \
    --input_path tests/color_test.jpg \
    --quality_factor -1 \
    --output_dir fbcnn_color_output
# 灰度图像
!hub run fbcnn_gray \
    --input_path tests/gray_test.jpg \
    --quality_factor -1 \
    --output_dir fbcnn_gray_output

5. 算法模型

  • 模型简介:

    • FBCNN 是一个基于卷积神经网络的 JPEG 图像伪影移除模型。

    • 它可以预测可调整的质量因子,以在伪影移除和细节保存之间取得平衡。

  • 模型架构图:

在这里插入图片描述

  • FBCNN 组成模块:

    • FBCNN 由解耦器(Decoupler)

    • 质量因子预测器(Quality Factor Predictor)

    • 自适应控制器(Flexible Controller)

    • 图像重构器(Image Reconstructor)

  • 算法流程:

    • 解耦器从输入的受损 JPEG 图像中提取深层特征,然后将其分解为图像特征和质量因子特征,分别送入图像重构器和质量因子预测器。

    • 控制器从质量因子预测器获得估计的质量因子,然后生成质量因子嵌入(QF Embeddings)。

    • 质量因子注意力块(QF Attention Block)使控制器能够根据不同的质量因子嵌入(QF Embeddings)使重构器产生不同的结果。

    • 预测的质量因子可进行交互式的自定义修改,以在伪影移除和细节保存之间取得平衡。

6. 代码实现

  • 使用 Paddle 搭建 FBCNN 模型,并使用预训练模型构建模型预测器,实现彩色 / 灰度图像伪影移除
# 搭建模型
from collections import OrderedDict

import numpy as np
import paddle.nn as nn


def sequential(*args):
    if len(args) == 1:
        if isinstance(args[0], OrderedDict):
            raise NotImplementedError('sequential does not support OrderedDict input.')
        return args[0]  # No sequential is needed.
    modules = []
    for module in args:
        if isinstance(module, nn.Sequential):
            for submodule in module.children():
                modules.append(submodule)
        elif isinstance(module, nn.Layer):
            modules.append(module)
    return nn.Sequential(*modules)


def conv(in_channels=64,
         out_channels=64,
         kernel_size=3,
         stride=1,
         padding=1,
         bias=True,
         mode='CBR',
         negative_slope=0.2):
    L = []
    for t in mode:
        if t == 'C':
            L.append(
                nn.Conv2D(in_channels=in_channels,
                          out_channels=out_channels,
                          kernel_size=kernel_size,
                          stride=stride,
                          padding=padding,
                          bias_attr=bias))
        elif t == 'T':
            L.append(
                nn.Conv2DTranspose(in_channels=in_channels,
                                   out_channels=out_channels,
                                   kernel_size=kernel_size,
                                   stride=stride,
                                   padding=padding,
                                   bias_attr=bias))
        elif t == 'B':
            L.append(nn.BatchNorm2D(out_channels, momentum=0.9, eps=1e-04, affine=True))
        elif t == 'I':
            L.append(nn.InstanceNorm2D(out_channels, affine=True))
        elif t == 'R':
            L.append(nn.ReLU())
        elif t == 'L':
            L.append(nn.LeakyReLU(negative_slope=negative_slope))
        elif t == '2':
            L.append(nn.PixelShuffle(upscale_factor=2))
        elif t == '3':
            L.append(nn.PixelShuffle(upscale_factor=3))
        elif t == '4':
            L.append(nn.PixelShuffle(upscale_factor=4))
        elif t == 'U':
            L.append(nn.Upsample(scale_factor=2, mode='nearest'))
        elif t == 'u':
            L.append(nn.Upsample(scale_factor=3, mode='nearest'))
        elif t == 'v':
            L.append(nn.Upsample(scale_factor=4, mode='nearest'))
        elif t == 'M':
            L.append(nn.MaxPool2D(kernel_size=kernel_size, stride=stride, padding=0))
        elif t == 'A':
            L.append(nn.AvgPool2D(kernel_size=kernel_size, stride=stride, padding=0))
        else:
            raise NotImplementedError('Undefined type: '.format(t))
    return sequential(*L)


class ResBlock(nn.Layer):

    def __init__(self,
                 in_channels=64,
                 out_channels=64,
                 kernel_size=3,
                 stride=1,
                 padding=1,
                 bias=True,
                 mode='CRC',
                 negative_slope=0.2):
        super(ResBlock, self).__init__()

        assert in_channels == out_channels, 'Only support in_channels==out_channels.'
        if mode[0] in ['R', 'L']:
            mode = mode[0].lower() + mode[1:]

        self.res = conv(in_channels, out_channels, kernel_size, stride, padding, bias, mode, negative_slope)

    def forward(self, x):
        res = self.res(x)
        return x + res


def upsample_pixelshuffle(in_channels=64,
                          out_channels=3,
                          kernel_size=3,
                          stride=1,
                          padding=1,
                          bias=True,
                          mode='2R',
                          negative_slope=0.2):
    assert len(mode) < 4 and mode[0] in ['2', '3', '4'], 'mode examples: 2, 2R, 2BR, 3, ..., 4BR.'
    up1 = conv(in_channels,
               out_channels * (int(mode[0])**2),
               kernel_size,
               stride,
               padding,
               bias,
               mode='C' + mode,
               negative_slope=negative_slope)
    return up1


def upsample_upconv(in_channels=64,
                    out_channels=3,
                    kernel_size=3,
                    stride=1,
                    padding=1,
                    bias=True,
                    mode='2R',
                    negative_slope=0.2):
    assert len(mode) < 4 and mode[0] in ['2', '3', '4'], 'mode examples: 2, 2R, 2BR, 3, ..., 4BR'
    if mode[0] == '2':
        uc = 'UC'
    elif mode[0] == '3':
        uc = 'uC'
    elif mode[0] == '4':
        uc = 'vC'
    mode = mode.replace(mode[0], uc)
    up1 = conv(in_channels, out_channels, kernel_size, stride, padding, bias, mode=mode, negative_slope=negative_slope)
    return up1


def upsample_convtranspose(in_channels=64,
                           out_channels=3,
                           kernel_size=2,
                           stride=2,
                           padding=0,
                           bias=True,
                           mode='2R',
                           negative_slope=0.2):
    assert len(mode) < 4 and mode[0] in ['2', '3', '4'], 'mode examples: 2, 2R, 2BR, 3, ..., 4BR.'
    kernel_size = int(mode[0])
    stride = int(mode[0])
    mode = mode.replace(mode[0], 'T')
    up1 = conv(in_channels, out_channels, kernel_size, stride, padding, bias, mode, negative_slope)
    return up1


def downsample_strideconv(in_channels=64,
                          out_channels=64,
                          kernel_size=2,
                          stride=2,
                          padding=0,
                          bias=True,
                          mode='2R',
                          negative_slope=0.2):
    assert len(mode) < 4 and mode[0] in ['2', '3', '4'], 'mode examples: 2, 2R, 2BR, 3, ..., 4BR.'
    kernel_size = int(mode[0])
    stride = int(mode[0])
    mode = mode.replace(mode[0], 'C')
    down1 = conv(in_channels, out_channels, kernel_size, stride, padding, bias, mode, negative_slope)
    return down1


def downsample_maxpool(in_channels=64,
                       out_channels=64,
                       kernel_size=3,
                       stride=1,
                       padding=0,
                       bias=True,
                       mode='2R',
                       negative_slope=0.2):
    assert len(mode) < 4 and mode[0] in ['2', '3'], 'mode examples: 2, 2R, 2BR, 3, ..., 3BR.'
    kernel_size_pool = int(mode[0])
    stride_pool = int(mode[0])
    mode = mode.replace(mode[0], 'MC')
    pool = conv(kernel_size=kernel_size_pool, stride=stride_pool, mode=mode[0], negative_slope=negative_slope)
    pool_tail = conv(in_channels,
                     out_channels,
                     kernel_size,
                     stride,
                     padding,
                     bias,
                     mode=mode[1:],
                     negative_slope=negative_slope)
    return sequential(pool, pool_tail)


def downsample_avgpool(in_channels=64,
                       out_channels=64,
                       kernel_size=3,
                       stride=1,
                       padding=1,
                       bias=True,
                       mode='2R',
                       negative_slope=0.2):
    assert len(mode) < 4 and mode[0] in ['2', '3'], 'mode examples: 2, 2R, 2BR, 3, ..., 3BR.'
    kernel_size_pool = int(mode[0])
    stride_pool = int(mode[0])
    mode = mode.replace(mode[0], 'AC')
    pool = conv(kernel_size=kernel_size_pool, stride=stride_pool, mode=mode[0], negative_slope=negative_slope)
    pool_tail = conv(in_channels,
                     out_channels,
                     kernel_size,
                     stride,
                     padding,
                     bias,
                     mode=mode[1:],
                     negative_slope=negative_slope)
    return sequential(pool, pool_tail)


class QFAttention(nn.Layer):

    def __init__(self,
                 in_channels=64,
                 out_channels=64,
                 kernel_size=3,
                 stride=1,
                 padding=1,
                 bias=True,
                 mode='CRC',
                 negative_slope=0.2):
        super(QFAttention, self).__init__()

        assert in_channels == out_channels, 'Only support in_channels==out_channels.'
        if mode[0] in ['R', 'L']:
            mode = mode[0].lower() + mode[1:]

        self.res = conv(in_channels, out_channels, kernel_size, stride, padding, bias, mode, negative_slope)

    def forward(self, x, gamma, beta):
        gamma = gamma.unsqueeze(-1).unsqueeze(-1)
        beta = beta.unsqueeze(-1).unsqueeze(-1)
        res = (gamma) * self.res(x) + beta
        return x + res


class FBCNN(nn.Layer):

    def __init__(self,
                 in_nc=3,
                 out_nc=3,
                 nc=[64, 128, 256, 512],
                 nb=4,
                 act_mode='R',
                 downsample_mode='strideconv',
                 upsample_mode='convtranspose'):
        super(FBCNN, self).__init__()

        self.m_head = conv(in_nc, nc[0], bias=True, mode='C')
        self.nb = nb
        self.nc = nc
        # downsample
        if downsample_mode == 'avgpool':
            downsample_block = downsample_avgpool
        elif downsample_mode == 'maxpool':
            downsample_block = downsample_maxpool
        elif downsample_mode == 'strideconv':
            downsample_block = downsample_strideconv
        else:
            raise NotImplementedError('downsample mode [{:s}] is not found'.format(downsample_mode))

        self.m_down1 = sequential(*[ResBlock(nc[0], nc[0], bias=True, mode='C' + act_mode + 'C') for _ in range(nb)],
                                  downsample_block(nc[0], nc[1], bias=True, mode='2'))
        self.m_down2 = sequential(*[ResBlock(nc[1], nc[1], bias=True, mode='C' + act_mode + 'C') for _ in range(nb)],
                                  downsample_block(nc[1], nc[2], bias=True, mode='2'))
        self.m_down3 = sequential(*[ResBlock(nc[2], nc[2], bias=True, mode='C' + act_mode + 'C') for _ in range(nb)],
                                  downsample_block(nc[2], nc[3], bias=True, mode='2'))

        self.m_body_encoder = sequential(
            *[ResBlock(nc[3], nc[3], bias=True, mode='C' + act_mode + 'C') for _ in range(nb)])

        self.m_body_decoder = sequential(
            *[ResBlock(nc[3], nc[3], bias=True, mode='C' + act_mode + 'C') for _ in range(nb)])

        # upsample
        if upsample_mode == 'upconv':
            upsample_block = upsample_upconv
        elif upsample_mode == 'pixelshuffle':
            upsample_block = upsample_pixelshuffle
        elif upsample_mode == 'convtranspose':
            upsample_block = upsample_convtranspose
        else:
            raise NotImplementedError('upsample mode [{:s}] is not found'.format(upsample_mode))

        self.m_up3 = nn.LayerList([
            upsample_block(nc[3], nc[2], bias=True, mode='2'),
            *[QFAttention(nc[2], nc[2], bias=True, mode='C' + act_mode + 'C') for _ in range(nb)]
        ])

        self.m_up2 = nn.LayerList([
            upsample_block(nc[2], nc[1], bias=True, mode='2'),
            *[QFAttention(nc[1], nc[1], bias=True, mode='C' + act_mode + 'C') for _ in range(nb)]
        ])

        self.m_up1 = nn.LayerList([
            upsample_block(nc[1], nc[0], bias=True, mode='2'),
            *[QFAttention(nc[0], nc[0], bias=True, mode='C' + act_mode + 'C') for _ in range(nb)]
        ])

        self.m_tail = conv(nc[0], out_nc, bias=True, mode='C')

        self.qf_pred = sequential(*[ResBlock(nc[3], nc[3], bias=True, mode='C' + act_mode + 'C') for _ in range(nb)],
                                  nn.AdaptiveAvgPool2D((1, 1)), nn.Flatten(), nn.Linear(512, 512), nn.ReLU(),
                                  nn.Linear(512, 512), nn.ReLU(), nn.Linear(512, 1), nn.Sigmoid())

        self.qf_embed = sequential(nn.Linear(1, 512), nn.ReLU(), nn.Linear(512, 512), nn.ReLU(), nn.Linear(512, 512),
                                   nn.ReLU())

        self.to_gamma_3 = sequential(nn.Linear(512, nc[2]), nn.Sigmoid())
        self.to_beta_3 = sequential(nn.Linear(512, nc[2]), nn.Tanh())
        self.to_gamma_2 = sequential(nn.Linear(512, nc[1]), nn.Sigmoid())
        self.to_beta_2 = sequential(nn.Linear(512, nc[1]), nn.Tanh())
        self.to_gamma_1 = sequential(nn.Linear(512, nc[0]), nn.Sigmoid())
        self.to_beta_1 = sequential(nn.Linear(512, nc[0]), nn.Tanh())

    def forward(self, x, qf_input=None):

        h, w = x.shape[-2:]
        paddingBottom = int(np.ceil(h / 8) * 8 - h)
        paddingRight = int(np.ceil(w / 8) * 8 - w)
        x = nn.functional.pad(x, (0, paddingRight, 0, paddingBottom), mode='reflect')

        x1 = self.m_head(x)
        x2 = self.m_down1(x1)
        x3 = self.m_down2(x2)
        x4 = self.m_down3(x3)
        x = self.m_body_encoder(x4)
        qf = self.qf_pred(x)
        x = self.m_body_decoder(x)
        qf_embedding = self.qf_embed(qf_input) if qf_input is not None else self.qf_embed(qf)
        gamma_3 = self.to_gamma_3(qf_embedding)
        beta_3 = self.to_beta_3(qf_embedding)

        gamma_2 = self.to_gamma_2(qf_embedding)
        beta_2 = self.to_beta_2(qf_embedding)

        gamma_1 = self.to_gamma_1(qf_embedding)
        beta_1 = self.to_beta_1(qf_embedding)

        x = x + x4
        x = self.m_up3[0](x)
        for i in range(self.nb):
            x = self.m_up3[i + 1](x, gamma_3, beta_3)

        x = x + x3

        x = self.m_up2[0](x)
        for i in range(self.nb):
            x = self.m_up2[i + 1](x, gamma_2, beta_2)
        x = x + x2

        x = self.m_up1[0](x)
        for i in range(self.nb):
            x = self.m_up1[i + 1](x, gamma_1, beta_1)

        x = x + x1
        x = self.m_tail(x)
        x = x[..., :h, :w]

        return x, qf
# FBCNN 彩色图像预测器
import os
import time
from typing import Union

import cv2
import numpy as np
import paddle
import paddle.nn as nn


class FBCNNColor:

    def __init__(self):
        super(FBCNNColor, self).__init__()
        self.default_pretrained_model_path = 'data/data173244/fbcnn_color.pdparams'
        self.fbcnn = FBCNN()
        state_dict = paddle.load(self.default_pretrained_model_path)
        self.fbcnn.set_state_dict(state_dict)
        self.fbcnn.eval()

    def preprocess(self, img: np.ndarray) -> np.ndarray:
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img = img.transpose((2, 0, 1))
        img = img / 255.0
        return img.astype(np.float32)

    def postprocess(self, img: np.ndarray) -> np.ndarray:
        img = img.clip(0, 1)
        img = img * 255.0
        img = img.transpose((1, 2, 0))
        img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
        return img.astype(np.uint8)

    def artifacts_removal(self,
                          image: Union[str, np.ndarray],
                          quality_factor: float = None,
                          visualization: bool = True,
                          output_dir: str = "fbcnn_color_output") -> np.ndarray:
        if isinstance(image, str):
            _, file_name = os.path.split(image)
            save_name, _ = os.path.splitext(file_name)
            save_name = save_name + '_' + str(int(time.time())) + '.jpg'
            image = cv2.imdecode(np.fromfile(image, dtype=np.uint8), cv2.IMREAD_COLOR)
        elif isinstance(image, np.ndarray):
            save_name = str(int(time.time())) + '.jpg'
            image = image
        else:
            raise Exception("image should be a str / np.ndarray")

        with paddle.no_grad():
            img_input = self.preprocess(image)
            img_input = paddle.to_tensor(img_input[None, ...], dtype=paddle.float32)
            if quality_factor and 0 <= quality_factor <= 1:
                qf_input = paddle.to_tensor([[quality_factor]], dtype=paddle.float32)
            else:
                qf_input = None
            img_output, _ = self.fbcnn(img_input, qf_input)
            img_output = img_output.numpy()[0]
            img_output = self.postprocess(img_output)

        if visualization:
            if not os.path.isdir(output_dir):
                os.makedirs(output_dir)
            save_path = os.path.join(output_dir, save_name)
            cv2.imwrite(save_path, img_output)

        return img_output
# FBCNN 灰度图像预测器
import os
import time
from typing import Union

import cv2
import numpy as np
import paddle
import paddle.nn as nn


class FBCNNGary:

    def __init__(self):
        super(FBCNNGary, self).__init__()
        self.default_pretrained_model_path = 'data/data173244/fbcnn_gray.pdparams'
        self.fbcnn = FBCNN(in_nc=1, out_nc=1)
        state_dict = paddle.load(self.default_pretrained_model_path)
        self.fbcnn.set_state_dict(state_dict)
        self.fbcnn.eval()

    def preprocess(self, img: np.ndarray) -> np.ndarray:
        img = img[None, ...]
        img = img / 255.0
        return img.astype(np.float32)

    def postprocess(self, img: np.ndarray) -> np.ndarray:
        img = img.clip(0, 1)
        img = img * 255.0
        return img.astype(np.uint8)

    def artifacts_removal(self,
                          image: Union[str, np.ndarray],
                          quality_factor: float = None,
                          visualization: bool = True,
                          output_dir: str = "fbcnn_gray_output") -> np.ndarray:
        if isinstance(image, str):
            _, file_name = os.path.split(image)
            save_name, _ = os.path.splitext(file_name)
            save_name = save_name + '_' + str(int(time.time())) + '.jpg'
            image = cv2.imdecode(np.fromfile(image, dtype=np.uint8), cv2.IMREAD_GRAYSCALE)
        elif isinstance(image, np.ndarray):
            save_name = str(int(time.time())) + '.jpg'
            image = image
        else:
            raise Exception("image should be a str / np.ndarray")

        with paddle.no_grad():
            img_input = self.preprocess(image)
            img_input = paddle.to_tensor(img_input[None, ...], dtype=paddle.float32)
            if quality_factor and 0 <= quality_factor <= 1:
                qf_input = paddle.to_tensor([[quality_factor]], dtype=paddle.float32)
            else:
                qf_input = None
            img_output, _ = self.fbcnn(img_input, qf_input)
            img_output = img_output.numpy()[0][0]
            img_output = self.postprocess(img_output)

        if visualization:
            if not os.path.isdir(output_dir):
                os.makedirs(output_dir)
            save_path = os.path.join(output_dir, save_name)
            cv2.imwrite(save_path, img_output)

        return img_output
# 彩色图像伪影移除
from PIL import Image


fbcnn_color = FBCNNColor()
image = cv2.imread('tests/color_test.jpg')
output = fbcnn_color.artifacts_removal(
    image=image,
    quality_factor=0.7,
    visualization=True,
    output_dir='fbcnn_color_output'
)
vis = np.concatenate([image, output], 1)
Image.fromarray(vis[..., ::-1])

在这里插入图片描述

# 灰度图像伪影移除

fbcnn_gray = FBCNNGary()
image = cv2.imread('tests/gray_test.jpg', cv2.IMREAD_GRAYSCALE)
output = fbcnn_gray.artifacts_removal(
    image=image,
    quality_factor=1.0,
    visualization=True,
    output_dir='fbcnn_gray_output'
)
vis = np.concatenate([image, output], 1)
Image.fromarray(vis)

在这里插入图片描述

7. 总结

  • 简单介绍了如何使用 FBCNN 模型移除由于压缩不良导致的图像伪影,去除图像的“电子包浆”。

  • 虽然说这样的图像存在缺陷,不过对于表情包这类图像,这种不清晰也有一种别样的历史感,假如将这些缺陷移除,那就没那味了。

此文章为搬运
原项目链接

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值