最新做一些模型压缩的研究,需要自己手动实现nn.Linear(),而不是直接调用Linear()去进行forword()
(一)nn.Linear()的官方说明
先查看Linear()的官方说明如下:
torch.nn — PyTorch master documentation
可以看到内部有两个变量,weight和bias,这是待会实现的关键。
正常使用相信大家都不陌生,直接定义nn.Linear(256,512)这样,但是要注意的是输入的参数顺序是in_features和out_features
(二)F.linear()的说明
F是常见的torch.nn.functional,为什么要说明这个呢?实际上真正的linear()函数是这里的F.linear(),nn.Linear()只是对F.linear()做了封装。
先看F.linear()的官方说明:
torch.nn.functional — PyTorch master documentation
这里的参数列表看起来就非常符合linear的逻辑了,输入张量input,用weight进行矩阵权重计算,bias作为偏执,但要注意weight的说明(out_features,in_features),即out在前,待会后面自定义别搞错了
(三)nn.Linear()的源码
跟踪nn.Linear()的源码,正如,上面分析的一样,内部调用了F.linear(),在init中定义了weight和bias
这里的定义使用了Parameter,来看官方解释:
torch.nn — PyTorch master documentation
其实说白了,就是用来定义网络参数的。我们一张图片过linear层进行参数更新,就是这里的Parameter在变化。
(四)自己实现
根据上面的分析,如果要自己实现nn.Linear(),需要先定义weight和bias,也即在init中:
一定要注意这里的参数,out_features和in_features对应正确
代码:
# 原先的linear
# self.linear = nn.Linear(d_model, dim_feedforward)
# 自定义weight和bias
self.weight = nn.Parameter(torch.empty((dim_feedforward, d_model)))
self.bias = nn.Parameter(torch.empty(dim_feedforward))
然后对应的forward中,调用F.linear():
其中src就是图片的张量
代码:
src = F.linear(src, self.weight, self.bias)
(五)总结
主要注意两点:
1 nn.Linear和F.linear、init和forward的函数的理解
2 out_features和in_features对应正确