"""copy from D3D net"""
if __name__ == "__main__":
net =yours net.cuda()
from thop import profile
input = torch.randn(1, 1, 7, 320, 180).cuda()
flops, params = profile(net, inputs=(input,))
total = sum([param.nelement() for param in net.parameters()])
print(' Number of params: %.2fM' % (total / 1e6))
print(' Number of FLOPs: %.2fGFLOPs' % (flops / 1e9))
计算网络参数量
最新推荐文章于 2023-10-30 20:35:32 发布