这个就类似于python中的list,不过里面存储的是nn.Module。里面的模型可以被遍历,在某些特殊场景下相当好用。
(比方说在GAT中,我们有多头机制,就可以使用nn.ModuleList)
看看官方文档:
我使用ModuleList写的GAT,因为GAT里面有多头,这个头的数量应该作为参数传入,所以代码如下:
import torch
import torch.nn as nn
import torch.nn.functional as F
from layer import GATLayer
from torch.nn import ModuleList
class GAT_NET(nn.Module):
def __init__(self, num_input, num_hidden, num_classes, num_heads=3, dropout=0.5):
super(GAT_NET, self).__init__()
# gat1中存储着num_heads个GATLayer层
self.gat1 = ModuleList(GATLayer(num_input, num_hidden) for _ in range(num_heads))
self.gat2 = GATLayer(num_hidden * num_heads, num_classes)
self.dropout = dropout
def forward(self, adj, H):
H = torch.cat([gat(adj, H) for gat in self.gat1], dim=1)
H = F.dropout(H, self.dropout, training=self.training)
H = self.gat2(adj, H)
return F.softmax(H, dim=1)