AIGC-《ArtCoder: An End-to-end Method for Generating Scanning-robust Stylized QR Codes》论文+代码解读

《ArtCoder: An End-to-end Method for Generating Scanning-robust Stylized QR Codes》CVPR2021
二维码风格化论文阅读+风格迁移讲解
在这里插入图片描述

论文

https://arxiv.org/abs/2011.07815

二维码的原理可以参考这篇博客:https://blog.csdn.net/y4x5M0nivSrJaY3X92c/article/details/113667206

贡献

• 提出了一种新颖的端到端方法ArtCoder 来生成个性化、多样化且耐扫描的风格化QR 码。
• 我们提出采样模拟层(Sampling-Simulation layer)来提取QR 码的消息,并引入基于模块的代码丢失以保持风格化QR 码的扫描鲁棒性。
• 提出了一种竞争机制,以保证风格化二维码在扫描鲁棒性和视觉效果上的高质量。
• 我们风格化的二维码中的所有模块都更加隐形,并且与整个图像融合得很好。

相关技术

风格迁移

  • 参数化
    参数方法迭代更新初始图像,直到满足所需的全局统计数据。Gatys等人利用 CNN 和 Gram 矩阵的强大功能,开创了参数化 NST 方法。
  • 非参数化
    非参数方法使用简单的补丁表示,并通过最近邻搜索找到最相似的补丁。 非参数 NST 方法由 Li 等人首创。 他们使用马尔可夫随机场(MRF)重新表述风格迁移,即从风格图像中搜索神经补丁以匹配内容图像的结构。

二维码的艺术化

  • 模块变形Module-deformation
    首先选择一个方形模块作为变形的基础,然后对这个方形模块进行一定的变形操作,比如缩小、拉伸、旋转等,使其形状发生改变。接着,在变形后的区域中插入另一幅图像,使其与原始图像融合在一起。通过这种方法,可以实现对图像的局部区域进行精细的处理,同时保持整体图像的连贯性和完整性。

  • 模块重组Module-reshuffle
    基于模块重组的方法受到了开创性工作Qart代码的启发,该工作提出可以利用Gauss-Jordan Elimination Proce-dure来重新排列模块的位置,以满足混合图像的特征。随后,为了提高QR码的视觉质量,后续工作设计了不同的策略来重新排列模块,利用不同的图像特征,例如感兴趣区域、中央显著性、全局灰度值等。

  • NST-based method
    Xu等人首先引入NST技术来生成风格化QR码,并提出SEE(Stylized aEsthEtic)QR码。 他们的方法解决了风格迁移会损害扫描鲁棒性的问题,但是由风格化引起的错误模块通过后处理算法进行修复,这会产生无法与整个图像很好融合的干扰模块。(问题所在)

核心技术

整体架构

在视觉效果上,Q结合了Is的风格特征和Ic的语义内容; 就功能而言,任何标准 QR 码阅读器都可以扫描 Q 以显示消息 M。
Q = Ψ ( I s , I c , M ) Q=Ψ(Is, Ic,M) Q=Ψ(Is,Ic,M)
L t o t a l = λ 1 L s t y l e ( I s , Q ) + λ 2 L c o n t e n t ( I c , Q ) + λ 3 L c o d e ( M , Q ) L_total =λ1L_{style}(Is, Q) + λ2L_{content}(Ic, Q)+ λ3L_{code}(M, Q) Ltotal=λ1Lstyle(Is,Q)+λ2Lcontent(Ic,Q)+λ3Lcode(M,Q)
L_total,L_content,L_style,L_code:对应的损失函数
Is:目标风格图像
Ic:内容目标图像
M:生成的消息
在这里插入图片描述

风格和内容的特征由VGG-19 提取,二维码特征由所提出的采样模拟(SS)层提取。 在优化器的每次迭代中,虚拟二维码读取器RQR将读取程式化结果Q以区分所有错误并纠正模块。在优化器的每次迭代中,虚拟二维码读取器 R_QR 将读取程式化结果 Q 来区分所有错误并纠正模块 。对于第 k 个模块 M_k,如果 M_k 错误(or正确),我们控制激活映射 K 来激活(或停用)the k-th sub-code-loss L_
code(Mk),以优化其鲁棒性,但可能会损害表示 风格和内容(或者Lstyle和Lcontent会尽力优化视觉质量)。

损失函数(参考风格迁移)

风格图像和内容图像对应的损失函数和style stransfer里面提到的一样:
在这里插入图片描述
二维码对应的损失函数:

Lcode​ 是基于 QR 码 Q 中的每个模块 𝑀_k 设定的。对于 QR 码中的每个模块 𝑀_k ,设置一个子代码损失 𝐿code_Mk,然后将所有这些子代码损失加总以得到总的代码损失 𝐿codeLcode​。这个损失函数的目的是确保每个模块在视觉上的变化不会影响其在 QR 码读取器中的解码能力,从而保持 QR 码的扫描鲁棒性。数学表达式为:
L code = ∑ M k ∈ Q L code M k L_{\text{code}} = \sum_{M_k∈Q} L_{\text{code}}^{M_k} Lcode=MkQLcodeMk
对于输入消息 𝑀, QR 码编码器 𝐸_𝑄𝑅 将 𝑀编码为一个代码目标 𝑀=𝐸_𝑄𝑅(𝑀)​,这个代码目标是一个 𝑚×𝑚矩阵,矩阵中的的 1 或 0 标记每个模块的理想颜色(0/1 表示黑/白)。对于每个模块 𝑀𝑘​,定义一个子代码损失 𝐿codeL_𝑀𝑘如下:
L code M k = K M k ∥ M M k − F M k ∥ 2 L_{\text{code}}^{M_k} = K_{M_k} \left\| M_{M_k} - F_{M_k} \right\|^2 LcodeMk=KMkMMkFMk2
F:是由 Sampling-Simulation (SS) 层提取的特征图
K:K 是由竞争机制计算的激活图,𝐾_𝑀𝑘用于激活子代码损失 𝐿code_Mk(后面会讲到竞争机制)
M_MK:E_QR这是第 𝑘 个模块的理想或目标二进制值

Sampling-Simulation layer

  • QR码的采样:
    Google ZXing规定,QR码阅读器仅对QR码中每个模块的中心像素进行采样,然后对这些像素进行二值化和解码。对像素中心的采样概率遵从高斯分布

    作者认为如果使用卷积层来模拟QR码阅读器的采样过程,就可以通过损失反向传播来控制QR码的鲁棒性。

    • 二维 码 Q:对于由 m×m 个大小为 a×a 的模块组成的已经风格化了的二维码
    • l_ss :卷积核,被设计为具有kernal为 a、步幅 a、填充 0,内核权重遵循高斯分布。
    • F:当我们将Q输入到lss时,内核会对Q的每个模块进行一次卷积,输出一个m×m的特征图F=lss(Q),表示Q的采样结果。特征图F中的每一位F_Mk对应于 Q 中的模块 Mk,表示为
      G M k ( i , j ) = 1 2 π σ 2 e − i 2 + j 2 2 σ 2 \mathcal{G}_{M_k}(i,j) = \frac{1}{2\pi\sigma^2} e^{-\frac{i^2 + j^2}{2\sigma^2}} GMk(i,j)=2πσ21e2σ2i2+j2
      F M k = ∑ ( i , j ) ∈ M k G M k ( i , j ) ⋅ Q M k ( i , j ) \mathcal{F}_{M_{k}}=\sum_{(i, j) \in M_{k}} \mathcal{G}_{M_{k}(i, j)} \cdot Q_{M_{k}(i, j)} FMk=(i,j)MkGMk(i,j)QMk(i,j)
      在这里插入图片描述

竞争机制

在这里插入图片描述
L_code 试图使每个模块扫描具有鲁棒性,但会牺牲视觉质量
L_style和L_content试图提高Q的视觉质量并损害扫描鲁棒性
在每次迭代中,虚拟二维码阅读器RQR读取Q以找出所有错误模块,然后构造激活图K,定义为
在这里插入图片描述
M_Mk:这个值是从输入消息 𝑀 通过美学 QR 码编码器 𝐸_𝑄𝑅编码得到的第 𝑘 个模块的理想值
R_QR(Q_Mk ):是虚拟二维码阅读器R_QR阅读Q的第k个模块QMk的读取结果。每个模块被非黑即白。这个矩阵代表了 QR 码的理想结构,即在没有任何视觉风格化干预的情况下,应该如何正确地显示和扫描。
Q_MK:这是这个值是由风格化函数 Ψ 生成的 QR 码中第 𝑘个模块的值.
K_MK:如果一个模块QMk是正确的(鲁棒的),那么K_Mk=0,Lcode_MK=0,我们的模型将尽力优化Lvisual={Lstyle, Lcontent },改进风格和内容特征。 如果这些修改使 QMk 出错,则 K_Mk=1, Lcode_Mk 代码将被激活以优化 QMk 的鲁棒性。

Virtual QR Code Reader

这个虚拟阅读器的设计是为了在优化过程中实时检测和调整QR码的可读性,从而确保生成的QR码既具有艺术风格又能被标准的QR码扫描设备有效识别。

  • 采样与二值化
    QR码阅读器通常只采样每个模块中心的像素点,然后将这些像素点二值化(转换为黑白)以解码。将采样的彩色像素二值化(通过一个阈值转换成非黑即白,0黑1白)
    在这里插入图片描述
  • 虚拟二维码读取器进行判断
    结合公式(7)(8),𝑅_𝑄𝑅(𝑄_𝑀𝑘)表示的是虚拟QR码阅读器对第 𝑘 个模块 𝑄_𝑀𝑘的读取结果。这个读取结果是基于风格化后的QR码模块 𝑄_𝑀𝑘​​ 的视觉表现进行的解码尝试。
    在这里插入图片描述
    T_b:表示对黑色模块进行二值化所采用的虚拟阈值,低于T_b设为黑色
    T_w:表示对白色模块进行二值化所采用的虚拟阈值,高于T_w设为白色
    η=|T−Tb|/T= |Tw−T|/(255−T):判别的严格程度主要受T_b和T_w影响。可以通过设置η来控制Tw和Tb,进一步权衡视觉质量和鲁棒性

R_QR(Q_MK)⊕M_MK=1:Q_MK is error
R_QR(Q_MK)⊕M_MK=0:Q_MK is correct

扫描鲁棒性

对于每个模块的采样像素,无论其颜色如何变化,只要保留与理想颜色相同的二值结果,就可以保持扫描鲁棒性。
参数η控制判别误差模块的严格程度(即控制模块的鲁棒性)。 具体来说,当设置较高的 η 时,每个模块必须更黑/更白才能被归类为鲁棒模块

不足

  • 生成二维码的速度相对来说不够快
  • 代码的稳健性略低于传统的二维码(但它们足以支持现实世界的应用程序)

补充:风格迁移

在论文中,损失函数用于指导图像风格迁移的过程,确保生成的图像既符合内容图像的内容也表示风格图像的风格。损失函数由两部分组成:内容损失(Content Loss)和风格损失(Style Loss)。下面是损失函数的具体公式以及各符号的含义:

  1. 内容损失(Content Loss):
    L c o n t e n t ( p , x , l ) = 1 2 ∑ i , j ( F i j l − P i j l ) 2 L_{content}(\mathbf{p}, \mathbf{x}, l) = \frac{1}{2} \sum_{i,j} (F^{l}_{ij} - P^{l}_{ij})^2 Lcontent(p,x,l)=21i,j(FijlPijl)2

    • p:原始的内容图像。
    • x:生成的图像,通过迭代优化以匹配内容和风格。
    • l:CNN中的特定层级。
    • Fij_L:在层l中,原始内容图像的特征表示(即滤波器在位置j的激活)。
    • Pij_L:在层l中,生成图像的特征表示。

    内容损失测量的是生成图像在特定层级上的特征表示与内容图像特征表示之间的差异。

  2. 风格损失(Style Loss):
    E l = 1 4 N l 2 M l 2 ∑ i , j ( G i j l − A i j l ) 2 E_l = \frac{1}{4N^2_l M^2_l} \sum_{i,j} (G^{l}_{ij} - A^{l}_{ij})^2 El=4Nl2Ml21i,j(GijlAijl)2

    • A_l:原始风格图像在层l的Gram矩阵。
    • G_l:生成图像在层l的Gram矩阵。
    • N_l:层l中滤波器的数量。
    • M_l:层l中特征图的大小。

    Gram矩阵的元素( G^{l}_{ij} )是滤波器响应之间的相关性度量,计算公式为:
    G i j l = ∑ k F i k l F j k l G^{l}_{ij} = \sum_{k} F^{l}_{ik} F^{l}_{jk} Gijl=kFiklFjkl

    风格损失测量的是生成图像在特定层级上的风格表示(即Gram矩阵)与风格图像的风格表示之间的差异。

  3. 总损失(Total Loss):
    L t o t a l ( p , a , x ) = α L c o n t e n t ( p , x ) + β ∑ l = 0 L w l E l L_{total}(\mathbf{p}, \mathbf{a}, \mathbf{x}) = \alpha L_{content}(\mathbf{p}, \mathbf{x}) + \beta \sum_{l=0}^{L} w_l E_l Ltotal(p,a,x)=αLcontent(p,x)+βl=0LwlEl

    • α:内容损失的权重。
    • β:风格损失的权重。
    • w:不同层级风格损失的权重。
    • l:参与计算的CNN层级总数。

    总损失是内容损失和风格损失的加权和,其中风格损失是所有层级风格损失的加权求和。通过最小化总损失,可以在保持内容的同时,使生成图像的风格与目标风格图像相匹配。

在风格迁移的过程中,通过梯度下降等优化算法调整生成图像x的像素值,从而最小化损失函数L_total。通过这种方式,生成的图像既保留了内容图像的视觉内容,又复现了风格图像的风格特征。

代码分析

https://github.com/SwordHolderSH/ArtCoder

SS_layer

这种设计可以模拟QR码扫描过程中的采样行为,其中高斯权重模拟了采样点周围像素对采样结果的影响。

import os
from PIL import Image
import torch
import torch.nn as nn
import torch.utils.data as Data
import torchvision
import matplotlib.pyplot as plt
from torchvision import transforms
import numpy as np
import utils

# 定义SSlayer类,继承自nn.Module
class SSlayer(nn.Module):
    def __init__(self, requires_grad=False):
        super(SSlayer, self).__init__()  

        # 创建一个卷积层,输入输出通道数为3,核大小为16,步长为16,无填充,不使用偏置
        cov_module = nn.Conv2d(in_channels=3, out_channels=3, kernel_size=16, stride=16, padding=0, bias=False)
        # 获取一个3D高斯权重
        weight = utils.get_3DGauss()  # [16,16]
        weight = weight.unsqueeze(0).unsqueeze(0)  # 增加两个维度,变为[1,1,16,16]
        weight = torch.cat([weight, weight, weight], dim=1)  # 变为[1,3,16,16],适应于RGB图像
        cov_module.weight = nn.Parameter(weight) 
        self.conv_module = nn.Sequential(
            cov_module
        )
        if not requires_grad:
            for param in self.parameters():
                param.requires_grad = False

    def forward(self, x):
        x = x.repeat(1, 1, 1, 1)  # 重复,保持维度不变
        x = self.conv_module(x)  # 通过卷积层处理数据
        return x  # 返回处理后的数据

utils.py

包含了多个函数,主要用于图像处理和QR码的生成与分析。

import torch
from PIL import Image
import os
import numpy as np
import math
import pandas as pd
from torchvision import transforms
import shutil

unloader = transforms.ToPILImage()
load = transforms.ToTensor()


def load_image(filename, size=None, scale=None):#加载图像
    img = Image.open(filename)
    if size is not None:
        img = img.resize((size, size), Image.ANTIALIAS)
    elif scale is not None:
        img = img.resize((int(img.size[0] / scale), int(img.size[1] / scale)), Image.ANTIALIAS)
    return img

# 将二维码的图案添加到目标图像
def add_pattern(target_PIL, code_PIL, module_number=37, 
module_size=16):
    target_img = np.asarray(target_PIL)
    code_img = np.array(code_PIL)
    output = target_img
    output = np.require(output, dtype='uint8', requirements=['O', 'W'])
    ms = module_size  # module size
    mn = module_number  # module_number
     # 在特定区域将目标图像替换为二维码图像
     #将四个角复制到相同区域
    output[0 * ms:(8 * ms) - 1, 0 * ms:(8 * ms) - 1, :] = code_img[0 * ms:(8 * ms) - 1, 0 * ms:(8 * ms) - 1, :]
    output[((mn - 8) * ms) + 1:(mn * ms), 0 * ms:(8 * ms) - 1, :] = code_img[((mn - 8) * ms) + 1:(mn * ms),
                                                                    0 * ms:(8 * ms) - 1,
                                                                    :]
    output[0 * ms: (8 * ms) - 1, ((mn - 8) * ms) + 1:(mn * ms), :] = code_img[0 * ms: (8 * ms) - 1,
                                                                     ((mn - 8) * ms) + 1:(mn * ms), :]
    output[28 * ms: (33 * ms) - 1, 28 * ms:(33 * ms) - 1, :] = code_img[28 * ms: (33 * ms) - 1, 28 * ms:(33 * ms) - 1,
                                                               :]

    output = Image.fromarray(output.astype('uint8'))
    print('Added finder and alignment patterns.')
    return output


# 删除文件夹中的所有文件和子文件夹的函数
def del_file(filepath):
    del_list = os.listdir(filepath)
    for f in del_list:
        file_path = os.path.join(filepath, f)
        if os.path.isfile(file_path):
            os.remove(file_path)
        elif os.path.isdir(file_path):
            shutil.rmtree(file_path)

# 计算特征图的Gram矩阵的函数
def gram_matrix(y):
    (b, ch, h, w) = y.size()
    features = y.view(b, ch, w * h)
    features_t = features.transpose(1, 2)
    gram = features.bmm(features_t) / (ch * h * w)
    return gram

# 获取目标图像和二维码图像的误差矩阵的函数
def get_action_matrix(img_target, img_code, module_size=16, IMG_SIZE=592, Dis_b=50, Dis_w=200):
    img_code = np.require(np.asarray(img_code.convert('L')), dtype='uint8', requirements=['O', 'W'])
    img_target = np.require(np.array(img_target.convert('L')), dtype='uint8', requirements=['O', 'W'])
	# 获取二维码图像的理想二进制结果
    ideal_result = get_binary_result(img_code, module_size)
     # 获取目标图像的中心像素矩阵
    center_mat = get_center_pixel(img_target, module_size)
     # 获取误差模块
    error_module = get_error_module(center_mat, code_result=ideal_result,
                                    threshold_b=Dis_b,
                                    threshold_w=Dis_w)
    return error_module, ideal_result

# 根据二维码图像获取二进制结果的函数
def get_binary_result(img_code, module_size, module_number=37):
    binary_result = np.zeros((module_number, module_number))
    for j in range(module_number):
        for i in range(module_number):
            module = img_code[i * module_size:(i + 1) * module_size, j * module_size:(j + 1) * module_size]
            module_color = np.around(np.mean(module), decimals=2)
            if module_color < 128:
                binary_result[i, j] = 0
            else:
                binary_result[i, j] = 1
    return binary_result

# 获取目标图像中心像素值的函数
#37模块每个16像素,则整个图像是592x592像素
def get_center_pixel(img_target, module_size):
#module_size是每个模块的像素尺寸
    center_mat = np.zeros((37, 37))#初始化
    for j in range(37):
        for i in range(37):
            module = img_target[i * module_size:(i + 1) * module_size, j * module_size:(j + 1) * module_size]
            #提取一个大矩形区域
            module_color = np.mean(module[5:12, 5:12])# 计算模块中心的像素值
            center_mat[i, j] = module_color
    return center_mat


def get_error_module(center_mat, code_result, threshold_b, threshold_w):
    error_module = np.ones((37, 37))  # 0 means correct,1 means error
    for j in range(37):
        for i in range(37):
            center_pixel = center_mat[i, j]  # 获取中心像素值
            right_result = code_result[i, j]  # 获取正确的二进制结果
            if right_result == 0 and center_pixel < threshold_b:
            # 如果正确的结果是0且中心像素值小于阈值
                error_module[i, j] = 0
            elif right_result == 1 and center_pixel > threshold_w:
            # 如果正确的结果是1且中心像素值大于阈值
                error_module[i, j] = 0
            else:
                error_module[i, j] = 1
    return error_module

# 根据二进制结果和颜色值获取目标图像
def get_target(binary_result, b_robust, w_robust, module_num=37, module_size=16):
    img_size = module_size * module_num  # 计算图像大小
    target = np.require(np.ones((img_size, img_size)), dtype='uint8', requirements=['O', 'W'])  # 初始化目标图像
    for i in range(module_num):  # 遍历模块
        for j in range(module_num):  # 遍历模块
            one_binary_result = binary_result[i, j]  # 获取单个二进制结果
            if one_binary_result == 0:  # 如果二进制结果是0
                target[i * module_size:(i + 1) * module_size, j * module_size:(j + 1) * module_size] = b_robust  # 设置为黑色
            else:  # 如果二进制结果是1
                target[i * module_size:(i + 1) * module_size, j * module_size:(j + 1) * module_size] = w_robust  # 设置为白色
    target = load(Image.fromarray(target.astype('uint8')).convert('RGB')).unsqueeze(0).cuda()
    return target  # 返回目标图像


def save_image_epoch(tensor, path, name, code_pil, addpattern=True):
    """Save a single image."""
    image = tensor.cpu().clone()
    image = image.squeeze(0)
    image = unloader(image)
    if addpattern == True:
        image = add_pattern(image, code_pil, module_number=37, module_size=16)
    image.save(os.path.join(path, "epoch_" + str(name)))


def tensor_to_PIL(tensor):
    image = tensor.cpu().clone()
    image = image.squeeze(0)
    image = unloader(image)
    return image


def get_3DGauss(s=0, e=15, sigma=1.5, mu=7.5):
    x, y = np.mgrid[s:e:16j, s:e:16j]
    z = (1 / (2 * math.pi * sigma ** 2)) * np.exp(-((x - mu) ** 2 + (y - mu) ** 2) / (2 * sigma ** 2))
    z = torch.from_numpy(MaxMinNormalization(z.astype(np.float32)))
    for j in range(16):
        for i in range(16):
            if z[i, j] < 0.1:
                z[i, j] = 0
    return z

#归一化
def MaxMinNormalization(loss_img):
    maxvalue = np.max(loss_img)
    minvalue = np.min(loss_img)
    img = (loss_img - minvalue) / (maxvalue - minvalue)
    img = np.around(img, decimals=2)
    return img


def print_options(opt):
    """Print and save options
    It will print both current options and default values(if different).
    It will save options into a text file / [checkpoints_dir] / opt.txt
    """
    message = ''
    message += '----------------- Options ---------------\n'
    for k, v in sorted(vars(opt).items()):
        comment = ''
        message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment)
    message += '----------------- End -------------------'
    print(message)

ArtCoder.py

from vgg import Vgg16
import utils
from torchvision import transforms
import torch.nn as nn
import torch.optim as optim
import numpy as np
import torch
from SS_layer import SSlayer
# 定义artcoder函数,接收风格图像、内容图像、代码图像路径、输出目录和其他超参数
def artcoder(STYLE_IMG_PATH, CONTENT_IMG_PATH, CODE_PATH, OUTPUT_DIR,
             LEARNING_RATE=0.01, CONTENT_WEIGHT=1e8, STYLE_WEIGHT=1e15, CODE_WEIGHT=1e15, MODULE_SIZE=16, MODULE_NUM=37,
             EPOCHS=50000, Dis_b=80, Dis_w=180, Correct_b=50, Correct_w=200, USE_ACTIVATION_MECHANISM=True):
    # STYLE_IMG_PATH = './style/redwave4.jpg'
    # CONTENT_IMG_PATH = './content/boy.jpg'
    # CODE_PATH = './code/boy.jpg'
    # OUTPUT_DIR = './output/'
    utils.del_file(OUTPUT_DIR)
    IMAGE_SIZE = MODULE_SIZE * MODULE_NUM

    transform = transforms.Compose([
        transforms.Resize(IMAGE_SIZE),
        transforms.ToTensor(),
    ])

    vgg = Vgg16(requires_grad=False).cuda()  # vgg16 model
    ## 加载自定义的SSlayer,也不需要梯度
    ss_layer = SSlayer(requires_grad=False).cuda()
 # 使用定义的load_image函数加载风格、内容和代码图像,并添加二维码图案
    style_img = utils.load_image(filename=STYLE_IMG_PATH, size=IMAGE_SIZE)
    content_img = utils.load_image(filename=CONTENT_IMG_PATH, size=IMAGE_SIZE)
    code_img = utils.load_image(filename=CODE_PATH, size=IMAGE_SIZE)
    init_img = utils.add_pattern(content_img, code_img)


    style_img = transform(style_img)
    content_img = transform(content_img)
    init_img = transform(init_img)

    init_img = init_img.repeat(1, 1, 1, 1).cuda()
    style_img = style_img.repeat(1, 1, 1, 1).cuda()  # make fake batch
    content_img = content_img.repeat(1, 1, 1, 1).cuda()
	# 从VGG模型提取风格图像和内容图像的特征
    features_style = vgg(style_img) 
    features_content = vgg(content_img)

    gram_style = [utils.gram_matrix(i) for i in features_style]  # gram matrix of style feature
    mse_loss = nn.MSELoss()

    y = init_img.detach()  # y is the target output. Optimized start from the content image.
    y = y.requires_grad_()  # let y to require grad

    optimizer = optim.Adam([y], lr=LEARNING_RATE)  # let optimizer to optimize the tensor y
	# 获取代码图像的误差矩阵和理想二进制结果
    error_matrix, ideal_result = utils.get_action_matrix(
        img_target=utils.tensor_to_PIL(y),
        img_code=code_img,
        Dis_b=Dis_b, Dis_w=Dis_w
    )
    # 使用SSlayer处理理想结果,生成目标二进制图像
    code_target = ss_layer(utils.get_target(ideal_result, b_robust=Correct_b, w_robust=Correct_w))
#开始训练
    print(" Start training =============================================")
    for epoch in range(EPOCHS):

        def closure(code_target=code_target):

            optimizer.zero_grad()
            y.data.clamp_(0, 1)
            features_y = vgg(y)  # feature maps of y extracted from VGG
            gram_style_y = [utils.gram_matrix(i) for i in
                            features_y]  # gram matrixs of feature_y in relu1_2,2_2,3_3,4_3
			# 获取内容和风格目标特征
            fc = features_content.relu3_3  # content target in relu4_3
            fy = features_y.relu3_3  # y in relu4_3
				
            style_loss = 0  # add style_losses in relu1_2,2_2,3_3,4_3
            for i in [0, 1, 2, 3]:
                style_loss += mse_loss(gram_style_y[i], gram_style[i])
            style_loss = STYLE_WEIGHT * style_loss

            code_y = ss_layer(y)

            if USE_ACTIVATION_MECHANISM == 1:
            # 重新计算误差矩阵和理想结果
                error_matrix, ideal_result = utils.get_action_matrix(
                    img_target=utils.tensor_to_PIL(y),
                    img_code=code_img,
                    Dis_b=Dis_b, Dis_w=Dis_w)
                activate_num = np.sum(error_matrix)
                activate_weight = torch.tensor(error_matrix.astype('float32'))
                code_y = code_y.cpu() * activate_weight
                code_target = code_target.cpu() * activate_weight
            else:
                code_y = code_y.cpu()
                code_target = code_target.cpu()
                activate_num = MODULE_NUM * MODULE_NUM

            code_loss = CODE_WEIGHT * mse_loss(code_target.cuda(), code_y.cuda())
            content_loss = CONTENT_WEIGHT * mse_loss(fc, fy)  # content loss

            # tv_loss = TV_WEIGHT * (torch.sum(torch.abs(y[:, :, :, :-1] - y[:, :, :, 1:])) +
            #                        torch.sum(torch.abs(y[:, :, :-1, :] - y[:, :, 1:, :])))

            total_loss = style_loss + code_loss + content_loss
            total_loss.backward(retain_graph=True)

            if epoch % 20 == 0:
                print(
                    "Epoch {}: Style Loss : {:4f}. Content Loss: {:4f}. Code Loss: {:4f}. Activated module number: {:4.2f}. Discriminate_b:{:4.2f}. Discriminate_w:{:4.2f}.".format(
                        epoch, style_loss, content_loss, code_loss, activate_num, Dis_b, Dis_w)
                )
            if epoch % 200 == 0:
                img_name = 'epoch=' + str(epoch) + '__Wstyle=' + str("%.1e" % STYLE_WEIGHT) + '__Wcode=' + str(
                    "%.1e" % CODE_WEIGHT) + '__Wcontent' + str(
                    "%.1e" % CONTENT_WEIGHT) + '.jpg'
                utils.save_image_epoch(y, OUTPUT_DIR, img_name, code_img, addpattern=True)
                print('Save output: ' + img_name)
                return total_loss

        optimizer.step(closure)
  • 21
    点赞
  • 12
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值