ctm代码笔记《Finding Task-Relevant Features for Few-Shot Learning by Category Traversal》

18 篇文章 0 订阅
12 篇文章 0 订阅

源码地址:https://github.com/Clarifai/few-shot-ctm

无法直接运行代码

看代码有人也有这个问题Why i can’t run the code with default setting,所以要稍微修改下代码。

  1. 修改tools/general_utils.pyline276
yaml_cfg = AttrDict(yaml.load(f, Loader=yaml.FullLoader))
  1. main代码line 28中使用了配置文件configs/demo/mini/20way_1shot.yaml,但是会有异常抛出。所以我们将配置文件写到core/config.pyConfig类属性中。然后将line 28
    改为:
opts = Config(None)

训练流程

main.py line 129开始对模型进行训练

学习策略

  1. 优化器定义在main.py line49
    根据不同的参数,分别使用了adamsgdrmsprop 三种优化器
  2. 学习率调整定义在main.py line59
    根据不同的参数,分别使用了MultiStepLRExponentialLR 2种更新学习率的策略
  • MultiStepLR每到一个milestones区间,学习率× gamma
  • ExponentialLR 指数衰减学习率,新学习率 = 旧*(gama ^ epoch)

计算模型loss

根据论文和解读:https://blog.csdn.net/qq_36104364/article/details/106363521,我们可以看论文是如何实现
Y = M ( r ( S ) ⊙ p , r ( Q ) ⊙ p ) , Y = { y i j } Y=\mathcal{M}(\boldsymbol{r}(\mathcal{S}) \odot p, \boldsymbol{r}(\mathcal{Q}) \odot p), \quad Y=\left\{y_{i j}\right\} Y=M(r(S)p,r(Q)p),Y={yij}
其中 S S S是支持集的特征向量, r ( S ) \boldsymbol{r}(\mathcal{S}) r(S)是对支持集 S S S的特征向量进行 r e s h a p e r reshaper reshaper Q Q Q是查询集的特征向量, r ( Q ) \boldsymbol{r}(\mathcal{Q}) r(Q)是对查询集 Q Q Q的特征向量进行 r e s h a p e r reshaper reshaper M \mathcal{M} M是距离函数
在这里插入图片描述
代码在core/model.py line415,根据注释函数主要包括三部分,以 20 w a y − 1 s h o t − 8 w a y 20way - 1shot - 8way 20way1shot8way为例:

  1. 特征提取,使用repnet进行特征提取:core/model.py line424
# support_sz (25), c (64), d (19), d (19)
support_xf_ori = self.repnet(support_x.view(batch_sz*support_sz, -1, _d, _d))  # torch.Size([1, 20, 3, 84, 84]) -> torch.Size([20, 3, 84, 84]) -> torch.Size([20, 64, 19, 19])
# query_sz (75), c (64), d (19), d (19)
query_xf_ori = self.repnet(query_x.view(batch_sz*query_sz, -1, _d, _d))# torch.Size([1, 160, 3, 84, 84]) -> torch.Size([160, 3, 84, 84]) -> torch.Size([160, 64, 19, 19])
  1. Concentrator:core/model.py line434
mp = self.main_component(support_xf_reshape)         #    ([20, 40, 19, 19])      # 5(n_way), 64, 3, 3
  1. projection:core/model.py line442
_input_P = mp.view(1, -1, mp.size(2), mp.size(3))   # ([1, 800, 19, 19])
P = self.projection(_input_P)   # 1, 64, 3, 3
P = F.softmax(P, dim=1)
  1. reshaper:core/model.py line457对support和query的特征进行reshaper
v = self.reshaper(support_xf_ori)
query = self.reshaper(query_xf_ori)     # 75, 64, 3, 3
  1. 相乘得到结果:core/model.py line458对support和query的特征进行reshaper
query = torch.matmul(query, P)
  1. 计算分数:core/model.py line476
  2. 输出:core/model.py line479

其他

  1. 展示模型的训练loss:main.py line161
  2. 在验证集上验证模型,并且保存最好精度的模型:main.py line169
  • 1
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 4
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值