计算模型参数量,FLOPs方法汇总
第一类:模型一个入口一个出口。
if __name__=='__main__':
import torch
from thop import profile
model = IRWArt() # 替换为你的PyTorch模型
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print("Total parameters: {:.2f}M".format(total_params / 1e6))
print("Trainable parameters: {:.2f}M".format(trainable_params / 1e6))
image=torch.randn(1,24,64,64)
flops, params = profile(model, (image,))
print(f"GFLOPS: {flops}---params;{params}")
print("GFLOPS: {:.2f}G".format(flops / 1e9))
第二类:一个模型多个入口
例如 model下面有两个方法 encoder decoder 这时用上面的方法无法分别进行计算。
使用下面的方法进行计算。
- 计算参数量一般可用:
total_params = sum(p.numel() for p in model.decoder.parameters())
trainable_params = sum(p.numel() for p in model.decoder.parameters() if p.requires_grad)
print("Total parameters: {:.2f}M".format(total_params / 1e6))
print("Trainable parameters: {:.2f}M".format(trainable_params / 1e6))
- 计算GFLOPS可以使用下面的方法:
from torchsummary import summary
z_embed, _ = model(z, None, secret)
stego = model.decode_first_stage(z_embed) # 1, 3, 256, 256
profile_encoder = torchprofile.profile_macs(model.encode_first_stage, cover)
print(profile_encoder)
print("GFLOPS: {:.22f}G".format(profile_encoder / 1e9))
profile_encoder = torchprofile.profile_macs(model.decode_first_stage, z_embed)
print(profile_encoder)
print("GFLOPS: {:.2f}G".format(profile_encoder / 1e9))
- 如果想要分别计算不同编码器解码器的参数量可以使用下面的方法。
# 计算解码器的参数量和可训练参数量
decoder_summary = summary(model.decoder, input_size=stego.size())
decoder_total_params = decoder_summary.total_params # 解码器的参数量
decoder_trainable_params = sum(p.numel() for p in model.decoder.parameters() if p.requires_grad) # 解码器的可训练参数量
print("解码器的参数量: {:.2f} millions".format(decoder_total_params / 1e6))
print("解码器的可训练参数量: {:.2f} millions".format(decoder_trainable_params / 1e6))