def getModelSize(model):
param_size = 0
param_sum = 0
for param in model.parameters():
param_size += param.nelement() * param.element_size()
param_sum += param.nelement()
buffer_size = 0
buffer_sum = 0
for buffer in model.buffers():
buffer_size += buffer.nelement() * buffer.element_size()
buffer_sum += buffer.nelement()
all_size = (param_size + buffer_size) / 1024 / 1024
print('模型总大小为:{:.3f}MB,参数大小:{:.3f}MB,所有buffer的参数字节大小:{:.3f}MB'.format(all_size, param_size/ 1024 / 1024, buffer_size/ 1024 / 1024))
return (param_size, param_sum, buffer_size, buffer_sum, all_size)
打印深度学习模型大小
最新推荐文章于 2024-10-11 02:00:00 发布