Torch-nn学习:Tabel Layer

本文介绍了Torch-nn库中的Table Layer,包括ConcatTable、ParallelTable、MapTable、SplitTable、JoinTable、NarrowTable、FlattenTable、PairwiseDistance、DotProduct、CosineDistance以及用于计算的CAddTable等操作。通过实例展示了这些模块如何处理输入数据,特别是它们在处理多输入和输出时的灵活性。
摘要由CSDN通过智能技术生成

1.ConcatTable:对每个成员模块应用相同输入。

如图:

                  +-----------+
             +----> {member1, |
+-------+    |    |           |
| input +----+---->  member2, |
+-------+    |    |           |
   or        +---->  member3} |
 {input}          +-----------+
示例:

mlp = nn.ConcatTable()
mlp:add(nn.Linear(5, 2))
mlp:add(nn.Linear(5, 3))

pred = mlp:forward(torch.randn(5))
for i, k in ipairs(pred) do print(i, k) end


2.ParallelTable:对每个成员模块应用与之对应的输入(第i个模块应用第i个输入)。

如图:

+----------+         +-----------+
| {input1, +---------> {member1, |
|          |         |           |
|  input2, +--------->  member2, |
|          |         |           |
|  input3} +--------->  member3} |
+----------+         +-----------+

mlp = nn.ParallelTable()
mlp:add(nn.Linear(10, 2))
mlp:add(nn.Linear(5, 3))

x = torch.randn(10)
y = torch.rand(5)

pred = mlp:forward{x, y}
for i, k in pairs(pred) do print(i, k) end


3.MapTable:对所有输入应用,不够的就clone。参数共享(weightbiasgradWeight and gradBias

eg:

+----------+         +-----------+
| {input1, +---------> {member,  |
|          |         |           |
|  input2, +--------->  clone,   |
|          |         |           |
|  input3} +--------->  clone}   |
+----------+         +-----------+
map = nn.MapTable()
map:add(nn.Linear(10, 3))

x1 = torch.rand(10)
x2 = torch.rand(10)
y = map:forward{x1, x2}

for i, k in pairs(y) do print(i, k) end



4.SplitTable:这个不用解释,index可以为负数

module = SplitTable(dimension, nInputDims)


                
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值