先看一个列子:
import torch
from torch import nn
m = nn.Linear(20, 30)
input = torch.randn(128, 20)
output = m(input)
output.size()
out:
torch.Size([128, 30])
刚开始看这份代码是有点迷惑的,m是类对象,而直接像函数一样调用m,m(input)
重点:
- nn.Module 是所有神经网络单元(neural network modules)的基类
- pytorch在nn.Module中,实现了
__call__
方法,而在__call__
方法中调用了forward函数。
经过以上两点。上述代码就不难理解。
接下来看一下源码:
https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/module.py
再来看一下nn.Linear
https://pytorch.org/docs/stable/_modules/torch/nn/modules/linear.html
主要看一下forward函数: