小白学Pytorch系列–Torch.nn API Linear Layers(10)
方法 | 注释 |
---|---|
nn.Identity | 不区分参数的占位符标识操作符。 |
nn.Linear | 对传入数据应用线性转换 y = x A T + b y=xA^T+b y=xAT+b |
nn.Bilinear | 对传入数据应用双线性转换 y = x 1 T A x 2 + b y=x_1^T A x_2+b y=x1TAx2+b |
nn.LazyLinear | 一个torch.nn.Linear 模块,其中的特征被推断出来。 |
nn.Identity
>>> m = nn.Identity(54, unused_argument1=0.1, unused_argument2=False)
>>> input = torch.randn(128, 20)
>>> output = m(input)
>>> print(output.size())
torch.Size([128, 20])
nn.Linear
对传入数据应用线性转换
y
=
x
A
T
+
b
y=x A^T+b
y=xAT+b
>>> m = nn.Linear(20, 30)
>>> input = torch.randn(128, 20)
>>> output = m(input)
>>> print(output.size())
torch.Size([128, 30])
nn.Bilinear
>>> m = nn.Bilinear(20, 30, 40)
>>> input1 = torch.randn(128, 20)
>>> input2 = torch.randn(128, 30)
>>> output = m(input1, input2)
>>> print(output.size())
torch.Size([128, 40])