pytorch系列 ----暂时就叫5的番外吧: nn.Modlue及nn.Linear 源码理解

先看一个列子:

import torch
from torch import nn

m = nn.Linear(20, 30)
input = torch.randn(128, 20)
output = m(input)

output.size()

out:

torch.Size([128, 30])

刚开始看这份代码是有点迷惑的,m是类对象,而直接像函数一样调用m,m(input)

重点

  • nn.Module 是所有神经网络单元(neural network modules)的基类
  • pytorch在nn.Module中,实现了__call__方法,而在__call__方法中调用了forward函数。

经过以上两点。上述代码就不难理解。

接下来看一下源码:
https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/module.py

在这里插入图片描述

再来看一下nn.Linear
https://pytorch.org/docs/stable/_modules/torch/nn/modules/linear.html
主要看一下forward函数:

评论 14
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值