计算nn.Linear的计算量FLOPs

import torch
import torch.nn as nn
m = nn.Linear(20, 30)
input = torch.randn(128, 3,20)
output = m(input)
print(output.size())
flops = (torch.prod(torch.LongTensor(list(output.size()))) \
         * input[0].size(1)).item()
print((list(output.size())))
print(input[0].size(1))
print(flops)

官网对nn.Linear的介绍:

Applies a linear transformation to the incoming data:
y = x A T + b y = xA^T + b y=xAT+b
其函数签名为
torch.nn.Linear(in_features, out_features, bias=True)
其参数为:

  • in_features - size of each input sample
  • out_features -size of each output sample
  • bias If set to false, the layer will not learn an additive bias.
    其输入输出的维度为:
  • Input: ( N , ∗ , H i n ) (N, *, H_{in}) (N,,Hin), where *∗ means any number of additional dimensions and H i n H_{in} Hin=in_features
  • Output: ( N , ∗ , H o u t ) (N, *, H_{out}) (N,,Hout), where all but the last dimension are the same shape as the input and H o u t H_{out} Hout=out_features

故代码运行得到的计算量FLOPS为 128×3×20×30 = 230400,即计算量FLOPs为batch_size×特征维度值×输入特征数×输出特征数。

评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值