HiDDeN Encoder——基于深度学习的水印生成网络“编码器”详解

一、简介

论文链接:点击此链接查看HiDDeN的文献

论文中的Encoder结构,如下图所示:
论文Encoder结果
Encoder具体结构如图2所示。

Encoder功能: 将30bit消息串嵌入到原始图像中,形成编码图像

Encoder输入1: Cover Image(原始图像) = 3×128×128
Encoder输入2: Message(随机二进制bit) = 30bit

Encoder输出: Encoded Image(编码图像) = 3×128×128

Encoder实现方法:

  1. 对Cover Image(原始图像)应用4次convBNReLU卷积,提取特征
    将这一步的结果命名为卷积图像
  2. 对Message(随机二进制bit)进行复制和扩展
  3. 将Cover Image(原始图像)、卷积图像、扩展后的二进制消息进行连接
  4. 对连接结果应用卷积
  5. 形成3×128×128的Encoded Image(编码图像)

在这里插入图片描述

图2 网络具体结构图

在这里插入图片描述

图3 ConvBNRelu结构图

二、编码器代码模块

import torch
import torch.nn as nn


class MyNet(nn.Module):
    def __init__(self):
        super(MyNet, self).__init__()
        self.conv_layers = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),

            nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),

            nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),

            nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
        )
        self.after_concat_layer = nn.Sequential(
            nn.Conv2d(97, 64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True)
        )
        self.final_layer = nn.Conv2d(64, 3, kernel_size=1)

        # initialize H, W
        self.H = 128
        self.W = 128

    def forward(self, image, message):
        # 首先,将消息的最后两个维度添加两个虚拟维度。
        expanded_message = message.unsqueeze(-1)
        expanded_message.unsqueeze_(-1)
        # 然后,将消息扩展为图像大小
        expanded_message = expanded_message.expand(-1, -1, self.H, self.W)
        conved_image = self.conv_layers(image)
        # 接着,连接 expanded_message, conved_image, image
        print("conved_image.shape: ", conved_image.shape)
        print("expanded_message.shape: ", expanded_message.shape)
        print("image.shape: ", image.shape)
        concat = torch.cat([expanded_message, conved_image, image], dim=1)
        im_w = self.after_concat_layer(concat)
        im_w = self.final_layer(im_w)
        return im_w


# 创建图片和消息张量
image = torch.randn(1, 3, 128, 128)
print("image.shape: ", image.shape)
message = torch.randn(1, 30)
print("message.shape: ", message.shape)
print()

# 初始化模型
model = MyNet()

# 调用模型
output = model(image, message)

# 打印输出张量的形状
print("output.shape: ", output.shape)

运行结果如下图所示:
在这里插入图片描述

三、网络结构可视化

pytorch Graphviz安装使用方法:点击此链接,查看pytorch Graphviz安装使用方法

添加代码如下所示:

from torchviz import make_dot
import graphviz

......

# 将模型输出和模型参数传递给make_dot函数
dot = make_dot(output, params=dict(model.named_parameters()))
# 使用Graphviz渲染图形
graphviz.Source(dot)
# 将图形渲染为PDF文件
dot.render('output.pdf', format='pdf')

可视化结果

四、重点函数详解

4.1 expand函数

expand函数功能: 对变量内容进行复制和填充,扩展成指定格式。
expand函数参数: -1表示不进行扩展。

备注: 需要先使用x.unsqueeze(-1)函数添加维度,再进行复制和扩展

expand函数 例1

import torch

x = torch.tensor([11, 20, 15])
print("x shape:", x.shape)
print("x values:", x)

# 将x扩展为[3, 2]
y = x.unsqueeze(-1).expand(-1, 2)
print("y shape:", y.shape)
print("y values:"), print(y)

运行结果,如下所示:
输出结果

expand函数 例2

# x [6]
x = torch.tensor([11, 20, 15, 255, 145, 75])
print("x shape:", x.shape)
print("x values:", x)

# 将x扩展为[6, 128, 128]
y = x.unsqueeze(-1).unsqueeze(-1).expand(-1, 128, 128)
print("y shape:", y.shape)
print("y values:"), print(y)

运行结果,如下所示:
运行结果

4.2 cat函数

cat函数功能: 将多个张量沿着指定的维度拼接起来
cat函数参数: 参数1:要连接的变量列表; 参数2:在哪个维度上进行连接

cat函数 例1:

import torch

a = torch.tensor([1, 2, 3])
b = torch.tensor([4, 5, 6])
print(a.shape)
print(b.shape)

c = torch.cat([a, b], dim=0)
print(c)

运行结果,如下所示:
运行结果
cat函数 例2: 沿着维度0进行连接

# 两个形状为 (1, 3, 4) 的张量
x1 = torch.tensor([[[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]]])
x2 = torch.tensor([[[11, 22, 33, 44], [55, 66, 77, 88], [99, 100, 111, 122]]])
print(x1)
print(x1.shape)
print(x2)
print(x2.shape)
# 将 x1 和 x2 沿着第 0 维度连接
x3 = torch.cat([x1, x2], dim=0)
print(x3)
print(x3.shape)

运行结果,如下所示:
在这里插入图片描述
cat函数 例2: 沿着维度1进行连接

# 两个形状为 (1, 3, 4) 的张量
x1 = torch.tensor([[[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]]])
x2 = torch.tensor([[[11, 22, 33, 44], [55, 66, 77, 88], [99, 100, 111, 122]]])
print(x1)
print(x1.shape)
print(x2)
print(x2.shape)
# 将 x1 和 x2 沿着第 1 维度连接
x3 = torch.cat([x1, x2], dim=1)
print(x3)
print(x3.shape)

运行结果,如下所示:
在这里插入图片描述

  • 9
    点赞
  • 16
    收藏
    觉得还不错? 一键收藏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值