def get_parameter_number(net):
total_num = sum(p.numel() for p in net.parameters())
trainable_num = sum(p.numel() for p in net.parameters() if p.requires_grad)
return {'Total': total_num, 'Trainable': trainable_num}
Pytorch统计参数网络参数数量
最新推荐文章于 2024-09-07 19:23:02 发布