pytorch中Linear类中weight的形状问题源码探讨
import torch
import torch.nn as nn
m=nn.Linear(50,30)
input=torch.randn(128,50)
output=m(input)
print(output.size())
print(m.weight.shape)
print(m.bias.shape)
output
torch.Size([128, 30])
torch.Size([30, 50])
torch.Size([30])
可以看出weight形状为[30,50]而不是[50,30]
Linear类的源码:
def __init__(self, in_features: int, out_features: int, bias: bool = True) -> None:
super(Linear, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.weight = Parameter(torch.Tensor(out_features, in_features))
if bias:
self.bias = Parameter(torch.Tensor(out_features))
else:
self.register_parameter('bias', None)
self.reset_parameters()
在Linear类中的__init__函数中,weight形状为[out_features, in_features]
参考:https://blog.csdn.net/dss_dssssd/article/details/83537765