自注意力机制和全连接的图卷积网络(GCN)有什么区别联系?

 
 

点击上方“小白学视觉”,选择加"星标"或“置顶

重磅干货,第一时间送达

50bc74d405e8b74b07bf601533b5f9a0.png

本文整理自知乎问答,仅用于学术分享,著作权归作者所有。如有侵权,请联系后台作删文处理。

观点一

作者|Guohao Li

https://www.zhihu.com/question/366088445/answer/1023290162

来说一下自己的理解。

首先结论是大部分GCN和Self-attention都属于Message Passing(消息传递)。GCN中的Message从节点的邻居节点传播来,Self-attention的Message从Query的Key-Value传播来。

d1be872a45132552c1a52810cc144b8b.jpeg

Message Passing[4]

先看看什么是Message Passing。我们知道在实现和设计GCN的时候很多时候都是采用Message Passing的框架[3],其思想是把每个节点的领域的特征信息传递到节点上。在这里举例描述一个节点i在第k层GCN卷积的过程:

1)把节点i的每一个邻居j与该节点的特征经过函数变换后形成一条Message(对应公示里函数\phi里面的操作);

2)经过一个Permutation Invariant(置换不变性)函数把该节点领域的所有Message聚合在一起(对应函数\square);

3)再经过函数\gamma把聚合的领域信息和节点特征做一次函数变化,得到该节点在第k层图卷积后的特征X_i。

那么Self-attention是否也落在Message Passing的框架内呢?我们先回顾一下Self-attention一般是怎么计算的[2],这里举例一个Query i的经过attention的计算过程:

1】Query i的特征x_i会和每一个Key j的特征计算一个相似度e_ij;

3b773761de77c4e316c746d5f06e0617.png

2】得到Query i与所有Key的相似度后经过SoftMax得到Attention coefficient(注意力系数)\alpha_ij;

4e71cbe9b5d68a9d82cd0a9f867f0fd2.jpeg

3】通过Attention coefficient加权Value j计算出Query i最后的输出z_j。

955cc9057db959ba17d709c1d38556ac.png

好了,那么我们来看看它们之间的对应关系。首先结论是Self-attention计算中的1】2】3】是对应Message Passing里的1)2)的。

如果用Message Passing来实现Self-attention,那么我们可以这么一一对应:

-1 每个Key-Value j可以看作是Query i的邻居;

-2 相似度和注意力系数的计算和最后3】中Value j与注意力系数相乘的操作可以对应为Message Passing中第一步构成Message的过程;

-3 最后Self-attention的求和运算对应Message Passing中第二步的Permutation Invariant函数,也就是说这里聚合领域信息的过程是通过Query对Key-Value聚合而来。

那么也就是说,Attention的过程是把每一个Query和所有Key相连得到一个Complete Bipartite Graph(左边是Query右边的Key-Value),然后在这图上去对所有Query节点做Message Passing。当然Query和Key-Value一样的Self-attention就是在一般的Complete Graph上做Message Passing了。

4f549b1f06305490e2636ffc962d7a30.jpeg

Complete Bipartite Graph

看到这里大家可能疑问那么为什么Self attention里面没有了Message Passing中第三步把聚合的信息和节点信息经过\gamma函数做变换的过程呢。是的,如果没有了这一步很可能学习过程中Query的原来特征会丢失,其实这一步在Attention is all your need[1]里还是有的,不信你看:

bb2e922efaf449a7213738980be286cc.jpeg

在每一次经过Self-Attention之后基本上都是有Skip connection+MLP的,这里某种程度上对应了Message Passing里的\gamma函数不是吗?

那么说白了GCN和Self-attention都落在Message Passing(消息传递)框架里。GCN中的Message从节点的邻居节点传播来,Self-attention的Message从Query的Key-Value传播来。如果称所有的Message Passing函数都是GCN的话,那么Self-attention也就是GCN作用Query和Key-Value所构成Complete Garph上的一种特例。也正如乃岩@Naiyan Wang的回答一样。

可以说NLP中GCN应该大有可为,毕竟Self-attention可以看出是GCN一种,那么肯定存在比Self-attention表达能力更强和适用范围更广的GCN。

感谢评论里 @叶子豪的补充,DGL团队写了个很详尽的用Message Passing实现Transformer的教程。对具体实现感兴趣的同学可以去读一下:DGL Transformer Tutorial。

Reference:

1. Attention is All You Need

2. Self-Attention with Relative Position Representations

3. Pytorch Geometric

4. DeepGCNs for Representation Learning on Graphs

观点二

作者|Houye

https://www.zhihu.com/question/366088445/answer/1022692208

来说一下自己的理解。

GAT中的Attention就是self-attention,作者在论文中已经说了

6524c0fa92eee6ebde833d8483b7212c.jpeg

下面说说个人理解:

GNN,包括GCN,是一种聚合邻居信息来更新节点表示的神经网络模型。下图取自GraphSAGE,感觉比较好的说明了聚合过程。这里我们只关注图2:红色的点u是我们最终需要关注的点,3个蓝色的点{v1,v2,v3}是红色点的一阶邻居。每次更新在红色节点表示的时候,GNN都会收集3个蓝色点的信息并将其聚合,然后通过神经网络来更新红色节点的表示。这里神经网络可以是一个mean-pooling,也是对邻居进行平均,这时候 v1,v2,v3的权重都是1/3。

b84cf7d7f359f91d1575d462f7635244.jpeg

那这里就有问题了,3个蓝色的点都是邻居,直观的想,不同邻居对于红色的点的重要性是不同的。那么,在GNN聚合邻居的时候能不能考虑到邻居的重要性来加权聚合呢(比如,0.8v1+0.19v2+0.01v3)?手动加权肯定是不实际的。虽然加权感觉好一些,但是不加权也是可以做GNN的,在有些数据集上,不加权的效果甚至更好。

个人感觉在深度学习领域,“加权=attention”。我们这里可以设计一种attention机制来实现对邻居的加权。这里的权重可以理解为边的权重,是针对于一对节点来说的(比如u和v1).

那这里为啥是self-attention,因为GNN在聚合的时候会把自身也当做邻居。也就是说,上图中u的邻居集合实际是{u,v1,v2,v3}。这是很自然的,邻居的信息只能算是补充信息,节点自身的信息才是最重要的。

现在问题转化成了:给定{u,v1,v2,v3}作为输入,如何将u更好的表示出来?这个就很像NLP里面的self-attention了,见下图(引自川陀学者:Attention机制详解(二)——Self-Attention与Transformer)

a2b3e9e4ce2050d0c15505a98de97811.jpeg

最后总结一下:GCN和self-attention甚至attention都没有必然联系。对邻居加权来学习更好的节点表示是一个可选项。

 
 

好消息!

小白学视觉知识星球

开始面向外开放啦👇👇👇

 
 

6337af973b236bbf458d34590f1be844.jpeg

下载1:OpenCV-Contrib扩展模块中文版教程

在「小白学视觉」公众号后台回复:扩展模块中文教程,即可下载全网第一份OpenCV扩展模块教程中文版,涵盖扩展模块安装、SFM算法、立体视觉、目标跟踪、生物视觉、超分辨率处理等二十多章内容。


下载2:Python视觉实战项目52讲
在「小白学视觉」公众号后台回复:Python视觉实战项目,即可下载包括图像分割、口罩检测、车道线检测、车辆计数、添加眼线、车牌识别、字符识别、情绪检测、文本内容提取、面部识别等31个视觉实战项目,助力快速学校计算机视觉。


下载3:OpenCV实战项目20讲
在「小白学视觉」公众号后台回复:OpenCV实战项目20讲,即可下载含有20个基于OpenCV实现20个实战项目,实现OpenCV学习进阶。


交流群

欢迎加入公众号读者群一起和同行交流,目前有SLAM、三维视觉、传感器、自动驾驶、计算摄影、检测、分割、识别、医学影像、GAN、算法竞赛等微信群(以后会逐渐细分),请扫描下面微信号加群,备注:”昵称+学校/公司+研究方向“,例如:”张三 + 上海交大 + 视觉SLAM“。请按照格式备注,否则不予通过。添加成功后会根据研究方向邀请进入相关微信群。请勿在群内发送广告,否则会请出群,谢谢理解~
  • 0
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
以下是一个使用Pytorch Geometric (PyG)库训练卷积网络(GCN)进行骨骼识别的代码示例: ```python import torch import torch.nn.functional as F from torch_geometric.datasets import Human36M from torch_geometric.nn import GCNConv # 加载数据集 train_dataset = Human36M(root='/path/to/dataset', train=True) test_dataset = Human36M(root='/path/to/dataset', train=False) # 定义卷积网络模型 class GCN(torch.nn.Module): def __init__(self): super(GCN, self).__init__() self.conv1 = GCNConv(54, 128) # 第一层GCN卷积 self.conv2 = GCNConv(128, 128) # 第二层GCN卷积 self.fc1 = torch.nn.Linear(128, 64) # 全连接层 self.fc2 = torch.nn.Linear(64, 32) # 全连接层 self.fc3 = torch.nn.Linear(32, 17) # 全连接层 def forward(self, x, edge_index): # x: 特征向量 # edge_index: 的邻接矩阵 x = F.relu(self.conv1(x, edge_index)) # GCN卷积层1 x = F.relu(self.conv2(x, edge_index)) # GCN卷积层2 x = F.relu(self.fc1(x)) # 全连接层1 x = F.relu(self.fc2(x)) # 全连接层2 x = self.fc3(x) # 全连接层3, 输出17维向量 return x # 实例化模型 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = GCN().to(device) # 定义损失函数和优化器 criterion = torch.nn.MSELoss() optimizer = torch.optim.Adam(model.parameters(), lr=0.001) # 训练模型 model.train() for epoch in range(50): train_loss = 0.0 for batch in train_dataset: x, edge_index, y = batch.x.to(device), batch.edge_index.to(device), batch.y.to(device) optimizer.zero_grad() out = model(x, edge_index) loss = criterion(out, y) loss.backward() optimizer.step() train_loss += loss.item() * batch.num_graphs train_loss /= len(train_dataset) print('Epoch: {:03d}, Train Loss: {:.7f}'.format(epoch, train_loss)) # 测试模型 model.eval() test_loss = 0.0 for batch in test_dataset: x, edge_index, y = batch.x.to(device), batch.edge_index.to(device), batch.y.to(device) out = model(x, edge_index) loss = criterion(out, y) test_loss += loss.item() * batch.num_graphs test_loss /= len(test_dataset) print('Test Loss: {:.7f}'.format(test_loss)) ``` 在这个示例中,我们使用了Human3.6M数据集进行骨骼识别。该数据集包含了大量的人体骨骼姿态数据,每个姿态由17个关键点组成。我们使用GCN对每个关键点进行分类,输出17维向量,每一维代表一个关键点的分类得分。我们使用均方误差(MSE)作为损失函数,使用Adam优化器进行优化。在训练过程中,我们使用了50个epoch进行训练,每个epoch中遍历整个训练集。在测试过程中,我们仅仅计算了测试集上的损失,没有进行预测。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值