统计Tensor的元素个数
- 基本用例:
import torch
a = torch.zeros(4,4)
print(a)
# tensor([[0., 0., 0., 0.],
# [0., 0., 0., 0.],
# [0., 0., 0., 0.],
# [0., 0., 0., 0.]])
print(a.numel()) # 16
print(torch.numel(a)) # 16
- 统计网络中的可训练参数个数:
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)