FSL-GNN
数据流
one task: c way k shot 支持图像(每类k个样本) 查询图像(1个样本)
batch:batch个task
eg:
(1)5 way 1 shot,b=10
支持集(5x1,c,w,h,)
查询集(b,c,w,h)
(2)20 way 5 shot,b=10
支持集(20x5,c,w,h,)
查询集(b,c,w,h)
模型
构建节点
图像特征提取网络
小样本常规4个卷积层网络
构建边,并更新
计算节点间的欧式距离,通过四个卷积层学习更新
节点更新
WAX进行节点更新
技巧
将数值标签转换为one-hot编码
labels_x_scalar = np.argmax(labels_x, 1)
网络定义----类嵌套定义
逐层写 逐层前向
多个层按功能封装一起 逐子模块调用
调整学习率
# 定义学习率调整函数
def adjust_learning_rate(optimizers, lr, iter):
new_lr = lr * (0.5**(int(iter/args.dec_lr)))
for optimizer in optimizers:
for param_group in optimizer.param_groups:
param_group['lr'] = new_lr
# 调用学习率调整方法
adjust_learning_rate(optimizers=[opt_enc_nn, opt_metric_nn], lr=args.lr, iter=batch_idx)