【Datawhale AI夏令营】Deepfake多模态检测大赛Task2
问题陈述
Deepfake是一种使用人工智能技术生成的伪造媒体,特别是视频和音频,它们看起来或听起来非常真实,但实际上是由计算机生成的。这种技术通常涉及到深度学习算法,特别是生成对抗网络(GANs),它们能够学习真实数据的特征,并生成新的、逼真的数据。
深度伪造技术通常可以分为四个主流研究方向:
- 面部交换专注于在两个人的图像之间执行身份交换;
- 面部重演强调转移源运动和姿态;
- 说话面部生成专注于在角色生成中实现口型与文本内容的自然匹配;
- 面部属性编辑旨在修改目标图像的特定面部属性;
深度学习是一种强大的机器学习技术,它通过模拟人脑处理信息的方式,使计算机能够从大量数据中自动学习和识别模式。深度学习模型,尤其是卷积神经网络(CNN),能够识别图像和视频中的复杂特征。在Deepfake检测中,模型可以学习识别伪造内容中可能存在的微妙异常。
为了训练有效的Deepfake检测模型,需要构建包含各种Deepfake和真实样本的数据集(本次比赛的数据集就是按照这种方式进行组织)。深度学习模型通过这些数据集学习区分真假内容。
本质上,该任务是一个判别任务。
Baseline关键步骤
- 数据准备:使用Pandas库读取训练集和验证集的标签,并将图片路径与标签结合,以便于后续处理。
- 定义生成MEL频谱图的函数:
generate_mel_spectrogram
函数用于从视频文件中提取音频,并生成MEL频谱图,然后将其转换为图像格式。 - 定义训练、验证和预测函数:
train
、validate
和predict
函数分别用于模型的训练、在验证集上评估模型性能以及生成预测结果。 - 模型初始化和训练:初始化
resnet18
模型,并使用Adam优化器和交叉熵损失函数进行训练。训练过程中使用了学习率调度器,并在每个epoch结束时在验证集上评估模型性能。 - 保存最佳模型:在验证过程中,如果模型的性能超过了之前的最佳性能,则保存模型的权重。
- 生成预测结果:使用训练好的模型对验证集进行预测,并将预测结果保存到
submit.csv
文件中。 - 提交结果:最后,代码将预测的分数与原始的提交模板合并,并保存为最终的提交文件。
加载预训练模型
预训练模型是指在特定的大型数据集(如ImageNet)上预先训练好的神经网络模型。这些模型已经学习到了丰富的特征表示,能够识别和处理图像中的多种模式。使用预训练模型的好处是,它们可以在新数据集或新任务上进行微调(Fine-tuning),从而加快训练过程并提高模型性能,尤其是当可用的数据量有限时。
ResNet(残差网络)是一种深度卷积神经网络,由微软研究院的Kaiming He等人在2015年提出。ResNet的核心思想是引入了“残差学习”框架,通过添加跳过一层或多层的连接(即残差连接或快捷连接),解决了随着网络深度增加时训练困难的问题。
在下面代码中,timm.create_model('resnet18', pretrained=True, num_classes=2)
这行代码就是加载了一个预训练的ResNet-18模型,其中pretrained=True
表示使用在ImageNet数据集上预训练的权重,num_classes=2
表示模型的输出层被修改为有2个类别的输出,以适应二分类任务(例如区分真实和Deepfake图像)。通过model = model.cuda()
将模型移动到GPU上进行加速。
import timm
model = timm.create_model('resnet18', pretrained=True, num_classes=2)
model = model.cuda()
提取音频特征
在识别Deepfake视频时,音频分析之所以简单,是因为Deepfake技术生成的视频中,音频可能存在不自然或重复的模式,例如重复的单词或短语。通过分析音频的频谱图,可以更容易地发现这些异常,从而帮助识别视频是否经过了深度伪造处理。
MEL频谱图(Mel-spectrogram)是一种在音频信号处理领域常用的可视化工具,它基于人耳的听觉特性来表示音频信号的频率内容。梅尔刻度是一种对频率进行非线性缩放的方法,它将线性频率映射到梅尔频率上,使得梅尔刻度上的间隔更接近人耳感知的间隔。梅尔刻度是以物理学家H. Fletcher和W. A. Munson的名字命名的。
def generate_mel_spectrogram(video_path, n_mels=128, fmax=8000, target_size=(256, 256)):
# 提取音频
audio_path = 'extracted_audio.wav'
video = mp.VideoFileClip(video_path)
video.audio.write_audiofile(audio_path, verbose=False, logger=None)
# 加载音频文件
y, sr = librosa.load(audio_path)
# 生成MEL频谱图
S = librosa.feature.melspectrogram(y=y, sr=sr, n_mels=n_mels)
# 将频谱图转换为dB单位
S_dB = librosa.power_to_db(S, ref=np.max)
# 归一化到0-255之间
S_dB_normalized = cv2.normalize(S_dB, None, 0, 255, cv2.NORM_MINMAX)
# 将浮点数转换为无符号8位整型
S_dB_normalized = S_dB_normalized.astype(np.uint8)
# 缩放到目标大小
img_resized = cv2.resize(S_dB_normalized, target_size, interpolation=cv2.INTER_LINEAR)
return img_resized
定义模型训练步骤
在深度学习中,模型训练通常需要进行多次迭代,而不是单次完成。深度学习模型的训练本质上是一个优化问题,目标是最小化损失函数。梯度下降算法通过计算损失函数相对于模型参数的梯度来更新参数。由于每次参数更新只能基于一个数据批次来计算梯度,因此需要多次迭代,每次处理一个新的数据批次,以确保模型在整个数据集上都能得到优化。
模型训练的流程如下:
-
设置训练模式:通过调用
model.train()
将模型设置为训练模式。在训练模式下,模型的某些层(如BatchNorm
和Dropout
)会按照它们在训练期间应有的方式运行。 -
遍历数据加载器:使用
enumerate(train_loader)
遍历train_loader
提供的数据批次。input
是批次中的图像数据,target
是对应的标签。 -
数据移动到GPU:通过
.cuda(non_blocking=True)
将数据和标签移动到GPU上。non_blocking
参数设置为True
意味着如果数据正在被复制到GPU,此操作会立即返回,不会等待数据传输完成。 -
前向传播:通过
output = model(input)
进行前向传播,计算模型对输入数据的预测。 -
计算损失:使用损失函数
loss = criterion(output, target)
计算预测输出和目标标签之间的差异。 -
梯度归零:在每次迭代开始前,通过
optimizer.zero_grad()
清空(重置)之前的梯度,以防止梯度累积。 -
反向传播:调用
loss.backward()
计算损失相对于模型参数的梯度。 -
参数更新:通过
optimizer.step()
根据计算得到的梯度更新模型的参数。
transforms.Compose: 这是一个转换操作的组合,它将多个图像预处理步骤串联起来:
transforms.Resize((256, 256))
:将所有图像调整为256x256像素的大小。transforms.RandomHorizontalFlip()
:随机水平翻转图像。transforms.RandomVerticalFlip()
:随机垂直翻转图像。transforms.ToTensor()
:将PIL图像或Numpy数组转换为torch.FloatTensor
类型,并除以255以将像素值范围从[0, 255]缩放到[0, 1]。transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
:对图像进行标准化,使用ImageNet数据集的均值和标准差。
train_loader = torch.utils.data.DataLoader(
FFDIDataset(train_label['path'], train_label['target'],
transforms.Compose([
transforms.Resize((256, 256)),
transforms.RandomHorizontalFlip(),
transforms.RandomVerticalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
), batch_size=40, shuffle=True, num_workers=4, pin_memory=True
)
可优化方向
- 多帧提取,面部识别,相似比较
- 音频fake
- 特征提取
- chroma_stft
- rms
- spectral_centroid
- spectral_bandwidth
- rolloff
- zero_crossing_rate
- mfcc