代码解析—part3 训练ADM—CVPR2023—Implicit Identity Leakage: The Stumbling Block to Improving Deepfake

论文讲解请看:https://blog.csdn.net/JustWantToLearn/article/details/138758033
代码链接:https://github.com/megvii-research/CADDM
在这里,我们简要描述算法流程,着重分析模型搭建细节,以及为什么要这样搭建。
part 1:数据集准备,请看链接 https://blog.csdn.net/JustWantToLearn/article/details/138773005
part 2: 数据集加载,包含 Multi-scale Facial Swap(MFS) 模块:https://blog.csdn.net/JustWantToLearn/article/details/139092687
part 3:训练过程,ADM模块,本文

1、训练 train.py

python train.py --cfg ./configs/caddm_train.cfg

def train():
    args = args_func()

    # load conifigs
    cfg = load_config(args.cfg)

    # init model. 模型初始化
    net = model.get(backbone=cfg['model']['backbone'])
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    net = net.to(device)
    net = nn.DataParallel(net)

    # loss init loss初始化,多任务损失函数 MultiBoxLoss 和交叉熵损失函数 nn.CrossEntropyLoss
    det_criterion = MultiBoxLoss(
        cfg['det_loss']['num_classes'],
        cfg['det_loss']['overlap_thresh'],
        cfg['det_loss']['prior_for_matching'],
        cfg['det_loss']['bkg_label'],
        cfg['det_loss']['neg_mining'],
        cfg['det_loss']['neg_pos'],
        cfg['det_loss']['neg_overlap'],
        cfg['det_loss']['encode_target'],
        cfg['det_loss']['use_gpu']
    )
    criterion = nn.CrossEntropyLoss()

    # optimizer init.
    optimizer = optim.AdamW(net.parameters(), lr=1e-3, weight_decay=4e-3)

    # load checkpoint if given
    base_epoch = 0
    if args.ckpt:
        net, optimzer, base_epoch = load_checkpoint(args.ckpt, net, optimizer, device)

    # get training data 加载训练数据集
    print(f"Load deepfake dataset from {cfg['dataset']['img_path']}..")
    train_dataset = DeepfakeDataset('train', cfg)
    train_loader = DataLoader(train_dataset,
                              batch_size=cfg['train']['batch_size'],
                              shuffle=True, num_workers=4,
                              collate_fn=my_collate
                              )

    # start trining.进入训练模式,并循环遍历每个epoch和batch。在每个epoch开始时更新学习率
    net.train()
    for epoch in range(base_epoch, cfg['train']['epoch_num']):
        for index, (batch_data, batch_labels) in enumerate(train_loader):

            lr = update_learning_rate(epoch)
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr

            labels, location_labels, confidence_labels = batch_labels
            labels = labels.long().to(device)
            location_labels = location_labels.to(device)
            confidence_labels = confidence_labels.long().to(device)
            #计算分类损失和检测损失。然后计算总损失,并执行反向传播
            optimizer.zero_grad()
            locations, confidence, outputs = net(batch_data)
            loss_end_cls = criterion(outputs, labels)
            loss_l, loss_c = det_criterion(
                (locations, confidence),
                confidence_labels, location_labels
            )
            acc = sum(outputs.max(-1).indices == labels).item() / labels.shape[0]
            det_loss = 0.1 * (loss_l + loss_c)
            loss = det_loss + loss_end_cls
            loss.backward()
            # 梯度裁剪和优化器步
            torch.nn.utils.clip_grad_value_(net.parameters(), 2)
            optimizer.step()

            outputs = [
                "e:{},iter: {}".format(epoch, index),
                "acc: {:.2f}".format(acc),
                "loss: {:.8f} ".format(loss.item()),
                "lr:{:.4g}".format(lr),
            ]
            print(" ".join(outputs))
        save_checkpoint(net, optimizer,
                        cfg['model']['save_path'],
                        epoch)

2、损失函数 MultiBoxLoss

MultiBoxLoss 类实现了SSD模型的损失计算,包括位置损失和置信度损失。
这里大体解释每个函数模块做了什么,具体的实现细节可以看论文https://arxiv.org/pdf/1512.02325.pdf


class MultiBoxLoss(nn.Module):
    """SSD Weighted Loss Function
    Compute Targets:
        1) Produce Confidence Target Indices by matching  ground truth boxes
           with (default) 'priorboxes' that have jaccard index > threshold parameter
           (default threshold: 0.5).
        2) Produce localization target by 'encoding' variance into offsets of ground
           truth boxes and their matched  'priorboxes'.
        3) Hard negative mining to filter the excessive number of negative examples
           that comes with using a large number of default bounding boxes.
           (default negative:positive ratio 3:1)
    Objective Loss:
        L(x,c,l,g) = (Lconf(x, c) + αLloc(x,l,g)) / N
        Where, Lconf is the CrossEntropy Loss and Lloc is the SmoothL1 Loss
        weighted by α which is set to 1 by cross val.
        Args:
            c: class confidences,
            l: predicted boxes,
            g: ground truth boxes
            N: number of matched default boxes
        See: https://arxiv.org/pdf/1512.02325.pdf for more details.
    """

    def __init__(self, num_classes, overlap_thresh, prior_for_matching,
                 bkg_label, neg_mining, neg_pos, neg_overlap, encode_target,
                 use_gpu=True):
        super(MultiBoxLoss, self).__init__()
        self.use_gpu = use_gpu
        self.num_classes = num_classes
        self.threshold = overlap_thresh
        self.background_label = bkg_label
        self.encode_target = encode_target
        self.use_prior_for_matching = prior_for_matching
        self.do_neg_mining = neg_mining
        self.negpos_ratio = neg_pos
        self.neg_overlap = neg_overlap
        self.variance = [0.1, 0.2]  # cfg['variance']

    # def forward(self, predictions, targets):
    def forward(self, predictions, conf_t, loc_t):
        """Multibox Loss
        Args:
            predictions (tuple): A tuple containing loc preds, conf preds,
            and prior boxes from SSD net.
                conf shape: torch.size(batch_size,num_priors,num_classes)
                loc shape: torch.size(batch_size,num_priors,4)
                priors shape: torch.size(num_priors,4)

            targets (tensor): Ground truth boxes and labels for a batch,
                shape: [batch_size,num_objs,5] (last idx is the label).
        """
        '''
        priors = priors[:loc_data.size(1), :]
        num_priors = (priors.size(0))
        num_classes = self.num_classes

        # match priors (default boxes) and ground truth boxes
        loc_t = torch.Tensor(num, num_priors, 4)
        conf_t = torch.LongTensor(num, num_priors)
        for idx in range(num):
            truths = targets[idx][:, :-1].data
            labels = targets[idx][:, -1].data
            defaults = priors.data
            match(self.threshold, truths, defaults, self.variance, labels,
                  loc_t, conf_t, idx)
        '''
        #predictions:模型的预测输出,包括位置预测和置信度预测。
        loc_data, conf_data = predictions
        num = loc_data.size(0)
        if self.use_gpu:
        #conf_t:置信度目标,loc_t:位置目标。
            loc_t = loc_t.cuda()
            conf_t = conf_t.cuda()
        # wrap targets 将目标数据封装为 Variable,并设置 requires_grad=False 以防止计算梯度。
        loc_t = Variable(loc_t, requires_grad=False)
        conf_t = Variable(conf_t, requires_grad=False)
        #计算正样本的位置和数量。
        pos = conf_t > 0
        num_pos = pos.sum(dim=1, keepdim=True)

        # Localization Loss (Smooth L1)
        # Shape: [batch,num_priors,4] 使用正样本的位置数据计算Smooth L1损失
        pos_idx = pos.unsqueeze(pos.dim()).expand_as(loc_data)
        loc_p = loc_data[pos_idx].view(-1, 4)
        loc_t = loc_t[pos_idx].view(-1, 4)
        loss_l = F.smooth_l1_loss(loc_p, loc_t, size_average=False)

        # Compute max conf across batch for hard negative mining
        # 计算置信度损失,包括硬负样本挖掘,保证正负样本比例合理
        batch_conf = conf_data.view(-1, self.num_classes)
        loss_c = log_sum_exp(batch_conf) - batch_conf.gather(1, conf_t.view(-1, 1))

        # Hard Negative Mining
        # loss_c[pos] = 0  # filter out pos boxes for now
        loss_c[pos.view(-1, 1)] = 0
        loss_c = loss_c.view(num, -1)
        _, loss_idx = loss_c.sort(1, descending=True)
        _, idx_rank = loss_idx.sort(1)
        num_pos = pos.long().sum(1, keepdim=True)
        num_neg = torch.clamp(self.negpos_ratio*num_pos, max=pos.size(1)-1)
        neg = idx_rank < num_neg.expand_as(idx_rank)

        # Confidence Loss Including Positive and Negative Examples
        pos_idx = pos.unsqueeze(2).expand_as(conf_data)
        neg_idx = neg.unsqueeze(2).expand_as(conf_data)
        conf_p = conf_data[(pos_idx+neg_idx).gt(0)].view(-1, self.num_classes)
        targets_weighted = conf_t[(pos+neg).gt(0)]
        loss_c = F.cross_entropy(conf_p, targets_weighted, size_average=False)

        # Sum of losses: L(x,c,l,g) = (Lconf(x, c) + αLloc(x,l,g)) / N
        # 归一化位置损失和置信度损失,然后返回
        N = num_pos.data.sum() if num_pos.data.sum() else 1
        loss_l /= N
        loss_c /= N
        return loss_l, loss_c

3、模型搭建 CADDM

def get(pretrained_model=None, backbone='efficientnet-b4'):
    """
    load one model
    :param model_path: ./models
    :param model_type: source/target/det
    :param model_backbone: res18/res34/Efficient
    :param use_cuda: True/False
    :return: model
    """
    if backbone not in ['resnet34', 'efficientnet-b3', 'efficientnet-b4']:
        raise ValueError("Unsupported type of models!")

    model = CADDM(2, backbone=backbone)

    if pretrained_model:
        checkpoint = torch.load(pretrained_model)
        model.load_state_dict(checkpoint['network'])
    return model

3.1 CADDM

CADDM 类是一个用于伪造图像检测和分类的神经网络模型。它结合了预训练的主干网络(如 ResNet 或 EfficientNet)和伪造检测模块(ADM),通过提取图像特征并对其进行分类,输出图像是否为伪造的结果。在训练模式下,模型返回位置结果、置信度和分类结果,而在评估模式下,模型返回分类概率。

class CADDM(nn.Module):

    def __init__(self, num_classes, backbone='resnet34'):
        super(CADDM, self).__init__()

        self.num_classes = num_classes
        #backbone: 主干网络的类型,默认为 'resnet34'
        self.backbone = backbone
        if backbone == 'resnet34':
            self.base_model = resnet34(pretrained=True)
        elif backbone == 'efficientnet-b3':
            self.base_model = EfficientNet.from_pretrained(
                'efficientnet-b3', out_size=[1, 3]
            )
        elif backbone == 'efficientnet-b4':
            self.base_model = EfficientNet.from_pretrained(
                'efficientnet-b4', out_size=[1, 3]
            )
        else:
            raise ValueError("Unsupported Backbone!")
        #获取主干网络的输出特征数(即特征图的通道数)
        self.inplanes = self.base_model.out_num_features
        #初始化伪造检测模块(ADM),该模块用于检测图像中的伪造痕迹
        self.adm = Artifact_Detection_Module(self.inplanes)
        #全连接层,用于分类
        self.fc = nn.Linear(self.inplanes, num_classes)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x):
        batch_num = x.size(0)
        #使用主干网络提取特征,得到特征图 x 和全局特征 global_feat
        x, global_feat = self.base_model(x)
        # location result, confidence of each anchor, final feature map of adm.
        #通过伪造检测模块(ADM)进一步处理特征图
        loc, cof, adm_final_feat = self.adm(x)
        #将全局特征和 ADM 最终特征图相加,得到最终的分类特征 final_cls_feat
        final_cls_feat = global_feat + adm_final_feat
        final_cls = self.fc(final_cls_feat.view(batch_num, -1))
        #如果模型处于训练模式,返回位置结果、置信度和最终分类结果
        if self.training:
            return loc, cof, final_cls
        #如果模型处于评估模式,返回经过 Softmax 处理的最终分类结果
        return self.softmax(final_cls)

4、ADM模块

4.1 Artifact_Detection_Module

Artifact_Detection_Module 类用于检测图像中的伪造痕迹。它由多个额外层和一个多尺度检测模块组成,通过前向传播,生成位置、置信度和最终的特征图,用于进一步的分类和检测任务。

class Artifact_Detection_Module(nn.Module):

    def __init__(
            self, inplanes, blocks=1, class_num=2,
            width_hight_ratios=2, extra_layers=None,
    ):

        super(Artifact_Detection_Module, self).__init__()

        # Artifact Detection Module Extra Layers.

        self.cls_num = class_num
        self.inplanes = inplanes
        # 初始化一个空列表 adm_extra_layers,用于存储额外的层
        adm_extra_layers = list()
       #如果未提供 extra_layers 参数,使用默认的额外层配置,其中包含三个 ADM_ExtraBlock 和一个 ADM_EndBlock
        if extra_layers is None:
            extra_layers = [ADM_ExtraBlock] * 3 + [ADM_EndBlock]
        #对于 ADM_EndBlock,直接添加到列表中,对于其他块,使用 _make_layer 方法创建层,并添加到列表中
        for i, extra_block in enumerate(extra_layers):
            ks = 3 if i else 1
            if extra_block != ADM_EndBlock:
                adm_extra_layers.append(
                    self._make_layer(
                        extra_block, inplanes,
                        blocks=blocks, kernel_size=ks, stride=1
                    )
                )
            else:
                adm_extra_layers.append(extra_block(inplanes, inplanes))
        #将 adm_extra_layers 转换为 nn.ModuleList,以便在前向传播中使用
        self.adm_extra_layers = nn.ModuleList(adm_extra_layers)
       #初始化多尺度检测模块,传入输入通道数和额外层的配置
        self.multi_scale_detection_module = Multi_scale_Detection_Module(
            inplanes, extra_layers=extra_layers
        )

    def _make_layer(self, block, planes, blocks, kernel_size, stride=1):
        #创建下采样层,包括卷积和批量归一化
        downsample = nn.Sequential(
            nn.Conv2d(self.inplanes, planes * block.expansion,
                      kernel_size=kernel_size, stride=stride, bias=False),
            nn.BatchNorm2d(planes * block.expansion)
        )
        layers = []
        #初始化一个层列表,添加第一个块,并将下采样层作为其参数
        layers.append(block(
            self.inplanes, planes * block.expansion, kernel_size=kernel_size,
            stride=stride, downsample=downsample))
        #添加剩余的块(不包含下采样层)
        for _ in range(1, blocks):
            layers.append(block(self.inplanes, planes * block.expansion, ))
        #返回由块组成的 nn.Sequential
        return nn.Sequential(*layers)

    def forward(self, x):
        bs = x.size(0)
        adm_feats = list()
        #adm_feats: 存储每一层输出的特征图列表
        for adm_layer in self.adm_extra_layers:
            x = adm_layer(x)
            adm_feats.append(x)
        #使用多尺度检测模块处理 adm_feats,得到位置和置信度
        location, confidence = self.multi_scale_detection_module(adm_feats)

        location = location.view(bs, -1, 4)
        confidence = confidence.view(bs, -1, self.cls_num)

        adm_final_feat = adm_feats[-1]
        #获取最后一层的输出特征图 adm_final_feat
        #返回位置、置信度和 adm_final_feat
        return location, confidence, adm_final_feat

4.2 ADM_ExtraBlock

ADM_ExtraBlock :卷积操作和批量归一化

class ADM_ExtraBlock(nn.Module):
    expansion = 1

    def __init__(
            self, inplanes, planes,
            kernel_size=3, stride=1, downsample=None
    ):
        super(ADM_ExtraBlock, self).__init__()
        # stride/2 maybe applied on conv1
        self.conv1 = nn.Conv2d(
            inplanes, planes, kernel_size=kernel_size, stride=stride)

        self.bn1 = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU(inplace=True)
        # Conv + BatchNorm + RelU
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = nn.BatchNorm2d(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        # Downsample: feature Map size/2 || Channel increase
        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out

4.3 ADM_EndBlock

ADM_EndBlock 使用一个 1x1 的卷积核进行最终处理,并在下采样操作后进行残差连接

class ADM_EndBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, kernel_size=3, stride=1):
        super(ADM_EndBlock, self).__init__()
        # stride/2 maybe applied on conv1
        self.conv1 = nn.Conv2d(
            inplanes, planes, kernel_size=kernel_size, stride=stride)

        self.relu = nn.ReLU(inplace=True)
        # Conv + BatchNorm + RelU
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=1, stride=1)

        self.downsample = nn.Conv2d(
            inplanes, planes, kernel_size=kernel_size, stride=stride
        )

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.relu(out)

        out = self.conv2(out)

        out += self.downsample(residual)
        out = self.relu(out)

        return out

4.4 Multi_scale_Detection_Module

多尺度检测模块,通过多个卷积层分别进行检测和分类

class Multi_scale_Detection_Module(nn.Module):

    def __init__(
            self, inplanes, class_num=2,
            width_hight_ratios=2, extra_layers=None
    ):
        super(Multi_scale_Detection_Module, self).__init__()

        # Multi-scale Detection Module.
        #初始化两个空列表,分别用于存储多尺度检测器和多尺度分类器
        multi_scale_detector = list()
        multi_scale_classifier = list()
        #遍历 extra_layers 中的每个块,根据是否是 ADM_EndBlock 确定卷积核大小 ks 和填充 pad
        for extra_block in extra_layers:
            ks = 3 if extra_block != ADM_EndBlock else 1
            pad = 1 if extra_block != ADM_EndBlock else 0
            #创建一个卷积层并将其添加到 multi_scale_classifier 列表中,用于多尺度分类
            multi_scale_classifier.append(
                nn.Conv2d(
                    inplanes, width_hight_ratios*class_num,
                    kernel_size=ks, stride=1, padding=pad
                )
            )
            #创建一个卷积层并将其添加到 multi_scale_detector 列表中,用于多尺度检测
            multi_scale_detector.append(
                nn.Conv2d(
                    inplanes, width_hight_ratios*4,
                    kernel_size=ks, stride=1, padding=pad
                )
            )

        self.ms_dets = nn.ModuleList(multi_scale_detector)
        self.ms_cls = nn.ModuleList(multi_scale_classifier)

    def forward(self, x):
        confidence, location = list(), list()
        for (feat, detector, classifier) in zip(x, self.ms_dets, self.ms_cls):
            #将特征图 feat 输入到检测器 detector 中,并调整输出的维度顺序,将其添加到 location 列表中
            location.append(detector(feat).permute(0, 2, 3, 1).contiguous())
            #将特征图 feat 输入到分类器 classifier 中,并调整输出的维度顺序,将其添加到 confidence 列表中
            confidence.append(classifier(feat).permute(0, 2, 3, 1).contiguous())
        #将 confidence 列表中的所有元素在通道维度上拼接成一个张量,将 location 列表中的所有元素在通道维度上拼接成一个张量
        confidence = torch.cat([o.view(o.size(0), -1) for o in confidence], 1)
        location = torch.cat([o.view(o.size(0), -1) for o in location], 1)

        return location, confidence
  • 15
    点赞
  • 21
    收藏
    觉得还不错? 一键收藏
  • 3
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值