torch_geometric.nn.conv. MessagePassing( )
继承这个类,可以自定义节点信息传播机制
例子
import torch
from torch_geometric.utils import add_self_loops
from torch_geometric.nn.conv import MessagePassing
class GCNConv(MessagePassing):
def __init__(self):
#选择相加的方式进行邻居节点信息聚合
super().__init__(aggr='add')
def forward(self, x, edge_index):
#给图添加自环
edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))
print(edge_index)
out = self.propagate(edge_index, x=x)
print(out)
def message(self, x_j):
print(x_j)
return x_j
edge_index = torch.tensor([[0, 1],
[1, 0],
[1, 2],
[2, 1]], dtype=torch.long)
x = torch.tensor([[-1], [6], [1]], dtype=torch.float)
edge_index = edge_index.permute(1, 0)
model = GCNConv( )
out = model(x, edge_index)
运行截图
MessagePassing的运行机制就是用行坐标[0,1,1,2,0,1,2]计算每个节点要汇聚的feature,然后用列坐标[1,0,2,1,0,1,2]进行 add 聚合信息,x_j其实就是根据行坐标得来的,行坐标里面的每一个元素其实就是一个节点标号,它告诉我们当前聚合信息时,每一个节点的信息应该是怎么样,在这里我没有转换节点feature,直接就是初始feature进行聚合,然后列坐标的元素进行聚合,如:列坐标中0节点与行节点对应的元素为1,0,所以在x_j对应位置找到元素6,-1然后相加得5,同理,1节点为-1+1+6 =6,2节点为6+1=7;
需要注意的是def message(self, x_j)中x_j的参数名字不能随便改变,不然会出错;其实x_j可以变为x_i,x_i代表以列坐标[0,1,1,2,0,1,2]计算每个节点要汇聚的feature,但仍以列坐标[0,1,1,2,0,1,2]汇聚信息;最后得到的结果如下:
也可以def message(self, x_j,x_i) ,其中x_j,x_i同时返回,可以根据具体应用进行灵活操作;