pip install torch-summary
pip install thop
e.g.
stmap = ubfc_dataset[0]["stmap"].unsqueeze(dim=0).cuda()
model = BVP_estimator().cuda()
# summary
from torchsummary import summary
summary(model,stmap)
# 计算flops
from thop import profile
flops, params = profile(model, inputs=stmap)
print(f"Total FLOPs: {flops / 10**9} G")
print(f"Total params: {params / 10**6} M")
最好将 profile.py 中的
with torch.no_grad():
model(*inputs)
改为
with torch.no_grad():
model(inputs)
否则可能吞维度