手动重新实现Linear()

最新做一些模型压缩的研究,需要自己手动实现nn.Linear(),而不是直接调用Linear()去进行forword()

(一)nn.Linear()的官方说明

先查看Linear()的官方说明如下:

torch.nn — PyTorch master documentation

可以看到内部有两个变量,weight和bias,这是待会实现的关键。

正常使用相信大家都不陌生,直接定义nn.Linear(256,512)这样,但是要注意的是输入的参数顺序是in_features和out_features

(二)F.linear()的说明

F是常见的torch.nn.functional,为什么要说明这个呢?实际上真正的linear()函数是这里的F.linear(),nn.Linear()只是对F.linear()做了封装

先看F.linear()的官方说明:

torch.nn.functional — PyTorch master documentation

 这里的参数列表看起来就非常符合linear的逻辑了,输入张量input,用weight进行矩阵权重计算,bias作为偏执,但要注意weight的说明(out_features,in_features),即out在前,待会后面自定义别搞错了

(三)nn.Linear()的源码

 跟踪nn.Linear()的源码,正如,上面分析的一样,内部调用了F.linear(),在init中定义了weight和bias

这里的定义使用了Parameter,来看官方解释:

torch.nn — PyTorch master documentation

 其实说白了,就是用来定义网络参数的。我们一张图片过linear层进行参数更新,就是这里的Parameter在变化。

(四)自己实现

根据上面的分析,如果要自己实现nn.Linear(),需要先定义weight和bias,也即在init中:

一定要注意这里的参数,out_features和in_features对应正确

代码:

        # 原先的linear
        # self.linear = nn.Linear(d_model, dim_feedforward)
        # 自定义weight和bias
        self.weight = nn.Parameter(torch.empty((dim_feedforward, d_model)))
        self.bias = nn.Parameter(torch.empty(dim_feedforward))

然后对应的forward中,调用F.linear():

其中src就是图片的张量

代码:

src = F.linear(src, self.weight, self.bias)

(五)总结

主要注意两点:

1  nn.Linear和F.linear、init和forward的函数的理解

2  out_features和in_features对应正确

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值