题目
'''
Description: Linear
Autor: 365JHWZGo
Date: 2021-12-15 15:17:55
LastEditors: 365JHWZGo
LastEditTime: 2021-12-15 16:06:23
'''
代码
import torch
import torch.nn as nn
class Demo(nn.Module):
def __init__(self,input_size,output_size):
super(Demo,self).__init__()
self.input_size = input_size
self.output_size = output_size
self.linear = nn.Linear(self.input_size,self.output_size)
def forward(self,inp,out):
output1 = self.linear(inp)
output2 = self.linear(out)
return output1,output2
if __name__ == '__main__':
input_size = 4
output_size =5
demo = Demo(input_size,output_size)
x = torch.randn(2,4)
y = torch.randn(3,4)
out1,out2 = demo(x,y)
print(out1)
print(out2)
运行结果
总结
Linear的作用是进行线性变化,实质上矩阵的乘法
y
=
W
∗
x
+
b
y = W*x+b
y=W∗x+b
W
是权重,b
是偏置
特征 | W | b |
---|---|---|
维度 | (out_features,in_features) | (out_features,1) |
特征 | x | y |
---|---|---|
维度 | (in_features,random_input_size) | (out_features,random_input_size) |
Linear类中init的定义
class Linear(Module):
def __init__(self, in_features: int, out_features: int, bias: bool = True,device=None, dtype=None) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
super(Linear, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.weight = Parameter(torch.empty((out_features, in_features), **factory_kwargs))
if bias:
self.bias = Parameter(torch.empty(out_features, **factory_kwargs))
else:
self.register_parameter('bias', None)
self.reset_parameters()
从中可以看出Linear是自己创造出W,bias是根据用户是否输入,如果为True,则
self.bias = Parameter(torch.empty(out_features, **factory_kwargs))
,否则self.register_parameter('bias', None)