示例:
import torch
a = torch.randn(2,3,4)
print(a.numel()) # 24
统计model中所有可训练参数量:
num_params = sum(p.numel() for p in model.parameters())
注:numel() 是pytorch的函数,只适用于 tensor,不能用于统计 list、tuple、dict 等的元素数量。
示例:
import torch
a = torch.randn(2,3,4)
print(a.numel()) # 24
统计model中所有可训练参数量:
num_params = sum(p.numel() for p in model.parameters())
注:numel() 是pytorch的函数,只适用于 tensor,不能用于统计 list、tuple、dict 等的元素数量。
打赏作者