Pytorch框架学习记录10——线性层
torch.nn.Linear
(in_features, out_features, bias=True, device=None, dtype=None)
参数:
- in_features – 每个输入样本的大小
- out_features – 每个输出样本的大小
- bias——如果设置为
False
,该层将不会学习附加偏差。默认:True
import torch
from torch import nn
input = torch.tensor([[1, 2, 3],
[1, 0, 3],
[3, 5, 2]], dtype=torch.float32)
input = torch.flatten(input)
class Test(nn.Module):
def __init__(self):
super(Test, self).__init__()
self.fc = nn.Linear(in_features=9, out_features=3)
def forward(self, input):
output = self.fc(input)
return output
test = Test()
output = test(input)
print(output)