计算nn.Linear的计算量FLOPs
import torchimport torch.nn as nnm = nn.Linear(20, 30)input = torch.randn(128, 3,20)output = m(input)print(output.size())flops = (torch.prod(torch.LongTensor(list(output.size()))) \ * input[0].size(1)).item()print((list(output.size())))pri
原创
2020-06-13 17:35:05 ·
3089 阅读 ·
2 评论