一文读懂few-shot learning

在这里插入图片描述
前提: 有一点代码基础,理解数据的维度。

目的: 用少量的样本训练模型,使模型具备分类能力。

数据集CUB: 200类,100 base class(训练集),50 valid class,50 novel class(测试集)。

训练方法:
1:直接用base class训练一个200类别的CNN分类器。
2:meta-learning method(训练和测试保持一致),以5-way 5-shot(从base class中随机取5个类别,每个类别随机取5张图,这里25张图片为support set)为例,假设query set 有16张图片(5个类别剩余的其他图片)。输入网络的数据为(5-way,5-shot+16,3,h,w)=(5,21,3,h,w),然后放入CNN中得到support set的特征(5,5,1600)和query set的特征(5,16,1600),接下来平均support set中每个类每张图片的特征mean((5,5,1600),dim=1)=(5,1600),reshape query set=(80,1600),最后计算两个set的距离得到(80,5)维向量,这可以表示query set中每张图片与support set中哪张图片比较相似,然后用一下交叉熵损失即可。【meta-training task】

note:每次随机取的五个类,重新定义标签为[0,1,2,3,4],所以才可以对(80,5)使用交叉熵损失。base class中的support set和query set都拿来训练,只是最后计算损失时,分开计算它们之间的距离。

验证:
简单的同训练方法1,直接把所有测试集valid class放进网络查看分类精度。

测试:
训练完网络后,得到特征提取器(backbone),固定其权重,把novel class放入其中得到每张图片的特征,假设每张图片1600维。【 接下来,从novel class中随机取5类,每类里面取5张图片5-way 5-shot(support set)和16张图片(query set),得到(5x(5+16),1600)=(5x5,1600)+(5x16,1600),把support set(25,1600)放入一层的分类器中分类,训练100个epoch,最后把query set(80,1600)放入测试得到结果。】【称为一个episode,要测试几百个episode取平均】【meta-test task】。这里你也许有疑问,这样训练(fine-tune)不是测试集也进行拟合了吗,其实每个meta-test task中的那层分类器都是重新定义的。

点个赞哦,亲(づ ̄3 ̄)づ╭❤~

  • 4
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值