一、简介
论文链接:点击此链接查看HiDDeN的文献
论文中的Encoder结构,如下图所示:
Encoder具体结构如图2所示。
Encoder功能: 将30bit消息串嵌入到原始图像中,形成编码图像
Encoder输入1: Cover Image(原始图像) = 3×128×128
Encoder输入2: Message(随机二进制bit) = 30bit
Encoder输出: Encoded Image(编码图像) = 3×128×128
Encoder实现方法:
- 对Cover Image(原始图像)应用4次convBNReLU卷积,提取特征
将这一步的结果命名为卷积图像 - 对Message(随机二进制bit)进行复制和扩展
- 将Cover Image(原始图像)、卷积图像、扩展后的二进制消息进行连接
- 对连接结果应用卷积
- 形成3×128×128的Encoded Image(编码图像)
二、编码器代码模块
import torch
import torch.nn as nn
class MyNet(nn.Module):
def __init__(self):
super(MyNet, self).__init__()
self.conv_layers = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
)
self.after_concat_layer = nn.Sequential(
nn.Conv2d(97, 64, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True)
)
self.final_layer = nn.Conv2d(64, 3, kernel_size=1)
# initialize H, W
self.H = 128
self.W = 128
def forward(self, image, message):
# 首先,将消息的最后两个维度添加两个虚拟维度。
expanded_message = message.unsqueeze(-1)
expanded_message.unsqueeze_(-1)
# 然后,将消息扩展为图像大小
expanded_message = expanded_message.expand(-1, -1, self.H, self.W)
conved_image = self.conv_layers(image)
# 接着,连接 expanded_message, conved_image, image
print("conved_image.shape: ", conved_image.shape)
print("expanded_message.shape: ", expanded_message.shape)
print("image.shape: ", image.shape)
concat = torch.cat([expanded_message, conved_image, image], dim=1)
im_w = self.after_concat_layer(concat)
im_w = self.final_layer(im_w)
return im_w
# 创建图片和消息张量
image = torch.randn(1, 3, 128, 128)
print("image.shape: ", image.shape)
message = torch.randn(1, 30)
print("message.shape: ", message.shape)
print()
# 初始化模型
model = MyNet()
# 调用模型
output = model(image, message)
# 打印输出张量的形状
print("output.shape: ", output.shape)
运行结果如下图所示:
三、网络结构可视化
pytorch Graphviz安装使用方法:点击此链接,查看pytorch Graphviz安装使用方法
添加代码如下所示:
from torchviz import make_dot
import graphviz
......
# 将模型输出和模型参数传递给make_dot函数
dot = make_dot(output, params=dict(model.named_parameters()))
# 使用Graphviz渲染图形
graphviz.Source(dot)
# 将图形渲染为PDF文件
dot.render('output.pdf', format='pdf')
四、重点函数详解
4.1 expand函数
expand函数功能: 对变量内容进行复制和填充,扩展成指定格式。
expand函数参数: -1表示不进行扩展。
备注: 需要先使用x.unsqueeze(-1)函数添加维度,再进行复制和扩展
expand函数 例1 :
import torch
x = torch.tensor([11, 20, 15])
print("x shape:", x.shape)
print("x values:", x)
# 将x扩展为[3, 2]
y = x.unsqueeze(-1).expand(-1, 2)
print("y shape:", y.shape)
print("y values:"), print(y)
运行结果,如下所示:
expand函数 例2 :
# x [6]
x = torch.tensor([11, 20, 15, 255, 145, 75])
print("x shape:", x.shape)
print("x values:", x)
# 将x扩展为[6, 128, 128]
y = x.unsqueeze(-1).unsqueeze(-1).expand(-1, 128, 128)
print("y shape:", y.shape)
print("y values:"), print(y)
运行结果,如下所示:
4.2 cat函数
cat函数功能: 将多个张量沿着指定的维度拼接起来
cat函数参数: 参数1:要连接的变量列表; 参数2:在哪个维度上进行连接
cat函数 例1:
import torch
a = torch.tensor([1, 2, 3])
b = torch.tensor([4, 5, 6])
print(a.shape)
print(b.shape)
c = torch.cat([a, b], dim=0)
print(c)
运行结果,如下所示:
cat函数 例2: 沿着维度0进行连接
# 两个形状为 (1, 3, 4) 的张量
x1 = torch.tensor([[[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]]])
x2 = torch.tensor([[[11, 22, 33, 44], [55, 66, 77, 88], [99, 100, 111, 122]]])
print(x1)
print(x1.shape)
print(x2)
print(x2.shape)
# 将 x1 和 x2 沿着第 0 维度连接
x3 = torch.cat([x1, x2], dim=0)
print(x3)
print(x3.shape)
运行结果,如下所示:
cat函数 例2: 沿着维度1进行连接
# 两个形状为 (1, 3, 4) 的张量
x1 = torch.tensor([[[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]]])
x2 = torch.tensor([[[11, 22, 33, 44], [55, 66, 77, 88], [99, 100, 111, 122]]])
print(x1)
print(x1.shape)
print(x2)
print(x2.shape)
# 将 x1 和 x2 沿着第 1 维度连接
x3 = torch.cat([x1, x2], dim=1)
print(x3)
print(x3.shape)
运行结果,如下所示: