在 PyTorch 中,“column-wise” 可以通过 torch.stack
函数来实现,“row-wise” 可以通过 torch.cat
函数来实现。
以下是一个简单的示例,说明如何将模块堆叠在一起,行和列的方式不同:
import torch.nn as nn
import torch
# 定义一个简单的模块类
class MyModule(nn.Module):
def __init__(self, in_dim, out_dim):
super(MyModule, self).__init__()
self.linear = nn.Linear(in_dim, out_dim)
self.relu = nn.ReLU()
def forward(self, x):
x = self.linear(x)
x = self.relu(x)
return x
# 定义一组模块
module_list = [MyModule(10, 20) for i in range(5)]
# 将这些模块以列的方式堆叠起来
column_wise_module = nn.Sequential(*module_list)
# 将这些模块以行的方式堆叠起来
row_wise_module = nn.Sequential()
for module in module_list:
row_wise_module.add_module(str(len(row_wise_module)), module)
# 将输入数据通过两种方式的模块分别执行
x = torch.randn(5, 10)
y1 = column_wise_module(x)
y2 = row_wise_module(x)
# 比较两种方式的输出是否相同
print(torch.all(y1.eq(y2)))
在上面的代码中,module_list
中包含了 5 个相同的 MyModule
实例,在 column_wise_module
中,这 5 个实例以列的方式堆叠在一起,形成了一个包含 5 层的神经网络;在 row_wise_module
中,这些实例以行的方式堆叠,形成了一个同样包含 5 层的神经网络。两种方式训练出来的模型都是一样的,只不过其中参数的排列方式不同。