手动复现论文Dynamic Inconsistency Learning for DeepFake Video Detection

论文Delving into the Local: Dynamic Inconsistency Learning for DeepFake Video
Detection(以下简称DILD)是STIL(Spatiotemporal Inconsistency Learning for DeepFake Video
Detection)升级版, 同一作者, 两篇论文都来自腾讯优图实验室, 其中STIL有公开源码, 但DILD没有, 详细阅读原文后, 觉得是一个很好的idea, 因此手动复现了论文模型

STIL(论文源码):
https://github.com/Tencent/TFace/tree/master/security/tasks/Face-Forgery-Detection/STIL

DILD论文:
https://www.semanticscholar.org/paper/Delving-into-the-Local%3A-Dynamic-Inconsistency-for-Gu-Chen/0f138a9b3d90f3874ca5b2b0dd25db13bc3ce32b

DILD(本人复现)
https://github.com/yblir/DILD

一 论文解读

在这里插入图片描述
视频人脸deepfake伪造,先把视频拆帧,对图片进行伪造,然后再把图片帧合成视频. 论文认为这种伪造方式不能在多个角度都贴合人脸,表现在视频中,即会出现视频帧的不连续性. 因此可以通过检测视频帧之间的差异来鉴定是否存在人脸伪造现象. 作者前一篇paper STIL也是这种观点, 是从这个视频中抽取8或16张不连续的视频帧进行检测, 作者认为视频帧间隔太大不利于捕获伪造痕迹,因此又发表这篇升级版paper, 思路为: 从整个视频中抽取不连续的u个视频段, 每个视频段由连续的t张视频帧组成, 通过这样的组合来检测是否存在伪造.

二 主要模块复现

  • Intra-SIM模块
    该模块作用是提取一个视频段内的伪造信息,将输入通道一分为二, 对一半通道进行位置信息提取(Intra-SMA), 另一半啥都不做, 直接concat在一起, 推测这样做目的是兼顾原始信息与伪造信息.
class IntraSIM(nn.Module):
    def __init__(self, u, t, ch_in, norm_layer):
        super(IntraSIM, self).__init__()

        self.u = u
        self.t = t

        ch_in_half = ch_in // 2
        self.conv1 = nn.Conv2d(ch_in, ch_in_half, kernel_size=1, bias=False)
        self.bn1 = norm_layer(ch_in_half)
        self.conv2 = nn.Conv2d(ch_in, ch_in_half, kernel_size=1, bias=False)
        self.bn2 = norm_layer(ch_in_half)

        self.avg_pool = nn.AdaptiveAvgPool2d(1)

        self.fc1 = nn.Linear(ch_in_half, ch_in_half)
        self.fc2 = nn.Linear(ch_in_half, ch_in_half)

        self.relu = nn.ReLU(inplace=True)

        self.intra_sma = IntraSMA(u, t, ch_in_half, norm_layer)

    def forward(self, x):
        """

        Args:
            u: 片段数量,t:每个片段图片数量
            x: u * c * t * h * w
        Returns:

        """
        # x1,x2 相对于x通道数减半
        i1 = self.relu(self.bn1(self.conv1(x)))
        i2 = self.relu(self.bn2(self.conv2(x)))
        i2_k = i2

        # i2的两条分支
        # x1分支接全连接层
        i2_k = self.avg_pool(i2_k)
        i2_k = i2_k.reshape(i2_k.shape[0], -1)
        i2_k = self.fc1(i2_k)
        i2_k = self.relu(i2_k)

        i2_k = self.fc2(i2_k)
        i2_k = torch.softmax(i2_k, dim=1)

        # x2分支进入sma模块
        x2 = self.intra_sma(i2)
        o2 = i2_k.unsqueeze(-1).unsqueeze(-1) * x2

        return torch.concat([i1, o2], dim=1)
  • IntraSMA模块
    该模块作用是提取段内连续帧之间的伪造痕迹, 具体操作为前一帧减去后一帧, 然后反过来再操作一遍,之后求平均值. 对这个模块, 有一点不太赞同, H,W两个方向卷积乘相当于注意力机制, 再乘上输入, 相当于对输入通道各通道的信息筛选, 最后再与输入信息相加? 不太认同这种操作, 因为会弱化各通道信息的强化筛选功能.
class IntraSMA(nn.Module):
    def __init__(self, u, t, ch_in, norm_layer):
        super(IntraSMA, self).__init__()

        self.u = u
        self.t = t

        # self.reduction=16
        self.reduced_channels = ch_in // 1
        self.conv1 = nn.Conv2d(ch_in, self.reduced_channels, kernel_size=1, padding=0, bias=False)
        self.bn1 = norm_layer(self.reduced_channels)
        self.relu = nn.ReLU(inplace=True)

        self.ch_in2 = self.reduced_channels // 1
        # 这里,应该是每张图片的输入通道
        self.conv2 = nn.Conv2d(self.ch_in2, self.ch_in2, kernel_size=3, padding=1, bias=False)

        # self.conv_ht1 = nn.Conv2d(self.ch_in2, self.ch_in2, kernel_size=3, padding=1, bias=False)
        self.conv_ht1=nn.Conv2d(self.ch_in2, self.ch_in2,
                                  kernel_size=(3, 1), padding=(1, 0), groups=self.ch_in2, bias=False)
        self.conv_ht2 = nn.Conv2d(self.ch_in2, self.ch_in2, kernel_size=1, padding=0, bias=False)

        # self.conv_wt1 = nn.Conv2d(self.ch_in2, self.ch_in2, kernel_size=3, padding=1, bias=False)
        self.conv_wt1=nn.Conv2d(self.ch_in2, self.ch_in2,
                                  kernel_size=(3, 1), padding=(1, 0), groups=self.ch_in2, bias=False)
        self.conv_wt2 = nn.Conv2d(self.ch_in2, self.ch_in2, kernel_size=1, padding=0, bias=False)

        # self.conv_ht3 = nn.Conv2d(self.ch_in2, self.ch_in2, kernel_size=1, padding=0, bias=False)
        # self.conv_wt3 = nn.Conv2d(self.ch_in2, self.ch_in2, kernel_size=1, padding=0, bias=False)

    def reshape_feat(self, feat_):
        """

        Args:
            feat: shape=n,c,h,w, n=b*u*t

        Returns:

        """
        feat = feat_.reshape((-1, self.t) + feat_.shape[1:])
        # u分段数 t每段的图片数量
        bu, t, c, h, w = feat.shape

        # 使得每个分段的图片次序在首位
        # t,bu,c,h,w
        feat = feat.permute(1, 0, 2, 3, 4).contiguous()
        t_list = []

        # =====================================================================================================
        for i in range(t):
            if i == t - 1:
                break
            diff_feat = feat[i] - self.conv2(feat[i + 1])
            t_list.append(diff_feat)

        feat_stack = torch.stack(t_list, dim=0)
        t1, bu, c_, h_, w_ = feat_stack.shape
        # bu*w,c,h,t
        diff_h = feat_stack.permute(1, 4, 2, 3, 0).contiguous().reshape(-1, c_, h_, t1)
        diff_h = self.conv_ht2(self.conv_ht1(diff_h) + diff_h)
        # bu*h,c,t,w
        diff_w = feat_stack.permute(1, 3, 2, 0, 4).contiguous().reshape(-1, c_, t1, w_)
        diff_w = self.conv_wt2(self.conv_wt1(diff_w) + diff_w)

        diff_h = torch.sigmoid(torch.mean(diff_h, dim=-1, keepdim=True)).reshape(-1, w_, c_, h_, 1)
        diff_h = diff_h.permute(0, 4, 2, 3, 1).contiguous()

        diff_w = torch.sigmoid(torch.mean(diff_w, dim=-2, keepdim=True)).reshape(-1, h_, c_, 1, w_)
        diff_w = diff_w.permute(0, 3, 2, 1, 4).contiguous()

        # ====================================================================================================
        t_list2 = []
        for i in range(t):
            if i == t - 1:
                break
            diff_feat = feat[i + 1] - self.conv2(feat[i])
            t_list2.append(diff_feat)

        feat_stack2 = torch.stack(t_list, dim=0)
        t1, bu, c_, h_, w_ = feat_stack2.shape
        # bu*w,c,h,t
        diff_h2 = feat_stack2.permute(1, 4, 2, 3, 0).contiguous().reshape(-1, c_, h_, t1)
        diff_h2 = self.conv_ht2(self.conv_ht1(diff_h2) + diff_h2)
        # bu*h,c,t,w
        diff_w2 = feat_stack.permute(1, 3, 2, 0, 4).contiguous().reshape(-1, c_, t1, w_)
        diff_w2 = self.conv_wt2(self.conv_wt1(diff_w2) + diff_w2)

        diff_h2 = torch.sigmoid(torch.mean(diff_h2, dim=-1, keepdim=True)).reshape(-1, w_, c_, h_, 1)
        diff_h2 = diff_h2.permute(0, 4, 2, 3, 1).contiguous()

        diff_w2 = torch.sigmoid(torch.mean(diff_w2, dim=-2, keepdim=True)).reshape(-1, h_, c_, 1, w_)
        diff_w2 = diff_w2.permute(0, 3, 2, 1, 4).contiguous()

        diff_h = (diff_h + diff_h2) / 2
        diff_w = (diff_w + diff_w2) / 2

        # (b*u,1,c,h,w), 1是指每个分段所有特征的平均值
        return diff_h, diff_w
      
    def forward(self, x2):
        x2 = self.relu(self.bn1(self.conv1(x2)))
        diff_h, diff_w = self.reshape_feat(x2)
        but, c, h, w = x2.shape
        x2_ = x2.reshape(-1, self.t, c, h, w)

        sma1 = diff_h * diff_w * x2_
        sma1 = sma1.reshape(but, c, h, w)

        sma = sma1 + x2

        return sma

  • Inter-SIM
    上述Intra-SIM捕获了段内伪造信息, 在不同分段间也存在不连续的伪造痕迹, 这部分信息通过当前模块捕获, 该模块是个信息整合模块, 真正捕获信息是在Inter-SMA中进行.
class InterSIM(nn.Module):
    def __init__(self, u, t, ch_in, norm_layer):
        self.u = u
        self.t = t
        super(InterSIM, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.reduced_channels = ch_in // 1

        self.conv3x1 = nn.Conv2d(ch_in, self.reduced_channels, kernel_size=(3, 1), padding=(1, 0), bias=False)
        self.bn1 = norm_layer(self.reduced_channels)

        self.conv1x1 = nn.Conv2d(self.reduced_channels, self.reduced_channels, kernel_size=1, padding=0, bias=False)
        self.bn2 = norm_layer(self.reduced_channels)

        self.inter_sma = InterSMA(u, t, self.reduced_channels, norm_layer)

        self.relu = nn.ReLU(inplace=True)

        self.conv_u = nn.Conv2d(ch_in, ch_in, kernel_size=(3, 1), padding=(1, 0), bias=False)

    def forward(self, x):
        # but, c, h, w = x.shape
        # bu,t,c,h,w
        x_ = x.reshape((-1, self.t) + x.shape[1:])
        bu, t, c, h, w = x_.shape
        x_ = x_.reshape(-1, self.u, t, c, h, w)
        # b, u, t, c, h, w = x_.shape
        # todo b,t,c,u,h,w 模块输入shape, 用于点乘
        feat_raw = x_.permute(0, 2, 3, 1, 4, 5).contiguous()
        # b,t,c,u,1,1
        feat = self.avg_pool(feat_raw)
        # b,t,c,u
        feat = feat.squeeze().squeeze()
        # b,c,u,t
        feat = feat.permute(0, 2, 3, 1).contiguous()

        # 分支1
        # x1 = feat.squeeze().squeeze()
        # b,c,u,t
        x1 = self.bn1(self.conv3x1(feat))  # 2, 32,2,4
        # 没有bn层
        x1 = torch.sigmoid(self.conv1x1(x1))

        # 分支2
        # b, t, c, u, 1, 1
        # x2 = feat.squeeze(-1)
        # b,u,c,t,1
        # x2 = x2.permute(0, 3, 2, 1, 4).contiguous()
        # bu,c,t,1
        # x2 = x2.reshape(-1, c, t, 1)

        # b,u,c,t
        feat_2 = feat.permute(0, 2, 1, 3).contiguous()
        # b,u,c,t,1 -> bu,c,t,1
        b, u, c, t = feat_2.shape
        feat_2 = feat_2.squeeze(-1).reshape(-1, c, t, 1)
        # b,c,1,t
        x2 = self.inter_sma(feat_2)
        # b,c,u,t
        x12 = x1 * x2
        # b,t,c,u
        x12 = x12.permute(0, 3, 1, 2).contiguous()
        # b,t,c,u,1,1
        x12 = x12.unsqueeze(-1).unsqueeze(-1)
        # x2_avg = self.avg_pool(x1 * x2)
        # b,t,c,u,h,w
        x_merge = x12 * feat_raw + feat_raw

        x_return = x_merge.permute(0, 3, 1, 2, 4, 5).contiguous().reshape(-1, c, h, w)
        x_return = self.relu(self.bn2(self.conv_u(x_return)))

        return x_return
  • Inter-SMA
    这是真正捕获段鉴伪造痕迹的模块,结构看起来与Inter-SMA类型, 但具体写在代码层面, 就有很多种解读. 论文仅给出上述模型结构图与计算公式, 关键论文说明与结构图,有些细节还对不上号, 那么怎么写代码就靠自己领悟…, 这里我认为是每个分段的t张图片特征与前一个或后一个t张图片特征相减, 提取伪造痕迹.
    在这里插入图片描述
class InterSMA(nn.Module):
    def __init__(self, u, t, ch_in, norm_layer):
        super(InterSMA, self).__init__()
        self.u = u
        self.t = t

        ch_in_half = ch_in // 2
        self.conv1 = nn.Conv2d(ch_in, ch_in_half, kernel_size=1, bias=False)
        self.bn1 = norm_layer(ch_in_half)

        # self.conv2 = nn.Conv2d(ch_in_half, ch_in_half, kernel_size=(1, 3), padding=(0, 1), bias=False)
        self.conv2 = nn.Conv2d(ch_in_half, ch_in_half, kernel_size=(3, 1), padding=(1, 0), bias=False)
        # self.bn2 = norm_layer(ch_in_half)

        self.conv3 = nn.Conv2d(ch_in_half, ch_in, kernel_size=1, bias=False)
        # self.avg_pool = nn.AdaptiveAvgPool2d(1)

        # self.fc1 = nn.Linear(512 * block.expansion, num_classes * 64)
        # self.fc2 = nn.Linear(num_classes * 64, num_classes)

        self.relu = nn.ReLU(inplace=True)

    def reshape_feat(self, feat_):
        """
        Args:
            feat: shape=n,c,h,w, n=b*u*t
        Returns:
        """
        feat = feat_.reshape((-1, self.t) + feat_.shape[1:])
        # u分段数 t每段的图片数量
        # b,u,c,t,1
        # b, u, c, t, n = feat_.shape
        # x2 = feat_.reshape(-1, self.u, c, t, n)
        # =================bt,u,c,1

        # b, u, c, t, n = x2.shape
        # u,b,c,t,n
        x2 = feat_.permute(1, 0, 2, 3, 4).contiguous()
        # bu, c, t, n = x2.shape

        u_list = []
        for i in range(u):
            if i == u - 1:
                break
            diff_feat = x2[i] - self.conv2(x2[i + 1])
            u_list.append(diff_feat)

        diff_avg = sum(u_list) / len(u_list)

        diff_avg = torch.sigmoid(self.conv3(diff_avg))

        return diff_avg

    def reshape_feat2(self, feat_):
        """
        Args:
            feat: shape=n,c,h,w, n=b*u*t
        Returns:
        """
        # b, c, u, t = feat_.shape

        # bu, c, t, 1
        bu, c, t, n = feat_.shape
        feat_ = feat_.reshape(-1, self.u, c, t, n)
        b, u, c, t, n = feat_.shape
        # u,b,c,t,1
        feat_ = feat_.permute(1, 0, 2, 3, 4).contiguous()

        # feat = feat_.reshape((-1, self.t) + feat_.shape[1:])
        # bu,c,t
        # feat_new = feat_.squeeze(-1).reshape(-1, self.u, c, t)
        # b, u, c, t = feat_new.shape
        # b,c,u,t
        # x2 = feat_.permute(2, 0, 1, 3).contiguous()
        # u,b,c,t,1
        # x2 = x2.unsqueeze(-1)

        u_list = []
        for i in range(u):
            if i == u - 1:
                break
            diff_feat = feat_[i] - self.conv2(feat_[i + 1])
            u_list.append(diff_feat)
        # u-1,b,c,t,1
        diff_u = torch.stack(u_list, dim=0)
        # b,c,u,t,1
        diff_u = diff_u.permute(1, 2, 0, 3, 4).contiguous()
        # b,c,u,t
        diff_u = diff_u.squeeze(-1)
        # b,c,1,t
        diff_u = torch.mean(diff_u, dim=-2, keepdim=True)

        u_list2 = []
        for i in range(u):
            if i == u - 1:
                break
            diff_feat2 = feat_[i + 1] - self.conv2(feat_[i])
            u_list2.append(diff_feat2)
        # u-1,b,c,t,1
        diff_u2 = torch.stack(u_list2, dim=0)
        # b,c,u,t,1
        diff_u2 = diff_u2.permute(1, 2, 0, 3, 4).contiguous()
        # b,c,u,t
        diff_u2 = diff_u2.squeeze(-1)
        # b,c,1,t
        diff_u2 = torch.mean(diff_u2, dim=-2, keepdim=True)

        diff_u = (diff_u + diff_u2) / 2

        # b,c,t,1
        # diff_u = sum(u_list) / len(u_list)
        # u-1,b,c,t,1
        # diff_u = torch.stack(u_list, dim=0)
        # diff_avg = torch.sigmoid(self.conv3(diff_avg))

        return diff_u

    def reshape_feat3(self, feat_):
        """
        Args:
            feat: shape=n,c,h,w, n=b*u*t
        Returns:
        """
        b, c, u, t = feat_.shape

        diff_u1 = feat_ - self.self.conv2(f)

        # bu, c, t, 1
        # bu, c, t, n = feat_.shape
        feat_ = feat_.reshape(-1, self.u, c, t, n)
        b, u, c, t, n = feat_.shape
        # u,b,c,t,1
        feat_ = feat_.permute(1, 0, 2, 3, 4).contiguous()

        # feat = feat_.reshape((-1, self.t) + feat_.shape[1:])
        # bu,c,t
        # feat_new = feat_.squeeze(-1).reshape(-1, self.u, c, t)
        # b, u, c, t = feat_new.shape
        # b,c,u,t
        # x2 = feat_.permute(2, 0, 1, 3).contiguous()
        # u,b,c,t,1
        # x2 = x2.unsqueeze(-1)

        u_list = []
        for i in range(u):
            if i == u - 1:
                break
            diff_feat = feat_[i] - self.conv2(feat_[i + 1])
            u_list.append(diff_feat)
        # b,c,t,1
        diff_u = sum(u_list) / len(u_list)
        # u-1,b,c,t,1
        # diff_u = torch.stack(u_list, dim=0)
        # diff_avg = torch.sigmoid(self.conv3(diff_avg))

        return diff_u

    def forward(self, x2):
        # bu,c,t,1
        # b,c,u,t
        x2 = self.relu(self.bn1(self.conv1(x2)))
        # 通道变为c//4
        # diff_avg = self.reshape_feat(x2)
        diff_u = self.reshape_feat2(x2)
        # sma = diff_w * diff_w * x2 + x2
        diff_u = torch.sigmoid(self.conv3(diff_u))
        # return diff_avg
        # b, c, 1, t
        return diff_u
评论 10
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值