数字水印之JPEG压缩不可微分导致网络无法端到端训练的解决方法(附PyTorch代码)


一、引言

  在基于深度学习的图像数字水印领域,网络模型大致包含:编码器Encoder、解码器Decoder、鉴别器Discriminator和噪声层Noise-Layer四部分,其中Encoder功能是水印嵌入,输入为水印信息和载体图片,输出为包含水印信息的水印图片;Noise-Layer层包含多种类型的攻击,其作用是模拟水印图像在现实生活中遇到的各种攻击,训练网络模型对这些攻击的鲁棒性;Decoder功能为水印提取,输入为攻击后的水印图像,输出为之前嵌入的水印信息;Discriminator鉴别器的作用是提升生成水印图像的视觉质量,其输入为Encoder生成的水印图像或者不包含水印信息的原始载体图像,输出为0到1之间的小数,该数值表示输入图像是原始图像的概率,1表示Discriminator输入的是水印图像,Discriminator的理念在于通过GAN对抗训练的方式,让Encoder生成的水印图像尽可能和载体图像相似,提升水印图像的视觉质量。
  
  在众多攻击鲁棒性训练当中,水印信息对JPEG压缩攻击的鲁棒性至关重要,可以排在首位,在保存方面,当前大部分图像都是以 jpg,jpeg 格式保存,在图像传输方面,图像在QQ、微信、facebook等平台传输时,会经受平台自带的压缩算法,其底层或多或少会用到JPEG压缩。但是JPEG压缩在量化阶段采用的是rounding取整操作(也称为截断操作),这样的操作是不可微分的(not differentiable),这种不可导的存在对于依靠梯度反向传播实现参数优化的深度学习网络而言是不可接受的,换言之,若将JPEG压缩直接放到Noise-Layer中,梯度传递到Noise-Layer就会截止,Encoder无法得到有效训练。因此如何使网络获得对JPEG压缩的鲁棒性是本文要解决的问题。

  
  本文对学术界现有的方法进行总结,主要阐述:HiDDeN (2018 ECCV)、StegsStamp (2020 CVPR)、TSDL(2019 ACM MM)和MBRS(2021 ACM MM) 四篇论文的解决方法,以上四篇论文均为计算机视觉领域的A类会议,其中 HiDDeN 和 StegsStamp 更是CV顶会。

二、JPEG压缩流程

  JPEG压缩流程大致包含:颜色空间变换、下采样、分块、DCT变换、量化、熵编码六个步骤;
  (1) 颜色变换是将图像从RGB颜色空间转换到YUV空间;
  (2) 下采样是考虑到人眼对亮度信息相比色度信息更加敏感,因此会按照YUV444、YUV422或YUV420的格式下采样,其中最常用的是YUV420;
  (3) 分块是指将图像分为若干8x8大小的块;
  (4) DCT变换是指对之前分出来的每个8x8块进行(Discrete Cosine Transform, DCT)离散余弦变换,将图像在频域分为64个不同频率的余弦波;
  (5) 之后就是量化,亮度分量和色度分量各自对应不同的量化表,用DCT变换所得结果除以量化表,对于除不尽的小数部分直接进行截断操作(忘记是直接舍去小数点后的部分还是四舍五入取整,似乎是后者,看官若确定,可在评论区指出),量化这一步会极大压缩图像的高频分量,JPEG压缩的核心所在就是量化这一步,也正是因为量化取整的操作,才导致量化之前JPEG压缩是可逆的,但是量化之后,图像变换就不可逆了,即使后续解码显示图片,也会存在一定的精度损失。
  (6) 熵编码,用ZigZag扫描方式、游程长度编码、Huffman编码等方式进行编码,将图像数据转换为二进制码流,进而保存在磁盘上。
  JPEG压缩流程这里不细讲,若感兴趣,各位看官可自行搜索相关文章。

三、基于模拟的可微分JPEG压缩(Simulated-based JPEG)

  既然真实的JPEG压缩不可微分,那么可尝试用可微分的JPEG压缩模拟真实JPEG压缩对图像造成的影响,将基于模拟的可微分JPEG压缩放入Noise Layer层中不会影响模型的训练,该方式的代表作有斯坦福于2018年发表在 ECCV上的HiDDeN,以及加州大学伯克利分校发表在2020 CVPR上的StegaStamp。
  
  1. HiDDeN 处理方法
  经JPEG压缩后,图像高频区域会受到较大压缩,换言之,JPEG压缩主要压缩的是图像的高频区域,也就是DCT系数矩阵中靠近右下方的部分。对此,HiDDeN采用如图3-1所示的两种模拟方法:JPEG-Mask和JPEG-Drop。JPEG-Mask通过Mask蒙版的方式,将DCT系数矩阵右下角的部分置0,仅保留左上角25个元素的值;JPEG-Drop没有使用固定的蒙版,而是通过概率的方式将DCT系数矩阵任意位置的值置为0,其中越靠近右下角的位置,其被置0的概率越大。有学者认为,JPEG-Mask方法对于真实JPEG压缩是一种过拟合,其对高频分量的压缩力度过大,但对低频部分则完全没有压缩,因此模型抵抗JPEG压缩的性能并不是太好。
在这里插入图片描述

图3-1 HiDDeN JPEG压缩处理方式
  

  2. StegaStamp 处理方法
  相比HiDDeN,StegaStamp模拟JPEG压缩的方式更加细腻,StegaStamp从头至尾完全模拟了真实JPEG压缩,并且在量化这一步,作者将DCT系数矩阵和标准量化表相除后 (HiDDeN没有除以标准量化表,直接对DCT系数矩阵进行操作),对于所除得的结果,作者采用Song等人(2017 NIPS,顶会中的顶会,羡慕的淌口水。。。)如下所示的方式进行操作。
Q ( x ) = { x 3 , ∣ x ∣ < 0.5 x , ∣ x ∣ > = 0.5 Q(x) = \left\{ \begin {aligned} x^{3}, \vert x \vert <0.5 \\ x, \vert x \vert >= 0.5 \\ \end{aligned} \right. Q(x)={x3,x<0.5x,x>=0.5
  观察上式可知,若输入值小于0.5,对该值求立方后,所得结果将接近于0,这与JPEG压缩本意是符合的,通常而言图像大部分能量集中在低频部分,低频部分的值较大,高频部分的值较低 (对图像DCT变换后,图像左上角直流分量部位会出现一个明亮点),带入公式后,高频部分的值会越来越小,通过这样的方式可近似模拟JPEG对高频部分的压缩。
  
  HiDDeN和StegsStamp采用的方式能很好的避免真实JPEG压缩量化不可导的问题,将基于模拟的可微分JPEG压缩模块放在Encoder和Decoder之间,即可对整体网络进行端到端训练,使模型获得对JPEG压缩的鲁棒性。

四、双阶段解码器附加训练(Real JPEG)

  基于模拟的可微分JPEG压缩虽然能有效解决量化不可微分问题,但模拟出来的JPEG压缩毕竟不能代替真实的JPEG压缩噪声,该方法训练出的网络模型对JPEG压缩鲁棒性不高,为此有学者提出双阶段解码器附加训练的方法,通过优化模型训练策略,将真实的JPEG噪声引入网络训练,从而提高模型对JPEG压缩的鲁棒性。
  该方案的代表作为2019北京大学和深圳鹏城实验室推出的 TSDL,其具体方案为:将网络训练分为两个阶段,情况如图3-2所示:第一阶段训练时,Encoder和Decoder之间的Noise Layer层为空,即对水印图像没有攻击,第一阶段模型训练收敛后,固定编码器Encoder的网络参数不变,对Decoder解码器进行第二阶段的附加训练,将Encoder生成的水印图像JPEG压缩后,送入Decoder进行水印提取,通过解码器附加训练,使网络模型获得对JPEG压缩的鲁棒性。
在这里插入图片描述

图3-2 TSDL双阶段解码器附加训练

五、上述两类方法的不足

  上述两类方案都能一定程度上使网络模型获得对JPEG压缩攻击的鲁棒性,但这两种方案也有各自不足之处。
   (1) 基于模拟的可微分JPEG压缩,其缺点在于:基于模拟的可微分JPEG压缩,虽然能有效解决网络端到端训练的问题,但模拟的JPEG压缩噪声毕竟不是真正的JPEG压缩噪声,因此模型对JPEG压缩攻击的鲁棒性还有待提高。
  
   (2) 双阶段解码器附加训练策略,其致命缺点在于:单纯对Decoder进行附加训练并不能使得模型对JPEG压缩鲁棒,Encoder需要根据来自Decoder解码情况的信息,不断优化水印嵌入方式,从而找到最终合适的嵌入策略。在第一阶段训练时,由于没有攻击层的存在,故在水印解码准确率和水印图像视觉质量的监督下,水印信息会被嵌入图像的中高频区域,模型训练收敛后,图像视觉质量和Decoder水印解码准确率都能得到保证;第二阶段训练时,Encoder网络参数被固定,这也就意味着Encoder水印嵌入的策略随之固定,水印信息仍旧被嵌入至图像的中高频区域,此时引入JPEG压缩攻击,并对Decoder进行附加训练,由于JPEG会压缩图像的高频信息,即Encoder嵌入在高频区域的水印信息经JPEG压缩后已经出现丢失,因此不论后续再如何对Decoder附加训练,水印信息的丢失已经是既定事实,故Decoder不能很好的提取水印信息。

六、MBRS (极佳的JPEG压缩鲁棒性)

  MBRS是截止目前抵抗JPEG压缩最好的网络模型,直接上结果,图3-3是MBRS模型对JPEG压缩的水印提取误码率,由图可知MBRS在嵌入密度远大于HiDDeN、StegaStamp和TSDL的情况,不论是生成水印图像的视觉质量还是水印提取的误码率,都远超以往的模型,误码率更是仅有0.0092%和0.0012%,让后来人难以望其项背。
在这里插入图片描述

图3-3 MBRS对JPEG压缩攻击鲁棒性
  

  面对上文提高的缺点,MBRS采用一种Simulated-JPEG和Read-JPEG混合训练的方式,这样一方面既解决了真实JPEG压缩不可微分导致网络无法端到端训练的问题,另一方面也解决了模拟JPEG压缩导致的模型鲁棒性不强的问题,具体训练见图3-4右侧所示。

在这里插入图片描述

图3-4 MBRS 模型训练策略
  

  观察图3-4可知,Noise-Layer中包含:Identity (不对水印图像进行攻击)、基于模拟的可微分JPEG压缩,以及真实JPEG压缩三种攻击类型,每一次Iteration迭代都会从Noise-Layer的三种攻击中随机选择一种,模型训练时,基于模拟的可微分JPEG压缩会引导模型的训练,当Noise-Layer中为真实JPEG压缩噪声时,Decoder会学习到经真实JPEG后图像的哪些信道出现问题 (JPEG压缩导致Decoder解码失败,Loss值较大),而后在下一次端到端训练时 (当Noise-Layer为Identity和Simulated-JPEG时,梯度可正常传递至Encoder),将该信息反馈至Encoder,从而指导Encoder进行水印嵌入。

七、JPEG-Real、JPEG-Mask和JPEG-SS三者代码实现

  话不多说,直接看代码实现,本代码分别实现:真实JPEG压缩流程、基于Mask的模拟JPEG压缩以及和StegaStamp类似的JPEG压缩JPEG-SS。

import os
import random
import string

import numpy as np
import torch
import torch.nn as nn
from PIL import Image
from torchvision import transforms


class JpegBasic(nn.Module):
    def __init__(self):
        super(JpegBasic, self).__init__()

    def std_quantization(self, image_yuv_dct, scale_factor, round_func=torch.round):
        luminance_quant_tbl = (torch.tensor([
            [16, 11, 10, 16, 24, 40, 51, 61],
            [12, 12, 14, 19, 26, 58, 60, 55],
            [14, 13, 16, 24, 40, 57, 69, 56],
            [14, 17, 22, 29, 51, 87, 80, 62],
            [18, 22, 37, 56, 68, 109, 103, 77],
            [24, 35, 55, 64, 81, 104, 113, 92],
            [49, 64, 78, 87, 103, 121, 120, 101],
            [72, 92, 95, 98, 112, 100, 103, 99]
        ], dtype=torch.float) * scale_factor).round().to(image_yuv_dct.device).clamp(min=1).repeat(
            image_yuv_dct.shape[2] // 8, image_yuv_dct.shape[3] // 8)

        chrominance_quant_tbl = (torch.tensor([
            [17, 18, 24, 47, 99, 99, 99, 99],
            [18, 21, 26, 66, 99, 99, 99, 99],
            [24, 26, 56, 99, 99, 99, 99, 99],
            [47, 66, 99, 99, 99, 99, 99, 99],
            [99, 99, 99, 99, 99, 99, 99, 99],
            [99, 99, 99, 99, 99, 99, 99, 99],
            [99, 99, 99, 99, 99, 99, 99, 99],
            [99, 99, 99, 99, 99, 99, 99, 99]
        ], dtype=torch.float) * scale_factor).round().to(image_yuv_dct.device).clamp(min=1).repeat(
            image_yuv_dct.shape[2] // 8, image_yuv_dct.shape[3] // 8)

        q_image_yuv_dct = image_yuv_dct.clone()
        q_image_yuv_dct[:, :1, :, :] = image_yuv_dct[:, :1, :, :] / luminance_quant_tbl
        q_image_yuv_dct[:, 1:, :, :] = image_yuv_dct[:, 1:, :, :] / chrominance_quant_tbl
        q_image_yuv_dct_round = round_func(q_image_yuv_dct)
        return q_image_yuv_dct_round

    def std_reverse_quantization(self, q_image_yuv_dct, scale_factor):

        luminance_quant_tbl = (torch.tensor([
            [16, 11, 10, 16, 24, 40, 51, 61],
            [12, 12, 14, 19, 26, 58, 60, 55],
            [14, 13, 16, 24, 40, 57, 69, 56],
            [14, 17, 22, 29, 51, 87, 80, 62],
            [18, 22, 37, 56, 68, 109, 103, 77],
            [24, 35, 55, 64, 81, 104, 113, 92],
            [49, 64, 78, 87, 103, 121, 120, 101],
            [72, 92, 95, 98, 112, 100, 103, 99]
        ], dtype=torch.float) * scale_factor).round().to(q_image_yuv_dct.device).clamp(min=1).repeat(
            q_image_yuv_dct.shape[2] // 8, q_image_yuv_dct.shape[3] // 8)

        chrominance_quant_tbl = (torch.tensor([
            [17, 18, 24, 47, 99, 99, 99, 99],
            [18, 21, 26, 66, 99, 99, 99, 99],
            [24, 26, 56, 99, 99, 99, 99, 99],
            [47, 66, 99, 99, 99, 99, 99, 99],
            [99, 99, 99, 99, 99, 99, 99, 99],
            [99, 99, 99, 99, 99, 99, 99, 99],
            [99, 99, 99, 99, 99, 99, 99, 99],
            [99, 99, 99, 99, 99, 99, 99, 99]
        ], dtype=torch.float) * scale_factor).round().to(q_image_yuv_dct.device).clamp(min=1).repeat(
            q_image_yuv_dct.shape[2] // 8, q_image_yuv_dct.shape[3] // 8)

        image_yuv_dct = q_image_yuv_dct.clone()
        image_yuv_dct[:, :1, :, :] = q_image_yuv_dct[:, :1, :, :] * luminance_quant_tbl
        image_yuv_dct[:, 1:, :, :] = q_image_yuv_dct[:, 1:, :, :] * chrominance_quant_tbl
        return image_yuv_dct

    def dct(self, image):
        # coff for dct and idct
        coff = torch.zeros((8, 8), dtype=torch.float).to(image.device)
        coff[0, :] = 1 * np.sqrt(1 / 8)
        for i in range(1, 8):
            for j in range(8):
                coff[i, j] = np.cos(np.pi * i * (2 * j + 1) / (2 * 8)) * np.sqrt(2 / 8)

        split_num = image.shape[2] // 8
        image_dct = torch.cat(torch.cat(image.split(8, 2), 0).split(8, 3), 0)
        image_dct = torch.matmul(coff, image_dct)
        image_dct = torch.matmul(image_dct, coff.permute(1, 0))
        image_dct = torch.cat(torch.cat(image_dct.chunk(split_num, 0), 3).chunk(split_num, 0), 2)

        return image_dct

    def idct(self, image_dct):
        # coff for dct and idct
        coff = torch.zeros((8, 8), dtype=torch.float).to(image_dct.device)
        coff[0, :] = 1 * np.sqrt(1 / 8)
        for i in range(1, 8):
            for j in range(8):
                coff[i, j] = np.cos(np.pi * i * (2 * j + 1) / (2 * 8)) * np.sqrt(2 / 8)

        split_num = image_dct.shape[2] // 8
        image = torch.cat(torch.cat(image_dct.split(8, 2), 0).split(8, 3), 0)
        image = torch.matmul(coff.permute(1, 0), image)
        image = torch.matmul(image, coff)
        image = torch.cat(torch.cat(image.chunk(split_num, 0), 3).chunk(split_num, 0), 2)

        return image

    def rgb2yuv(self, image_rgb):
        image_yuv = torch.empty_like(image_rgb)
        image_yuv[:, 0:1, :, :] = 0.299 * image_rgb[:, 0:1, :, :] \
                                  + 0.587 * image_rgb[:, 1:2, :, :] + 0.114 * image_rgb[:, 2:3, :, :]
        image_yuv[:, 1:2, :, :] = -0.1687 * image_rgb[:, 0:1, :, :] \
                                  - 0.3313 * image_rgb[:, 1:2, :, :] + 0.5 * image_rgb[:, 2:3, :, :]
        image_yuv[:, 2:3, :, :] = 0.5 * image_rgb[:, 0:1, :, :] \
                                  - 0.4187 * image_rgb[:, 1:2, :, :] - 0.0813 * image_rgb[:, 2:3, :, :]
        return image_yuv

    def yuv2rgb(self, image_yuv):
        image_rgb = torch.empty_like(image_yuv)
        image_rgb[:, 0:1, :, :] = image_yuv[:, 0:1, :, :] + 1.40198758 * image_yuv[:, 2:3, :, :]
        image_rgb[:, 1:2, :, :] = image_yuv[:, 0:1, :, :] - 0.344113281 * image_yuv[:, 1:2, :, :] \
                                  - 0.714103821 * image_yuv[:, 2:3, :, :]
        image_rgb[:, 2:3, :, :] = image_yuv[:, 0:1, :, :] + 1.77197812 * image_yuv[:, 1:2, :, :]
        return image_rgb

    def yuv_dct(self, image, subsample):
        # clamp and convert from [-1,1] to [0,255]
        # image = (image.clamp(-1, 1) + 1) * 255 / 2

        # 图像值域变为[0, 1], 故clamp and convert from [0,1] to [0,255]
        image = image.clamp(0, 1) * 255

        # pad the image so that we can do dct on 8x8 blocks
        pad_height = (8 - image.shape[2] % 8) % 8
        pad_width = (8 - image.shape[3] % 8) % 8
        image = nn.ZeroPad2d((0, pad_width, 0, pad_height))(image)

        # convert to yuv
        image_yuv = self.rgb2yuv(image)

        assert image_yuv.shape[2] % 8 == 0
        assert image_yuv.shape[3] % 8 == 0

        # subsample
        image_subsample = self.subsampling(image_yuv, subsample)

        # apply dct
        image_dct = self.dct(image_subsample)

        return image_dct, pad_width, pad_height

    def idct_rgb(self, image_quantization, pad_width, pad_height):
        # apply inverse dct (idct)
        image_idct = self.idct(image_quantization)

        # transform from yuv to to rgb
        image_ret_padded = self.yuv2rgb(image_idct)

        # un-pad
        image_rgb = image_ret_padded[:, :, :image_ret_padded.shape[2] - pad_height,
                    :image_ret_padded.shape[3] - pad_width].clone()

        # return image_rgb * 2 / 255 - 1
        return image_rgb / 255

    def subsampling(self, image, subsample):
        if subsample == 2:
            split_num = image.shape[2] // 8
            image_block = torch.cat(torch.cat(image.split(8, 2), 0).split(8, 3), 0)
            for i in range(8):
                if i % 2 == 1:
                    image_block[:, 1:3, i, :] = image_block[:, 1:3, i - 1, :]
            for j in range(8):
                if j % 2 == 1:
                    image_block[:, 1:3, :, j] = image_block[:, 1:3, :, j - 1]
            image = torch.cat(torch.cat(image_block.chunk(split_num, 0), 3).chunk(split_num, 0), 2)
        return image


class Jpeg(JpegBasic):
    """
    标准量化
    """

    def __init__(self, Q=50, subsample=0, jpeg_quality_ramp=2500, jpeg_quality=25):
        super(Jpeg, self).__init__()

        # quantization table
        self.Q = Q
        self.scale_factor = 2 - self.Q * 0.02 if self.Q >= 50 else 50 / self.Q
        # subsample
        self.subsample = subsample
        self.jpeg_quality_ramp = jpeg_quality_ramp
        self.jpeg_quality = jpeg_quality
        print("Jpeg()")

    def quality_to_factor(self, quality):
        self.Q = quality
        self.scale_factor = 2 - self.Q * 0.02 if self.Q >= 50 else 50 / self.Q

    def forward(self, image_and_cover_step):
        image, cover_image, train_step = image_and_cover_step

        # calculate jpeg factor
        ramp_fn = lambda ramp: np.min([train_step / ramp, 1.])
        jpeg_quality = 100. - torch.rand(1)[0] * ramp_fn(self.jpeg_quality_ramp) * (100. - self.jpeg_quality)
        # jpeg_quality = 100. - ramp_fn(jpeg_quality_ramp) * (100. - jpeg_quality)
        self.quality_to_factor(jpeg_quality)

        # [0,1] to [0,255], rgb2yuv, dct
        image_dct, pad_width, pad_height = self.yuv_dct(image, self.subsample)

        # quantization
        image_quantization = self.std_quantization(image_dct, self.scale_factor)

        # reverse quantization
        image_quantization = self.std_reverse_quantization(image_quantization, self.scale_factor)

        # idct, yuv2rgb, [0,255] to [0,1]
        noised_image = self.idct_rgb(image_quantization, pad_width, pad_height)
        return noised_image.clamp(0, 1)


class JpegSS(JpegBasic):
    """
    和StageStamp一样,用可微函数替换原始量化
    """

    def __init__(self, Q=50, subsample=0, jpeg_quality_ramp=2500, jpeg_quality=25):
        super(JpegSS, self).__init__()

        # quantization table
        self.Q = Q
        self.scale_factor = 2 - self.Q * 0.02 if self.Q >= 50 else 50 / self.Q
        # subsample
        self.subsample = subsample

        self.jpeg_quality_ramp = jpeg_quality_ramp
        self.jpeg_quality = jpeg_quality
        print("JpegSS()")

    def quality_to_factor(self, quality):
        self.Q = quality
        self.scale_factor = 2 - self.Q * 0.02 if self.Q >= 50 else 50 / self.Q

    def round_ss(self, x):
        # 用可微方式模拟JPEG压缩
        # cond = torch.tensor((torch.abs(x) < 0.5), dtype=torch.float).to(x.device)
        cond = (torch.abs(x) < 0.5).type(torch.float).clone().detach().requires_grad_(True)
        # cond.as_type(torch.float)
        return cond * (x ** 3) + (1 - cond) * x

    def forward(self, image_and_cover_step):
        image, cover_image, train_step = image_and_cover_step

        # calculate jpeg factor
        ramp_fn = lambda ramp: np.min([train_step / ramp, 1.])
        jpeg_quality = 100. - torch.rand(1)[0] * ramp_fn(self.jpeg_quality_ramp) * (100. - self.jpeg_quality)
        # jpeg_quality = 100. - ramp_fn(jpeg_quality_ramp) * (100. - jpeg_quality)
        self.quality_to_factor(jpeg_quality)

        # [0,1] to [0,255], rgb2yuv, dct
        image_dct, pad_width, pad_height = self.yuv_dct(image, self.subsample)

        # quantization
        image_quantization = self.std_quantization(image_dct, self.scale_factor, self.round_ss)

        # reverse quantization
        image_quantization = self.std_reverse_quantization(image_quantization, self.scale_factor)

        # idct, yuv2rgb, [0,255] to [0,1]
        noised_image = self.idct_rgb(image_quantization, pad_width, pad_height)
        return noised_image.clamp(0, 1)


class JpegMask(JpegBasic):
    """
    和HiDDeN一样,采用Mask蒙版模拟量化
    """

    def __init__(self, Q=50, subsample=0, jpeg_quality_ramp=2500, jpeg_quality=25):
        super(JpegMask, self).__init__()

        # quantization table
        self.Q = Q
        self.scale_factor = 2 - self.Q * 0.02 if self.Q >= 50 else 50 / self.Q
        # subsample
        self.subsample = subsample
        self.jpeg_quality_ramp = jpeg_quality_ramp
        self.jpeg_quality = jpeg_quality
        print("JpegMask()")

    def quality_to_factor(self, quality):
        self.Q = quality
        self.scale_factor = 2 - self.Q * 0.02 if self.Q >= 50 else 50 / self.Q

    def round_mask(self, x):
        mask = torch.zeros(1, 3, 8, 8).to(x.device)
        mask[:, 0:1, :5, :5] = 1
        mask[:, 1:3, :3, :3] = 1
        mask = mask.repeat(1, 1, x.shape[2] // 8, x.shape[3] // 8)
        return x * mask

    def forward(self, image_and_cover_step):
        image, cover_image, train_step = image_and_cover_step

        # calculate jpeg factor
        ramp_fn = lambda ramp: np.min([train_step / ramp, 1.])
        jpeg_quality = 100. - torch.rand(1)[0] * ramp_fn(self.jpeg_quality_ramp) * (100. - self.jpeg_quality)
        # jpeg_quality = 100. - ramp_fn(jpeg_quality_ramp) * (100. - jpeg_quality)
        self.quality_to_factor(jpeg_quality)

        # [0,1] to [0,255], rgb2yuv, dct
        image_dct, pad_width, pad_height = self.yuv_dct(image, self.subsample)

        # mask
        image_mask = self.round_mask(image_dct)

        # idct, yuv2rgb, [0,255] to [0,1]
        noised_image = self.idct_rgb(image_mask, pad_width, pad_height)
        return noised_image.clamp(0, 1)

  在上述代码中,本文通过传入的train_step调整QF参数的大小(QF的大小最低为25),随着训练迭代的进行,JPEG攻击强度会线性增强,线性增强阶段结束后,QF值会在[25, 100]之间随机选择,保证模型不出现过拟合,对任意大小的QF参数都有较高的鲁棒性。

八、References

  1. 2018 ECCV, HiDDeN论文地址:https://arxiv.org/abs/1807.09937
  2. 2020 CVPR, StegaStamp论文地址:https://arxiv.org/abs/1904.05343
  3. 2019 ACM MM, TSDL论文地址:https://dl.acm.org/doi/10.1145/3343031.3351025
  4. 2021 ACM MM, MBRS论文地址:https://arxiv.org/abs/2108.08211
  5. 2017 NIPS, JPEG-SS,Song等人的模拟方式,论文地址:https://paperswithcode.com/paper/jpeg-resistant-adversarial-images

  非作者同意,不可复制、转载本文,共建良好创作环境,尊重作者版权(这不也正是数字水印技术所追求的吗 :)

  • 6
    点赞
  • 16
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值