论文Delving into the Local: Dynamic Inconsistency Learning for DeepFake Video
Detection(以下简称DILD)是STIL(Spatiotemporal Inconsistency Learning for DeepFake Video
Detection)升级版, 同一作者, 两篇论文都来自腾讯优图实验室, 其中STIL有公开源码, 但DILD没有, 详细阅读原文后, 觉得是一个很好的idea, 因此手动复现了论文模型
STIL(论文源码):
https://github.com/Tencent/TFace/tree/master/security/tasks/Face-Forgery-Detection/STIL
DILD(本人复现)
https://github.com/yblir/DILD
一 论文解读
视频人脸deepfake伪造,先把视频拆帧,对图片进行伪造,然后再把图片帧合成视频. 论文认为这种伪造方式不能在多个角度都贴合人脸,表现在视频中,即会出现视频帧的不连续性. 因此可以通过检测视频帧之间的差异来鉴定是否存在人脸伪造现象. 作者前一篇paper STIL也是这种观点, 是从这个视频中抽取8或16张不连续的视频帧进行检测, 作者认为视频帧间隔太大不利于捕获伪造痕迹,因此又发表这篇升级版paper, 思路为: 从整个视频中抽取不连续的u个视频段, 每个视频段由连续的t张视频帧组成, 通过这样的组合来检测是否存在伪造.
二 主要模块复现
- Intra-SIM模块
该模块作用是提取一个视频段内的伪造信息,将输入通道一分为二, 对一半通道进行位置信息提取(Intra-SMA), 另一半啥都不做, 直接concat在一起, 推测这样做目的是兼顾原始信息与伪造信息.
class IntraSIM(nn.Module):
def __init__(self, u, t, ch_in, norm_layer):
super(IntraSIM, self).__init__()
self.u = u
self.t = t
ch_in_half = ch_in // 2
self.conv1 = nn.Conv2d(ch_in, ch_in_half, kernel_size=1, bias=False)
self.bn1 = norm_layer(ch_in_half)
self.conv2 = nn.Conv2d(ch_in, ch_in_half, kernel_size=1, bias=False)
self.bn2 = norm_layer(ch_in_half)
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.fc1 = nn.Linear(ch_in_half, ch_in_half)
self.fc2 = nn.Linear(ch_in_half, ch_in_half)
self.relu = nn.ReLU(inplace=True)
self.intra_sma = IntraSMA(u, t, ch_in_half, norm_layer)
def forward(self, x):
"""
Args:
u: 片段数量,t:每个片段图片数量
x: u * c * t * h * w
Returns:
"""
# x1,x2 相对于x通道数减半
i1 = self.relu(self.bn1(self.conv1(x)))
i2 = self.relu(self.bn2(self.conv2(x)))
i2_k = i2
# i2的两条分支
# x1分支接全连接层
i2_k = self.avg_pool(i2_k)
i2_k = i2_k.reshape(i2_k.shape[0], -1)
i2_k = self.fc1(i2_k)
i2_k = self.relu(i2_k)
i2_k = self.fc2(i2_k)
i2_k = torch.softmax(i2_k, dim=1)
# x2分支进入sma模块
x2 = self.intra_sma(i2)
o2 = i2_k.unsqueeze(-1).unsqueeze(-1) * x2
return torch.concat([i1, o2], dim=1)
- IntraSMA模块
该模块作用是提取段内连续帧之间的伪造痕迹, 具体操作为前一帧减去后一帧, 然后反过来再操作一遍,之后求平均值. 对这个模块, 有一点不太赞同, H,W两个方向卷积乘相当于注意力机制, 再乘上输入, 相当于对输入通道各通道的信息筛选, 最后再与输入信息相加? 不太认同这种操作, 因为会弱化各通道信息的强化筛选功能.
class IntraSMA(nn.Module):
def __init__(self, u, t, ch_in, norm_layer):
super(IntraSMA, self).__init__()
self.u = u
self.t = t
# self.reduction=16
self.reduced_channels = ch_in // 1
self.conv1 = nn.Conv2d(ch_in, self.reduced_channels, kernel_size=1, padding=0, bias=False)
self.bn1 = norm_layer(self.reduced_channels)
self.relu = nn.ReLU(inplace=True)
self.ch_in2 = self.reduced_channels // 1
# 这里,应该是每张图片的输入通道
self.conv2 = nn.Conv2d(self.ch_in2, self.ch_in2, kernel_size=3, padding=1, bias=False)
# self.conv_ht1 = nn.Conv2d(self.ch_in2, self.ch_in2, kernel_size=3, padding=1, bias=False)
self.conv_ht1=nn.Conv2d(self.ch_in2, self.ch_in2,
kernel_size=(3, 1), padding=(1, 0), groups=self.ch_in2, bias=False)
self.conv_ht2 = nn.Conv2d(self.ch_in2, self.ch_in2, kernel_size=1, padding=0, bias=False)
# self.conv_wt1 = nn.Conv2d(self.ch_in2, self.ch_in2, kernel_size=3, padding=1, bias=False)
self.conv_wt1=nn.Conv2d(self.ch_in2, self.ch_in2,
kernel_size=(3, 1), padding=(1, 0), groups=self.ch_in2, bias=False)
self.conv_wt2 = nn.Conv2d(self.ch_in2, self.ch_in2, kernel_size=1, padding=0, bias=False)
# self.conv_ht3 = nn.Conv2d(self.ch_in2, self.ch_in2, kernel_size=1, padding=0, bias=False)
# self.conv_wt3 = nn.Conv2d(self.ch_in2, self.ch_in2, kernel_size=1, padding=0, bias=False)
def reshape_feat(self, feat_):
"""
Args:
feat: shape=n,c,h,w, n=b*u*t
Returns:
"""
feat = feat_.reshape((-1, self.t) + feat_.shape[1:])
# u分段数 t每段的图片数量
bu, t, c, h, w = feat.shape
# 使得每个分段的图片次序在首位
# t,bu,c,h,w
feat = feat.permute(1, 0, 2, 3, 4).contiguous()
t_list = []
# =====================================================================================================
for i in range(t):
if i == t - 1:
break
diff_feat = feat[i] - self.conv2(feat[i + 1])
t_list.append(diff_feat)
feat_stack = torch.stack(t_list, dim=0)
t1, bu, c_, h_, w_ = feat_stack.shape
# bu*w,c,h,t
diff_h = feat_stack.permute(1, 4, 2, 3, 0).contiguous().reshape(-1, c_, h_, t1)
diff_h = self.conv_ht2(self.conv_ht1(diff_h) + diff_h)
# bu*h,c,t,w
diff_w = feat_stack.permute(1, 3, 2, 0, 4).contiguous().reshape(-1, c_, t1, w_)
diff_w = self.conv_wt2(self.conv_wt1(diff_w) + diff_w)
diff_h = torch.sigmoid(torch.mean(diff_h, dim=-1, keepdim=True)).reshape(-1, w_, c_, h_, 1)
diff_h = diff_h.permute(0, 4, 2, 3, 1).contiguous()
diff_w = torch.sigmoid(torch.mean(diff_w, dim=-2, keepdim=True)).reshape(-1, h_, c_, 1, w_)
diff_w = diff_w.permute(0, 3, 2, 1, 4).contiguous()
# ====================================================================================================
t_list2 = []
for i in range(t):
if i == t - 1:
break
diff_feat = feat[i + 1] - self.conv2(feat[i])
t_list2.append(diff_feat)
feat_stack2 = torch.stack(t_list, dim=0)
t1, bu, c_, h_, w_ = feat_stack2.shape
# bu*w,c,h,t
diff_h2 = feat_stack2.permute(1, 4, 2, 3, 0).contiguous().reshape(-1, c_, h_, t1)
diff_h2 = self.conv_ht2(self.conv_ht1(diff_h2) + diff_h2)
# bu*h,c,t,w
diff_w2 = feat_stack.permute(1, 3, 2, 0, 4).contiguous().reshape(-1, c_, t1, w_)
diff_w2 = self.conv_wt2(self.conv_wt1(diff_w2) + diff_w2)
diff_h2 = torch.sigmoid(torch.mean(diff_h2, dim=-1, keepdim=True)).reshape(-1, w_, c_, h_, 1)
diff_h2 = diff_h2.permute(0, 4, 2, 3, 1).contiguous()
diff_w2 = torch.sigmoid(torch.mean(diff_w2, dim=-2, keepdim=True)).reshape(-1, h_, c_, 1, w_)
diff_w2 = diff_w2.permute(0, 3, 2, 1, 4).contiguous()
diff_h = (diff_h + diff_h2) / 2
diff_w = (diff_w + diff_w2) / 2
# (b*u,1,c,h,w), 1是指每个分段所有特征的平均值
return diff_h, diff_w
def forward(self, x2):
x2 = self.relu(self.bn1(self.conv1(x2)))
diff_h, diff_w = self.reshape_feat(x2)
but, c, h, w = x2.shape
x2_ = x2.reshape(-1, self.t, c, h, w)
sma1 = diff_h * diff_w * x2_
sma1 = sma1.reshape(but, c, h, w)
sma = sma1 + x2
return sma
- Inter-SIM
上述Intra-SIM捕获了段内伪造信息, 在不同分段间也存在不连续的伪造痕迹, 这部分信息通过当前模块捕获, 该模块是个信息整合模块, 真正捕获信息是在Inter-SMA中进行.
class InterSIM(nn.Module):
def __init__(self, u, t, ch_in, norm_layer):
self.u = u
self.t = t
super(InterSIM, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.reduced_channels = ch_in // 1
self.conv3x1 = nn.Conv2d(ch_in, self.reduced_channels, kernel_size=(3, 1), padding=(1, 0), bias=False)
self.bn1 = norm_layer(self.reduced_channels)
self.conv1x1 = nn.Conv2d(self.reduced_channels, self.reduced_channels, kernel_size=1, padding=0, bias=False)
self.bn2 = norm_layer(self.reduced_channels)
self.inter_sma = InterSMA(u, t, self.reduced_channels, norm_layer)
self.relu = nn.ReLU(inplace=True)
self.conv_u = nn.Conv2d(ch_in, ch_in, kernel_size=(3, 1), padding=(1, 0), bias=False)
def forward(self, x):
# but, c, h, w = x.shape
# bu,t,c,h,w
x_ = x.reshape((-1, self.t) + x.shape[1:])
bu, t, c, h, w = x_.shape
x_ = x_.reshape(-1, self.u, t, c, h, w)
# b, u, t, c, h, w = x_.shape
# todo b,t,c,u,h,w 模块输入shape, 用于点乘
feat_raw = x_.permute(0, 2, 3, 1, 4, 5).contiguous()
# b,t,c,u,1,1
feat = self.avg_pool(feat_raw)
# b,t,c,u
feat = feat.squeeze().squeeze()
# b,c,u,t
feat = feat.permute(0, 2, 3, 1).contiguous()
# 分支1
# x1 = feat.squeeze().squeeze()
# b,c,u,t
x1 = self.bn1(self.conv3x1(feat)) # 2, 32,2,4
# 没有bn层
x1 = torch.sigmoid(self.conv1x1(x1))
# 分支2
# b, t, c, u, 1, 1
# x2 = feat.squeeze(-1)
# b,u,c,t,1
# x2 = x2.permute(0, 3, 2, 1, 4).contiguous()
# bu,c,t,1
# x2 = x2.reshape(-1, c, t, 1)
# b,u,c,t
feat_2 = feat.permute(0, 2, 1, 3).contiguous()
# b,u,c,t,1 -> bu,c,t,1
b, u, c, t = feat_2.shape
feat_2 = feat_2.squeeze(-1).reshape(-1, c, t, 1)
# b,c,1,t
x2 = self.inter_sma(feat_2)
# b,c,u,t
x12 = x1 * x2
# b,t,c,u
x12 = x12.permute(0, 3, 1, 2).contiguous()
# b,t,c,u,1,1
x12 = x12.unsqueeze(-1).unsqueeze(-1)
# x2_avg = self.avg_pool(x1 * x2)
# b,t,c,u,h,w
x_merge = x12 * feat_raw + feat_raw
x_return = x_merge.permute(0, 3, 1, 2, 4, 5).contiguous().reshape(-1, c, h, w)
x_return = self.relu(self.bn2(self.conv_u(x_return)))
return x_return
- Inter-SMA
这是真正捕获段鉴伪造痕迹的模块,结构看起来与Inter-SMA类型, 但具体写在代码层面, 就有很多种解读. 论文仅给出上述模型结构图与计算公式, 关键论文说明与结构图,有些细节还对不上号, 那么怎么写代码就靠自己领悟…, 这里我认为是每个分段的t张图片特征与前一个或后一个t张图片特征相减, 提取伪造痕迹.
class InterSMA(nn.Module):
def __init__(self, u, t, ch_in, norm_layer):
super(InterSMA, self).__init__()
self.u = u
self.t = t
ch_in_half = ch_in // 2
self.conv1 = nn.Conv2d(ch_in, ch_in_half, kernel_size=1, bias=False)
self.bn1 = norm_layer(ch_in_half)
# self.conv2 = nn.Conv2d(ch_in_half, ch_in_half, kernel_size=(1, 3), padding=(0, 1), bias=False)
self.conv2 = nn.Conv2d(ch_in_half, ch_in_half, kernel_size=(3, 1), padding=(1, 0), bias=False)
# self.bn2 = norm_layer(ch_in_half)
self.conv3 = nn.Conv2d(ch_in_half, ch_in, kernel_size=1, bias=False)
# self.avg_pool = nn.AdaptiveAvgPool2d(1)
# self.fc1 = nn.Linear(512 * block.expansion, num_classes * 64)
# self.fc2 = nn.Linear(num_classes * 64, num_classes)
self.relu = nn.ReLU(inplace=True)
def reshape_feat(self, feat_):
"""
Args:
feat: shape=n,c,h,w, n=b*u*t
Returns:
"""
feat = feat_.reshape((-1, self.t) + feat_.shape[1:])
# u分段数 t每段的图片数量
# b,u,c,t,1
# b, u, c, t, n = feat_.shape
# x2 = feat_.reshape(-1, self.u, c, t, n)
# =================bt,u,c,1
# b, u, c, t, n = x2.shape
# u,b,c,t,n
x2 = feat_.permute(1, 0, 2, 3, 4).contiguous()
# bu, c, t, n = x2.shape
u_list = []
for i in range(u):
if i == u - 1:
break
diff_feat = x2[i] - self.conv2(x2[i + 1])
u_list.append(diff_feat)
diff_avg = sum(u_list) / len(u_list)
diff_avg = torch.sigmoid(self.conv3(diff_avg))
return diff_avg
def reshape_feat2(self, feat_):
"""
Args:
feat: shape=n,c,h,w, n=b*u*t
Returns:
"""
# b, c, u, t = feat_.shape
# bu, c, t, 1
bu, c, t, n = feat_.shape
feat_ = feat_.reshape(-1, self.u, c, t, n)
b, u, c, t, n = feat_.shape
# u,b,c,t,1
feat_ = feat_.permute(1, 0, 2, 3, 4).contiguous()
# feat = feat_.reshape((-1, self.t) + feat_.shape[1:])
# bu,c,t
# feat_new = feat_.squeeze(-1).reshape(-1, self.u, c, t)
# b, u, c, t = feat_new.shape
# b,c,u,t
# x2 = feat_.permute(2, 0, 1, 3).contiguous()
# u,b,c,t,1
# x2 = x2.unsqueeze(-1)
u_list = []
for i in range(u):
if i == u - 1:
break
diff_feat = feat_[i] - self.conv2(feat_[i + 1])
u_list.append(diff_feat)
# u-1,b,c,t,1
diff_u = torch.stack(u_list, dim=0)
# b,c,u,t,1
diff_u = diff_u.permute(1, 2, 0, 3, 4).contiguous()
# b,c,u,t
diff_u = diff_u.squeeze(-1)
# b,c,1,t
diff_u = torch.mean(diff_u, dim=-2, keepdim=True)
u_list2 = []
for i in range(u):
if i == u - 1:
break
diff_feat2 = feat_[i + 1] - self.conv2(feat_[i])
u_list2.append(diff_feat2)
# u-1,b,c,t,1
diff_u2 = torch.stack(u_list2, dim=0)
# b,c,u,t,1
diff_u2 = diff_u2.permute(1, 2, 0, 3, 4).contiguous()
# b,c,u,t
diff_u2 = diff_u2.squeeze(-1)
# b,c,1,t
diff_u2 = torch.mean(diff_u2, dim=-2, keepdim=True)
diff_u = (diff_u + diff_u2) / 2
# b,c,t,1
# diff_u = sum(u_list) / len(u_list)
# u-1,b,c,t,1
# diff_u = torch.stack(u_list, dim=0)
# diff_avg = torch.sigmoid(self.conv3(diff_avg))
return diff_u
def reshape_feat3(self, feat_):
"""
Args:
feat: shape=n,c,h,w, n=b*u*t
Returns:
"""
b, c, u, t = feat_.shape
diff_u1 = feat_ - self.self.conv2(f)
# bu, c, t, 1
# bu, c, t, n = feat_.shape
feat_ = feat_.reshape(-1, self.u, c, t, n)
b, u, c, t, n = feat_.shape
# u,b,c,t,1
feat_ = feat_.permute(1, 0, 2, 3, 4).contiguous()
# feat = feat_.reshape((-1, self.t) + feat_.shape[1:])
# bu,c,t
# feat_new = feat_.squeeze(-1).reshape(-1, self.u, c, t)
# b, u, c, t = feat_new.shape
# b,c,u,t
# x2 = feat_.permute(2, 0, 1, 3).contiguous()
# u,b,c,t,1
# x2 = x2.unsqueeze(-1)
u_list = []
for i in range(u):
if i == u - 1:
break
diff_feat = feat_[i] - self.conv2(feat_[i + 1])
u_list.append(diff_feat)
# b,c,t,1
diff_u = sum(u_list) / len(u_list)
# u-1,b,c,t,1
# diff_u = torch.stack(u_list, dim=0)
# diff_avg = torch.sigmoid(self.conv3(diff_avg))
return diff_u
def forward(self, x2):
# bu,c,t,1
# b,c,u,t
x2 = self.relu(self.bn1(self.conv1(x2)))
# 通道变为c//4
# diff_avg = self.reshape_feat(x2)
diff_u = self.reshape_feat2(x2)
# sma = diff_w * diff_w * x2 + x2
diff_u = torch.sigmoid(self.conv3(diff_u))
# return diff_avg
# b, c, 1, t
return diff_u