统计模型的参数量以及大小
以pytorch官方的resnet18为例。首先获取模型的参数,有一下两种方法:
named_parameters()
方法返回一个迭代器,用于同时访问模型中的参数和它们的名称。
parameters()
方法返回一个迭代器,用于访问模型中所有需要梯度更新的参数。
# model.named_parameters()
for name, param in model.named_parameters():
print(name, param.shape)
# 输出:
# conv1.weight torch.Size([64, 3, 7, 7])
# bn1.weight torch.Size([64])
# bn1.bias torch.Size([64])
# layer1.0.conv1.weight torch.Size([64, 64, 3, 3])
# model.parameters()
for params in model18.parameters():
print(params.shape)
# torch.Size([64, 3, 7, 7])
# torch.Size([64])
# torch.Size([64])
# torch.Size([64, 64, 3, 3])
计算模型的参数量以及大小
import torchvision.models as models
model18 = models.resnet18()
total_size_resnet18 = 0
total_num_resnet18 = 0
for param in model18.parameters():
param_size = torch.tensor(param.size()).prod() * param.element_size()
param_num = torch.tensor(param.size()).prod()
total_num_resnet18 += param_num.item()
total_size_resnet18 += param_size.item()
print(f"Total parameters size: {total_size_resnet18 / (1024 ** 2):.2f} MB")
print(f"Total parameters number: {total_num_resnet18 } 个")
# Total parameters size: 44.59 MB
# Total parameters number: 11689512 个
保存的’'.pth文件大小也说明了这一点,虽然不是完全一样大,但大小差不多,可能与状态字典多的键值有关。
resnet18的构造
layer 2, 3, 4层的downsample是对残差传过来的上一层处理结果进行的,kernel_size = 1, stride = 2,维度翻倍,尺寸减半。
ResNet(
(conv1): 3, 64, 7, 2, 3 # conv1.weight torch.Size([64, 3, 7, 7])
(bn1): # bn1.weight torch.Size([64]) / bn1.bias torch.Size([64])
(relu):
(maxpool):3, 2, 1
(layer1): Sequential(
(0): BasicBlock(
(conv1): 64, 64, 3, 1, 1 # layer1.0.conv1.weight torch.Size([64, 64, 3, 3])
(bn1):
(relu):
(conv2): 64, 64, 3, 1, 1
(bn2):
)
(1): BasicBlock(
(conv1): 64, 64, 3, 1, 1
(bn1):
(relu):
(conv2): 64, 64, 3, 1, 1
(bn2):
)
)
(layer2): Sequential(
(0): BasicBlock(
(conv1): 64, 128, 3, 2, 1
(bn1):
(relu):
(conv2): 128, 128, 3, 1, 1
(bn2):
(downsample): Sequential(
(0): 64, 128, 1, 2
(1):
)
)
(1): BasicBlock(
(conv1): 128, 128, 3, 1, 1
(bn1):
(relu):
(conv2): 128, 128, 3, 1, 1
(bn2):
)
)
(layer3): Sequential(
(0): BasicBlock(
(conv1): 128, 256, 3, 2, 1
(bn1):
(relu):
(conv2): 256, 256, 3, 1, 1
(bn2):
(downsample): Sequential(
(0): 128, 256, 1, 2
(1):
)
)
(1): BasicBlock(
(conv1): 256, 256, 3, 1, 1
(bn1):
(relu):
(conv2): 256, 256, 3, 1, 1
(bn2):
)
)
(layer4): Sequential(
(0): BasicBlock(
(conv1): 256, 512, 3, 2, 1
(bn1):
(relu):
(conv2): 512, 512, 3, 1, 1
(bn2):
(downsample): Sequential(
(0): 256, 512, 3, 1
(1):
)
)
(1): BasicBlock(
(conv1): 512, 512, 3, 1, 1
(bn1):
(relu):
(conv2): 512, 512, 3, 1, 1
(bn2):
)
)
(avgpool): 1,1
(fc): 512, 1000
)