FSL-GNN复现过程与代码分析(二)
原码地址:https://github.com/vgsatorras/few-shot-gnn
论文链接:https://arxiv.org/pdf/1711.04043.pdf
Train过程
train在main.py 112行定义,比较容易理解,这里主要是load models(这个部分下一篇文章再分析)131行用到权重衰减weight_decay = 0 # 权重衰减
,对于mini_imagenet数据集这个值设为了1e-6
135行 优化器介绍可以参考torch.optim优化算法理解之optim.Adam()
154行 详细可以参考关于pytorch中zero_grad()函数
opt_enc_nn.zero_grad() # 梯度置零,也就是把loss关于weight的导数变成0
opt_metric_nn.zero_grad()
167行 这里的Display会输出不断更新参数过程中的loss
Test过程
if (batch_idx + 1) % args.test_interval == 0 or batch_idx == 20:
if batch_idx == 20:
test_samples = 100 # test with train 中会抽100个task
else:
test_samples = 3000 # test with train 中会抽3000个task
...
**** TESTING WITH val ***
9872 correct from 15000 Accuracy: 65.813%)
*** TEST FINISHED ***
**** TESTING WITH test ***
9883 correct from 15000 Accuracy: 65.887%)
*** TEST FINISHED ***
**** TESTING WITH train ***
2643 correct from 3000 Accuracy: 88.100%)
*** TEST FINISHED ***
Best test accuracy 66.0800 (最后一轮)
**** TESTING WITH test ***
19928 correct from 30000 Accuracy: 66.427%)
*** TEST FINISHED ***
训练完之后我们可以在checkpoints中找到run.log文件,包含有完整训练过程的输出数据可以看到,每一轮训练,用于验证和测试的均有15000个task ;用于测试的均有3000个task
每次训练完一轮,都会输出一个Best test accuracy
当我们训练结束时,最后会保存model
##