def get_parameter_number(net): total_num = sum(p.numel() for p in net.parameters()) trainable_num = sum(p.numel() for p in net.parameters() if p.requires_grad) return {'Total': total_num, 'Trainable': trainable_num}
Pytorch统计参数网络模型的总参数量和待学习参数量
最新推荐文章于 2022-10-26 12:30:24 发布