对 torch.nn.Linear 的理解

本文详细介绍了PyTorch中的torch.nn.Linear层,该层用于实现线性变换,不仅限于常见的二维输入,而是可以接受任意维度的张量,只要最后一维对应`in_features`。通过一个实例展示了即使输入张量为三维,依然能正确应用Linear层,并得到期望的输出形状。这表明Linear层的灵活性,可以在各种复杂的神经网络结构中发挥作用。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

torch.nn.Linear 是 pytorch 的线性变换层,定义如下:

Linear(in_features: int, out_features: int, bias: bool = True, device: Any | None = None, dtype: Any | None = None)

全连接层 Fully Connect 一般就就用这个函数来实现。因此在潜意识里,变换的输入张量的 shape 为 (batchsize, in_features),输出张量的 shape 为 (batchsize, out_features)。

当然这是常用的方式,但是 Linear 的输入张量的维度其实并不需要必须为上述的二维,多维也是完全可以的,Linear 仅是对输入的最后一维做线性变换,不影响其他维。

可以看下官网的解释:Linear — PyTorch 1.11.0 documentation

一个例子如下:

import torch
input = torch.randn(30, 20, 10)  # [30, 20, 10]
linear = torch.nn.Linear(10, 15)  # (*, 10) --> (*, 15)
output = linear(input)
print(output.size()) # 输出 [30, 20, 15]

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

地球被支点撬走啦

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值