pytorch元素相乘_MAML算法比较简洁的一个复现,Pytorch版本

本文探讨了在PyTorch中复现MAML算法的三种方法,包括破坏封装、使用nn.functional以及手动计算梯度。重点在于解决参数反传问题和元素相乘在计算图中的应用。通过手动计算,可以简化代码并避免复杂编程,同时介绍了如何处理叶子节点和计算梯度的过程。
摘要由CSDN通过智能技术生成

可以转载,请务必注明链接和作者名。花了好久的心血~

方法1 破坏封装

qwer在目前能找到的几个pytorch版本里面,大家都是用nn.Functional写的

这种实现方式很繁琐,在两次计算loss的时候都需要重新构建,要自己收东西写前向传播,还要自己初始化权重,(虽然用for循环可以避免)

因为pytorch中的网络参数不能继续反向传播下去(下面有详细解释)。

pytorch的作者写了一个小工具,破坏了module中parameters的封装,使得网络参数可以继续反传。所以就可以用sequential的方式来实现maml了,代码量也大大减少

但是目前仍存在问题,不能多卡并行和apex半精度加速。还不清楚问题出在哪里~~很烦~

有人想看的话,我传到github上,再把链接贴出来~

效果的话大概在MiniImagenet能跑到46(大致调了一下参数,已经差不多是48%, 和原作者的结果就差不多了。)

f99370d3bf60d2ffec1597993a5efe36.png

  • 2
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值