源码地址:https://github.com/Clarifai/few-shot-ctm
无法直接运行代码
看代码有人也有这个问题Why i can’t run the code with default setting,所以要稍微修改下代码。
- 修改
tools/general_utils.py
line276为
yaml_cfg = AttrDict(yaml.load(f, Loader=yaml.FullLoader))
main
代码line 28中使用了配置文件configs/demo/mini/20way_1shot.yaml
,但是会有异常抛出。所以我们将配置文件写到core/config.py
的Config类属性中
。然后将line 28
改为:
opts = Config(None)
训练流程
从main.py line 129开始对模型进行训练
学习策略
- 优化器定义在main.py line49
根据不同的参数,分别使用了adam
、sgd
、rmsprop
三种优化器 - 学习率调整定义在main.py line59
根据不同的参数,分别使用了MultiStepLR
、ExponentialLR
2种更新学习率的策略
MultiStepLR
每到一个milestones区间,学习率× gammaExponentialLR
指数衰减学习率,新学习率 = 旧*(gama ^ epoch)
计算模型loss
根据论文和解读:https://blog.csdn.net/qq_36104364/article/details/106363521,我们可以看论文是如何实现
Y
=
M
(
r
(
S
)
⊙
p
,
r
(
Q
)
⊙
p
)
,
Y
=
{
y
i
j
}
Y=\mathcal{M}(\boldsymbol{r}(\mathcal{S}) \odot p, \boldsymbol{r}(\mathcal{Q}) \odot p), \quad Y=\left\{y_{i j}\right\}
Y=M(r(S)⊙p,r(Q)⊙p),Y={yij}
其中
S
S
S是支持集的特征向量,
r
(
S
)
\boldsymbol{r}(\mathcal{S})
r(S)是对支持集
S
S
S的特征向量进行
r
e
s
h
a
p
e
r
reshaper
reshaper,
Q
Q
Q是查询集的特征向量,
r
(
Q
)
\boldsymbol{r}(\mathcal{Q})
r(Q)是对查询集
Q
Q
Q的特征向量进行
r
e
s
h
a
p
e
r
reshaper
reshaper,
M
\mathcal{M}
M是距离函数
代码在core/model.py line415,根据注释函数主要包括三部分,以
20
w
a
y
−
1
s
h
o
t
−
8
w
a
y
20way - 1shot - 8way
20way−1shot−8way为例:
- 特征提取,使用
repnet
进行特征提取:core/model.py line424
# support_sz (25), c (64), d (19), d (19)
support_xf_ori = self.repnet(support_x.view(batch_sz*support_sz, -1, _d, _d)) # torch.Size([1, 20, 3, 84, 84]) -> torch.Size([20, 3, 84, 84]) -> torch.Size([20, 64, 19, 19])
# query_sz (75), c (64), d (19), d (19)
query_xf_ori = self.repnet(query_x.view(batch_sz*query_sz, -1, _d, _d))# torch.Size([1, 160, 3, 84, 84]) -> torch.Size([160, 3, 84, 84]) -> torch.Size([160, 64, 19, 19])
Concentrator
:core/model.py line434
mp = self.main_component(support_xf_reshape) # ([20, 40, 19, 19]) # 5(n_way), 64, 3, 3
projection
:core/model.py line442
_input_P = mp.view(1, -1, mp.size(2), mp.size(3)) # ([1, 800, 19, 19])
P = self.projection(_input_P) # 1, 64, 3, 3
P = F.softmax(P, dim=1)
reshaper
:core/model.py line457对support和query的特征进行reshaper
v = self.reshaper(support_xf_ori)
query = self.reshaper(query_xf_ori) # 75, 64, 3, 3
- 相乘得到结果:core/model.py line458对support和query的特征进行reshaper
query = torch.matmul(query, P)
其他
- 展示模型的训练loss:main.py line161
- 在验证集上验证模型,并且保存最好精度的模型:main.py line169