self.layers
是一个用于存储网络层的属性。它是一个 nn.ModuleList
对象,这是PyTorch中用于存储 nn.Module
子模块的特殊列表。
为什么使用 nn.ModuleList
?
在PyTorch中,当需要处理多个神经网络层时,通常使用 nn.ModuleList
或 nn.Sequential
。这些容器类能够确保其中包含的所有模块(层)都被正确注册,这样PyTorch就可以跟踪它们的参数,实现自动梯度计算和参数更新。
self.layers
的作用
class UserDefined(nn.Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout=0.):
super().__init__()
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
PreNorm(dim, Attention(dim, heads, dim_head, dropout)),
PreNorm(dim, FeedForward(dim, mlp_dim, dropout))
]))
def forward(self, x):
for attn, ff in self.layers:
x = attn(x) + x
x = ff(x)