查看某一个网络层的参数量;
class model(nn.Module):
def __init__(self, in_features, bias):
self.linear_q = nn.Linear(in_features, in_features, bias)
def forward(self, x):
out = self.linear_q(x)
print('self.linear_q: ', sum(m.numel() for m in self.linear_q.parameters()))
return out