论文讲解请看:https://blog.csdn.net/JustWantToLearn/article/details/138758033
代码链接:https://github.com/megvii-research/CADDM
在这里,我们简要描述算法流程,着重分析模型搭建细节,以及为什么要这样搭建。
part 1:数据集准备,请看链接 https://blog.csdn.net/JustWantToLearn/article/details/138773005
part 2: 数据集加载,包含 Multi-scale Facial Swap(MFS) 模块:https://blog.csdn.net/JustWantToLearn/article/details/139092687
part 3:训练过程,ADM模块,本文
文章目录
1、训练 train.py
python train.py --cfg ./configs/caddm_train.cfg
def train():
args = args_func()
# load conifigs
cfg = load_config(args.cfg)
# init model. 模型初始化
net = model.get(backbone=cfg['model']['backbone'])
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
net = net.to(device)
net = nn.DataParallel(net)
# loss init loss初始化,多任务损失函数 MultiBoxLoss 和交叉熵损失函数 nn.CrossEntropyLoss
det_criterion = MultiBoxLoss(
cfg['det_loss']['num_classes'],
cfg['det_loss']['overlap_thresh'],
cfg['det_loss']['prior_for_matching'],
cfg['det_loss']['bkg_label'],
cfg['det_loss']['neg_mining'],
cfg['det_loss']['neg_pos'],
cfg['det_loss']['neg_overlap'],
cfg['det_loss']['encode_target'],
cfg['det_loss']['use_gpu']
)
criterion = nn.CrossEntropyLoss()
# optimizer init.
optimizer = optim.AdamW(net.parameters(), lr=1e-3, weight_decay=4e-3)
# load checkpoint if given
base_epoch = 0
if args.ckpt:
net, optimzer, base_epoch = load_checkpoint(args.ckpt, net, optimizer, device)
# get training data 加载训练数据集
print(f"Load deepfake dataset from {cfg['dataset']['img_path']}..")
train_dataset = DeepfakeDataset('train', cfg)
train_loader = DataLoader(train_dataset,
batch_size=cfg['train']['batch_size'],
shuffle=True, num_workers=4,
collate_fn=my_collate
)
# start trining.进入训练模式,并循环遍历每个epoch和batch。在每个epoch开始时更新学习率
net.train()
for epoch in range(base_epoch, cfg['train']['epoch_num']):
for index, (batch_data, batch_labels) in enumerate(train_loader):
lr = update_learning_rate(epoch)
for param_group in optimizer.param_groups:
param_group['lr'] = lr
labels, location_labels, confidence_labels = batch_labels
labels = labels.long().to(device)
location_labels = location_labels.to(device)
confidence_labels = confidence_labels.long().to(device)
#计算分类损失和检测损失。然后计算总损失,并执行反向传播
optimizer.zero_grad()
locations, confidence, outputs = net(batch_data)
loss_end_cls = criterion(outputs, labels)
loss_l, loss_c = det_criterion(
(locations, confidence),
confidence_labels, location_labels
)
acc = sum(outputs.max(-1).indices == labels).item() / labels.shape[0]
det_loss = 0.1 * (loss_l + loss_c)
loss = det_loss + loss_end_cls
loss.backward()
# 梯度裁剪和优化器步
torch.nn.utils.clip_grad_value_(net.parameters(), 2)
optimizer.step()
outputs = [
"e:{},iter: {}".format(epoch, index),
"acc: {:.2f}".format(acc),
"loss: {:.8f} ".format(loss.item()),
"lr:{:.4g}".format(lr),
]
print(" ".join(outputs))
save_checkpoint(net, optimizer,
cfg['model']['save_path'],
epoch)
2、损失函数 MultiBoxLoss
MultiBoxLoss 类实现了SSD模型的损失计算,包括位置损失和置信度损失。
这里大体解释每个函数模块做了什么,具体的实现细节可以看论文https://arxiv.org/pdf/1512.02325.pdf
class MultiBoxLoss(nn.Module):
"""SSD Weighted Loss Function
Compute Targets:
1) Produce Confidence Target Indices by matching ground truth boxes
with (default) 'priorboxes' that have jaccard index > threshold parameter
(default threshold: 0.5).
2) Produce localization target by 'encoding' variance into offsets of ground
truth boxes and their matched 'priorboxes'.
3) Hard negative mining to filter the excessive number of negative examples
that comes with using a large number of default bounding boxes.
(default negative:positive ratio 3:1)
Objective Loss:
L(x,c,l,g) = (Lconf(x, c) + αLloc(x,l,g)) / N
Where, Lconf is the CrossEntropy Loss and Lloc is the SmoothL1 Loss
weighted by α which is set to 1 by cross val.
Args:
c: class confidences,
l: predicted boxes,
g: ground truth boxes
N: number of matched default boxes
See: https://arxiv.org/pdf/1512.02325.pdf for more details.
"""
def __init__(self, num_classes, overlap_thresh, prior_for_matching,
bkg_label, neg_mining, neg_pos, neg_overlap, encode_target,
use_gpu=True):
super(MultiBoxLoss, self).__init__()
self.use_gpu = use_gpu
self.num_classes = num_classes
self.threshold = overlap_thresh
self.background_label = bkg_label
self.encode_target = encode_target
self.use_prior_for_matching = prior_for_matching
self.do_neg_mining = neg_mining
self.negpos_ratio = neg_pos
self.neg_overlap = neg_overlap
self.variance = [0.1, 0.2] # cfg['variance']
# def forward(self, predictions, targets):
def forward(self, predictions, conf_t, loc_t):
"""Multibox Loss
Args:
predictions (tuple): A tuple containing loc preds, conf preds,
and prior boxes from SSD net.
conf shape: torch.size(batch_size,num_priors,num_classes)
loc shape: torch.size(batch_size,num_priors,4)
priors shape: torch.size(num_priors,4)
targets (tensor): Ground truth boxes and labels for a batch,
shape: [batch_size,num_objs,5] (last idx is the label).
"""
'''
priors = priors[:loc_data.size(1), :]
num_priors = (priors.size(0))
num_classes = self.num_classes
# match priors (default boxes) and ground truth boxes
loc_t = torch.Tensor(num, num_priors, 4)
conf_t = torch.LongTensor(num, num_priors)
for idx in range(num):
truths = targets[idx][:, :-1].data
labels = targets[idx][:, -1].data
defaults = priors.data
match(self.threshold, truths, defaults, self.variance, labels,
loc_t, conf_t, idx)
'''
#predictions:模型的预测输出,包括位置预测和置信度预测。
loc_data, conf_data = predictions
num = loc_data.size(0)
if self.use_gpu:
#conf_t:置信度目标,loc_t:位置目标。
loc_t = loc_t.cuda()
conf_t = conf_t.cuda()
# wrap targets 将目标数据封装为 Variable,并设置 requires_grad=False 以防止计算梯度。
loc_t = Variable(loc_t, requires_grad=False)
conf_t = Variable(conf_t, requires_grad=False)
#计算正样本的位置和数量。
pos = conf_t > 0
num_pos = pos.sum(dim=1, keepdim=True)
# Localization Loss (Smooth L1)
# Shape: [batch,num_priors,4] 使用正样本的位置数据计算Smooth L1损失
pos_idx = pos.unsqueeze(pos.dim()).expand_as(loc_data)
loc_p = loc_data[pos_idx].view(-1, 4)
loc_t = loc_t[pos_idx].view(-1, 4)
loss_l = F.smooth_l1_loss(loc_p, loc_t, size_average=False)
# Compute max conf across batch for hard negative mining
# 计算置信度损失,包括硬负样本挖掘,保证正负样本比例合理
batch_conf = conf_data.view(-1, self.num_classes)
loss_c = log_sum_exp(batch_conf) - batch_conf.gather(1, conf_t.view(-1, 1))
# Hard Negative Mining
# loss_c[pos] = 0 # filter out pos boxes for now
loss_c[pos.view(-1, 1)] = 0
loss_c = loss_c.view(num, -1)
_, loss_idx = loss_c.sort(1, descending=True)
_, idx_rank = loss_idx.sort(1)
num_pos = pos.long().sum(1, keepdim=True)
num_neg = torch.clamp(self.negpos_ratio*num_pos, max=pos.size(1)-1)
neg = idx_rank < num_neg.expand_as(idx_rank)
# Confidence Loss Including Positive and Negative Examples
pos_idx = pos.unsqueeze(2).expand_as(conf_data)
neg_idx = neg.unsqueeze(2).expand_as(conf_data)
conf_p = conf_data[(pos_idx+neg_idx).gt(0)].view(-1, self.num_classes)
targets_weighted = conf_t[(pos+neg).gt(0)]
loss_c = F.cross_entropy(conf_p, targets_weighted, size_average=False)
# Sum of losses: L(x,c,l,g) = (Lconf(x, c) + αLloc(x,l,g)) / N
# 归一化位置损失和置信度损失,然后返回
N = num_pos.data.sum() if num_pos.data.sum() else 1
loss_l /= N
loss_c /= N
return loss_l, loss_c
3、模型搭建 CADDM
def get(pretrained_model=None, backbone='efficientnet-b4'):
"""
load one model
:param model_path: ./models
:param model_type: source/target/det
:param model_backbone: res18/res34/Efficient
:param use_cuda: True/False
:return: model
"""
if backbone not in ['resnet34', 'efficientnet-b3', 'efficientnet-b4']:
raise ValueError("Unsupported type of models!")
model = CADDM(2, backbone=backbone)
if pretrained_model:
checkpoint = torch.load(pretrained_model)
model.load_state_dict(checkpoint['network'])
return model
3.1 CADDM
CADDM 类是一个用于伪造图像检测和分类的神经网络模型。它结合了预训练的主干网络(如 ResNet 或 EfficientNet)和伪造检测模块(ADM),通过提取图像特征并对其进行分类,输出图像是否为伪造的结果。在训练模式下,模型返回位置结果、置信度和分类结果,而在评估模式下,模型返回分类概率。
class CADDM(nn.Module):
def __init__(self, num_classes, backbone='resnet34'):
super(CADDM, self).__init__()
self.num_classes = num_classes
#backbone: 主干网络的类型,默认为 'resnet34'
self.backbone = backbone
if backbone == 'resnet34':
self.base_model = resnet34(pretrained=True)
elif backbone == 'efficientnet-b3':
self.base_model = EfficientNet.from_pretrained(
'efficientnet-b3', out_size=[1, 3]
)
elif backbone == 'efficientnet-b4':
self.base_model = EfficientNet.from_pretrained(
'efficientnet-b4', out_size=[1, 3]
)
else:
raise ValueError("Unsupported Backbone!")
#获取主干网络的输出特征数(即特征图的通道数)
self.inplanes = self.base_model.out_num_features
#初始化伪造检测模块(ADM),该模块用于检测图像中的伪造痕迹
self.adm = Artifact_Detection_Module(self.inplanes)
#全连接层,用于分类
self.fc = nn.Linear(self.inplanes, num_classes)
self.softmax = nn.Softmax(dim=-1)
def forward(self, x):
batch_num = x.size(0)
#使用主干网络提取特征,得到特征图 x 和全局特征 global_feat
x, global_feat = self.base_model(x)
# location result, confidence of each anchor, final feature map of adm.
#通过伪造检测模块(ADM)进一步处理特征图
loc, cof, adm_final_feat = self.adm(x)
#将全局特征和 ADM 最终特征图相加,得到最终的分类特征 final_cls_feat
final_cls_feat = global_feat + adm_final_feat
final_cls = self.fc(final_cls_feat.view(batch_num, -1))
#如果模型处于训练模式,返回位置结果、置信度和最终分类结果
if self.training:
return loc, cof, final_cls
#如果模型处于评估模式,返回经过 Softmax 处理的最终分类结果
return self.softmax(final_cls)
4、ADM模块
4.1 Artifact_Detection_Module
Artifact_Detection_Module 类用于检测图像中的伪造痕迹。它由多个额外层和一个多尺度检测模块组成,通过前向传播,生成位置、置信度和最终的特征图,用于进一步的分类和检测任务。
class Artifact_Detection_Module(nn.Module):
def __init__(
self, inplanes, blocks=1, class_num=2,
width_hight_ratios=2, extra_layers=None,
):
super(Artifact_Detection_Module, self).__init__()
# Artifact Detection Module Extra Layers.
self.cls_num = class_num
self.inplanes = inplanes
# 初始化一个空列表 adm_extra_layers,用于存储额外的层
adm_extra_layers = list()
#如果未提供 extra_layers 参数,使用默认的额外层配置,其中包含三个 ADM_ExtraBlock 和一个 ADM_EndBlock
if extra_layers is None:
extra_layers = [ADM_ExtraBlock] * 3 + [ADM_EndBlock]
#对于 ADM_EndBlock,直接添加到列表中,对于其他块,使用 _make_layer 方法创建层,并添加到列表中
for i, extra_block in enumerate(extra_layers):
ks = 3 if i else 1
if extra_block != ADM_EndBlock:
adm_extra_layers.append(
self._make_layer(
extra_block, inplanes,
blocks=blocks, kernel_size=ks, stride=1
)
)
else:
adm_extra_layers.append(extra_block(inplanes, inplanes))
#将 adm_extra_layers 转换为 nn.ModuleList,以便在前向传播中使用
self.adm_extra_layers = nn.ModuleList(adm_extra_layers)
#初始化多尺度检测模块,传入输入通道数和额外层的配置
self.multi_scale_detection_module = Multi_scale_Detection_Module(
inplanes, extra_layers=extra_layers
)
def _make_layer(self, block, planes, blocks, kernel_size, stride=1):
#创建下采样层,包括卷积和批量归一化
downsample = nn.Sequential(
nn.Conv2d(self.inplanes, planes * block.expansion,
kernel_size=kernel_size, stride=stride, bias=False),
nn.BatchNorm2d(planes * block.expansion)
)
layers = []
#初始化一个层列表,添加第一个块,并将下采样层作为其参数
layers.append(block(
self.inplanes, planes * block.expansion, kernel_size=kernel_size,
stride=stride, downsample=downsample))
#添加剩余的块(不包含下采样层)
for _ in range(1, blocks):
layers.append(block(self.inplanes, planes * block.expansion, ))
#返回由块组成的 nn.Sequential
return nn.Sequential(*layers)
def forward(self, x):
bs = x.size(0)
adm_feats = list()
#adm_feats: 存储每一层输出的特征图列表
for adm_layer in self.adm_extra_layers:
x = adm_layer(x)
adm_feats.append(x)
#使用多尺度检测模块处理 adm_feats,得到位置和置信度
location, confidence = self.multi_scale_detection_module(adm_feats)
location = location.view(bs, -1, 4)
confidence = confidence.view(bs, -1, self.cls_num)
adm_final_feat = adm_feats[-1]
#获取最后一层的输出特征图 adm_final_feat
#返回位置、置信度和 adm_final_feat
return location, confidence, adm_final_feat
4.2 ADM_ExtraBlock
ADM_ExtraBlock :卷积操作和批量归一化
class ADM_ExtraBlock(nn.Module):
expansion = 1
def __init__(
self, inplanes, planes,
kernel_size=3, stride=1, downsample=None
):
super(ADM_ExtraBlock, self).__init__()
# stride/2 maybe applied on conv1
self.conv1 = nn.Conv2d(
inplanes, planes, kernel_size=kernel_size, stride=stride)
self.bn1 = nn.BatchNorm2d(planes)
self.relu = nn.ReLU(inplace=True)
# Conv + BatchNorm + RelU
self.conv2 = conv3x3(planes, planes)
self.bn2 = nn.BatchNorm2d(planes)
self.downsample = downsample
self.stride = stride
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
# Downsample: feature Map size/2 || Channel increase
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
4.3 ADM_EndBlock
ADM_EndBlock 使用一个 1x1 的卷积核进行最终处理,并在下采样操作后进行残差连接
class ADM_EndBlock(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, kernel_size=3, stride=1):
super(ADM_EndBlock, self).__init__()
# stride/2 maybe applied on conv1
self.conv1 = nn.Conv2d(
inplanes, planes, kernel_size=kernel_size, stride=stride)
self.relu = nn.ReLU(inplace=True)
# Conv + BatchNorm + RelU
self.conv2 = nn.Conv2d(planes, planes, kernel_size=1, stride=1)
self.downsample = nn.Conv2d(
inplanes, planes, kernel_size=kernel_size, stride=stride
)
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.relu(out)
out = self.conv2(out)
out += self.downsample(residual)
out = self.relu(out)
return out
4.4 Multi_scale_Detection_Module
多尺度检测模块,通过多个卷积层分别进行检测和分类
class Multi_scale_Detection_Module(nn.Module):
def __init__(
self, inplanes, class_num=2,
width_hight_ratios=2, extra_layers=None
):
super(Multi_scale_Detection_Module, self).__init__()
# Multi-scale Detection Module.
#初始化两个空列表,分别用于存储多尺度检测器和多尺度分类器
multi_scale_detector = list()
multi_scale_classifier = list()
#遍历 extra_layers 中的每个块,根据是否是 ADM_EndBlock 确定卷积核大小 ks 和填充 pad
for extra_block in extra_layers:
ks = 3 if extra_block != ADM_EndBlock else 1
pad = 1 if extra_block != ADM_EndBlock else 0
#创建一个卷积层并将其添加到 multi_scale_classifier 列表中,用于多尺度分类
multi_scale_classifier.append(
nn.Conv2d(
inplanes, width_hight_ratios*class_num,
kernel_size=ks, stride=1, padding=pad
)
)
#创建一个卷积层并将其添加到 multi_scale_detector 列表中,用于多尺度检测
multi_scale_detector.append(
nn.Conv2d(
inplanes, width_hight_ratios*4,
kernel_size=ks, stride=1, padding=pad
)
)
self.ms_dets = nn.ModuleList(multi_scale_detector)
self.ms_cls = nn.ModuleList(multi_scale_classifier)
def forward(self, x):
confidence, location = list(), list()
for (feat, detector, classifier) in zip(x, self.ms_dets, self.ms_cls):
#将特征图 feat 输入到检测器 detector 中,并调整输出的维度顺序,将其添加到 location 列表中
location.append(detector(feat).permute(0, 2, 3, 1).contiguous())
#将特征图 feat 输入到分类器 classifier 中,并调整输出的维度顺序,将其添加到 confidence 列表中
confidence.append(classifier(feat).permute(0, 2, 3, 1).contiguous())
#将 confidence 列表中的所有元素在通道维度上拼接成一个张量,将 location 列表中的所有元素在通道维度上拼接成一个张量
confidence = torch.cat([o.view(o.size(0), -1) for o in confidence], 1)
location = torch.cat([o.view(o.size(0), -1) for o in location], 1)
return location, confidence