本文主要是利用chatGPT等工具,对baseline的内容代码进行理解和分析。
1.1 数据处理前的分析工作
'''
代码解释:统计文件的行数
!表示在jupyter中执行shell命令的符号
wc 表示一种统计文件的字数,行数的工具
-l 表示选项,统计文件的行数'''
!wc -l /kaggle/input/deepfake/phase1/trainset_label.txt !wc -l /kaggle/input/deepfake/phase1/valset_label.txt !ls /kaggle/input/deepfake/phase1/trainset/ | wc -l !ls /kaggle/input/deepfake/phase1/valset/ | wc -l
1.2 数据路径的处理
# 读取到每一个图像的路径
train_label = pd.read_csv('/kaggle/input/deepfake/phase1/trainset_label.txt') val_label = pd.read_csv('/kaggle/input/deepfake/phase1/valset_label.txt') train_label['path'] = '/kaggle/input/deepfake/phase1/trainset/' + train_label['img_name'] val_label['path'] = '/kaggle/input/deepfake/phase1/valset/' + val_label['img_name']
1.3 对于数据集的处理工作
#
validate
函数,用于在验证集上评估深度学习模型的性能#
model.eval()
将模型切换到评估模式# 遍历验证数据加载器
# 训练集和验证集的代码相似
def validate(val_loader, model, criterion): batch_time = AverageMeter('Time', ':6.3f') losses = AverageMeter('Loss', ':.4e') top1 = AverageMeter('Acc@1', ':6.2f') progress = ProgressMeter(len(val_loader), batch_time, losses, top1) # switch to evaluate mode model.eval() with torch.no_grad(): end = time.time() for i, (input, target) in tqdm_notebook(enumerate(val_loader), total=len(val_loader)): input = input.cuda() target = target.cuda() # compute output output = model(input) loss = criterion(output, target) # measure accuracy and record loss acc = (output.argmax(1).view(-1) == target.float().view(-1)).float().mean() * 100 losses.update(loss.item(), input.size(0)) top1.update(acc, input.size(0)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() # TODO: this should also be done with the ProgressMeter print(' * Acc@1 {top1.avg:.3f}' .format(top1=top1)) return top1