from thop import profile
input = torch.randn(1, 1, 64, 64).cuda()
model = GroupMixFormer().to('cuda')
model.eval()
output = model(input)
flops, params = profile(model, inputs=(input,))
print("input_shape:", input.size())
print("output_shape", output.size())
print(' Number of parameters: %.2fM' % (params / 1e6))
print(' Number of FLOPs: %.2fG' % (flops * 2 / 1e9))
查模型的参数Param和浮点计算量FLOPs
于 2022-05-27 12:46:15 首次发布