python遍历嵌套字典计算pytorch模型大小
import torch
from torch import nn
import sys,os
os.chdir(sys.path[0])
dict_size = 0
def dict_traverse(my_dict):
global dict_size
for key,value in my_dict.items():
if isinstance(value, dict):
dict_size += sys.getsizeof(key)
dict_size += sys.getsizeof(value)
dict_traverse(value)
else:
if isinstance(value,torch.Tensor):
print(key,sys.getsizeof(value.storage()),sep=' ')
dict_size += sys.getsizeof(value.storage())
else:
print(key,sys.getsizeof(value),sep=' ')
dict_size += sys.getsizeof(value)
return dict_size/1024/1024
if __name__ == '__main__':
model = torch.load('vgg16.pth')
name = model
print(dict_traverse(name))