PyTorch : nn.Linear() 详解

15 篇文章 1 订阅

线性转换:
在这里插入图片描述
举例:

input1 = torch.randn(128, 20)
input2 = torch.randn(128, 3, 20) #中间 * 可以添加任意维度
input3 = torch.randn(128, 3, 4, 20) 
m = nn.Linear(20, 30)
output1 = m(input1)
output2 = m(input2)
output3 = m(input3)
print(output1.size(), output2.size(), output3.size())
#
torch.Size([128, 30]) torch.Size([128, 3, 30]) torch.Size([128, 3, 4, 30])

中间 * 可以是任意维度,原理解释:

input2 = torch.randn(128, 3, 20)

m = nn.Linear(20, 30)
output2 = m(input2)

input3 = input2.reshape(128 * 3, 20)
output3 = m(input3)

print(output3 == output2.reshape(128 * 3, -1))
#
tensor([[1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1],
        ...,
        [1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1]], dtype=torch.uint8)

可见,将所有前面的维度相乘变为了二维矩阵,nn.Linear() 线性变换,也就是全连接层的变换。

PyTorch官方文档

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值