文章目录
一、nn.Linear
nn.Linear 是 PyTorch 中的一个类,用于定义线性变换(全连接层)。它是神经网络中常用的一种层类型,作为输入张量与权重矩阵之间的线性变换。
nn.Linear(in_features, out_features, bias=True)
参数说明:
- in_features:输入特征的大小,即输入张量的最后一维大小。
- out_features:输出特征的大小,即输出张量的最后一维大小。
- bias:是否使用偏置项,默认为 True,表示使用偏置项。
import torch
import torch.nn as nn
# 创建一个线性层,输入特征大小为 3,输出特征大小为 2
linear_layer = nn.Linear(3, 2)
# 输入张量
x = torch.tensor([[1, 2, 3],
[4, 5, 6]])
# 进行线性变换
output = linear_layer(x)
print(output)
tensor([[ 1.5323, -0.2660],
[ 4.5969, -1.0649]], grad_fn=<AddmmBackward>)