一、简介
论文链接: 点击此链接查看HiDDeN的文献
论文中的Discriminator结构,如下图所示:
Discriminator功能: 判断输入图片是Cover Image(原始图像)还是Encoded Image(编码图像)
Discriminator输入: Cover Image(原始图像) / Encoded Image(编码图像)
Discriminator输出: 预测 Image 是编码图像的概率 ∈ [0,1]
Discriminator实现步骤:
step 1: 对输入图像应用多次ConvBNRelu卷积,进行特征提取
step 2: 应用AdaptiveAvgPool2d全局平均池化,将输入的特征图压缩成 [1, 64, 1, 1]
step 3: 进行维度压缩,去掉最后两个维度 [1, 64]
step 4: 应用Linear全连接,将网络的输出展平为一维向量,产生预测结果
其中,64是训练网络时提前配置的鉴别器通道数discriminator_channels
二、鉴别器代码模块
import torch.nn as nn
from PIL import Image
from torchvision.transforms import transforms
class ConvBNRelu(nn.Module):
def __init__(self, channels_in, channels_out, stride=1):
super(ConvBNRelu, self).__init__()
self.layers = nn.Sequential(
nn.Conv2d(channels_in, channels_out, 3, stride, padding=1),
nn.BatchNorm2d(channels_out),
nn.ReLU(inplace=True)
)
def forward(self, x):
return self.layers(x)
class Discriminator(nn.Module):
"""
Discriminator network. Receives an image and has to figure out whether it has a watermark inserted into it, or not.
"""
def __init__(self):
super(Discriminator, self).__init__()
self.config = {
"discriminator_blocks": 3,
"discriminator_channels": 64,
}
layers = [ConvBNRelu(3, self.config['discriminator_channels'])]
for _ in range(self.config['discriminator_blocks'] - 1):
layers.append(ConvBNRelu(self.config['discriminator_channels'], self.config['discriminator_channels']))
layers.append(nn.AdaptiveAvgPool2d(output_size=(1, 1)))
self.before_linear = nn.Sequential(*layers)
self.linear = nn.Linear(self.config['discriminator_channels'], 1)
def forward(self, image):
X = self.before_linear(image)
# 压缩维度, 去掉最后两个维度
X.squeeze_(3).squeeze_(2)
X = self.linear(X)
# X = torch.sigmoid(X)
return X
# 读取图片并转换为Tensor格式
image = Image.open("./data/cover.jpg")
transform = transforms.ToTensor()
image = transform(image)
image = image.unsqueeze(0)
print("image.shape: ", image.shape)
# 创建模型并进行前向传播
model = Discriminator()
Discriminator_Result = model(image)
# 打印输出张量的形状
print("Discriminator_Result.shape: ", Discriminator_Result.shape)
print("Discriminator_Result: ", Discriminator_Result)
运行结果如下图所示: