from torchsummary import summary
summary(net, input_size=(3, 256, 256), batch_size=-1)
输出的参数是除以一百万(/1000000)M,
from fvcore.nn import FlopCountAnalysis
inputs = torch.randn(1, 3, 256, 256).cuda()
flop_counter = FlopCountAnalysis(net, inputs)
print(f"FLOPs: {flop_counter.total()}")
输出的参数是B,(/1024/1024/1024)G,(/1024/1024/1024/1024)T