转载
核心:
moduleList
当添加 nn.ModuleList 作为 nn.Module 对象的一个成员时(即当我们添加模块到我们的网络时),所有 nn.ModuleList 内部的 nn.Module 的 parameter 也被添加作为 我们的网络的 parameter。
和普通list不一样,它和torch的其他机制结合紧密,继承了nn.Module的网络模型class可以使用nn.ModuleList并识别其中的parameters,而nn.ModuleList则没有顺序性要求,并且也没有实现forward()方法。
Sequential
nn.Sequential定义的网络中各层会按照定义的顺序进行级联,因此需要保证各层的输入和输出之间要衔接。并且nn.Sequential实现了farward()方法,因此可以直接通过类似于x=self.combine(x)的方式实现forward。这是二者之间的区别。
keras就直接一点:莫烦的Transformer
class Encoder(keras.layers.Layer):
def __init__(self, n_head, model_dim, drop_rate, n_layer):
super().__init__()
self.ls = [EncodeLayer(n_head, model_dim, drop_rate) for _ in range(n_layer)]
def call(self, xz, training, mask):
for l in self.ls:
xz = l.call(xz, training, mask)
return xz # [n, step, dim]