def parameter_count(module):
trainable, non_trainable = 0, 0
for p in module.parameters():
if p.requires_grad:
trainable += p.numel()
else:
non_trainable += p.numel()
return trainable, non_trainable
pytorch统计模型参数数量
最新推荐文章于 2024-03-29 17:55:31 发布