对于网络模型未知且只有.pth参数的网络模型,下载.pth文件后,可采用以下方法统计网络模型参数。
1.模型由一个.pth文件组成
import torch
model_dict = torch.load("path_to_your_pth")
params_dict=model_dict['state_dict']
total_params = 0
for param_tensor in params_dict.values():
# 将当前参数的元素数(即参数大小)加到总和中
total_params += param_tensor.numel()
print(f"参数量约为:{total_params/1000000:.2f}M(百万个参数)。")
2.模型由多个.pth文件组成
import torch
model_name='your_model_name'
file_path=["path_to_your_pth1",
"path_to_your_pth2",
"path_to_your_pth3"]
total_params = 0
for i in file_path:
model_dict = torch.load(i)
params_dict=model_dict['state_dict']
for param_tensor in params_dict.values():
# 将当前参数的元素数(即参数大小)加到总和中
total_params += param_tensor.numel()
print(f"{model_name} 的参数量约为:{total_params/1000000:.2f}M(百万个参数)。")
对于需要直接从