Shift-GCN中Shift的实现细节笔记,通过torch.index_select实现

Shift-GCN中Shift的实现细节笔记,通过torch.index_select实现
FesianXu 20201112 at UESTC

前言

近期在看Shift-GCN的论文[1],该网络是基于Shift卷积算子[2]在图结构数据上的延伸。在阅读源代码[3]的时候发现了其对于Non-Local Spatial Shift Graph Convolution有意思的实现方法,在这里简要记录一下。如有谬误请联系指出,转载请联系作者并注明出处,谢谢

∇ \nabla 联系方式:

e-mail: FesianXu@gmail.com

QQ: 973926198

github: https://github.com/FesianXu

知乎专栏: 计算机视觉/计算机图形理论与应用

微信公众号
qrcode


在讨论代码本身之前,简要介绍下Non-Local Spatial Shift Graph Convolution的操作流程,具体介绍可见博文[1]。对于一个时空骨骼点序列而言,如Fig 1所示,将单帧的骨骼点图视为是完全图,因此任何一个节点都和其他所有节点有所连接,其shift卷积策略为:

对于一个特征图 F ∈ R N × C \mathbf{F} \in \mathbb{R}^{N \times C} FRN×C而言,其中 N N N是骨骼点数量, C C C是特征通道数。对于第 i i i个通道的shift距离为 i   m o d   N i \bmod N imodN

non_local_spatial_shift

Fig 1. 在全局空间Shift图卷积中,将骨骼点图视为是完全图,其shift策略因此需要考虑本节点与其他所有节点之间的关系。

根据这种简单的策略,如Fig 1所示,形成了类似于螺旋上升的特征图样。那么我们要如何用代码描绘这个过程呢?作者公开的源代码给予了我们一种思路,其主要应用了pytorch中的torch.index_select函数。先简单介绍一下这个函数。

torch.index_select()是一个用于索引给定张量中某一个维度中某些特定索引元素的方法,其API手册如:

torch.index_select(input, dim, index, out=None) → Tensor
Parameters:	
	input (Tensor) – 输入张量,需要被索引的张量
	dim (int) – 在某个维度被索引
	index (LongTensor) – 一维张量,用于提供索引信息
	out (Tensor, optional) – 输出张量,可以不填

其作用很简单,比如我现在的输入张量为1000 * 10的尺寸大小,其中1000为样本数量,10为特征数目,如果我现在需要指定的某些样本,比如第1-100,300-400等等样本,我可以用一个index进行索引,然后应用torch.index_select()就可以索引了,例子如:

>>> x = torch.randn(3, 4)
>>> x
tensor([[ 0.1427,  0.0231, -0.5414, -1.0009],
        [-0.4664,  0.2647, -0.1228, -1.1068],
        [-1.1734, -0.6571,  0.7230, -0.6004]])
>>> indices = torch.tensor([0, 2])
>>> torch.index_select(x, 0, indices) # 按行索引
tensor([[ 0.1427,  0.0231, -0.5414, -1.0009],
        [-1.1734, -0.6571,  0.7230, -0.6004]])
>>> torch.index_select(x, 1, indices) # 按列索引
tensor([[ 0.1427, -0.5414],
        [-0.4664, -0.1228],
        [-1.1734,  0.7230]])

注意到有一个问题是,pytorch似乎在使用GPU的情况下,不检查index是否会越界,因此如果你的index越界了,但是报错的地方可能不在使用index_select()的地方,而是在后续的代码中,这个似乎就需要留意下你的index了。同时,index是一个LongTensor,这个也是要留意的。

我们先贴出主要代码,看看作者是怎么实现的:

class Shift_gcn(nn.Module):
    def __init__(self, in_channels, out_channels, A, coff_embedding=4, num_subset=3):
        super(Shift_gcn, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        if in_channels != out_channels:
            self.down = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, 1),
                nn.BatchNorm2d(out_channels)
            )
        else:
            self.down = lambda x: x

        self.Linear_weight = nn.Parameter(torch.zeros(in_channels, out_channels, requires_grad=True), requires_grad=True)

        self.Linear_bias = nn.Parameter(torch.zeros(1,1,out_channels,requires_grad=True),requires_grad=True)

        self.Feature_Mask = nn.Parameter(torch.ones(1,25,in_channels, requires_grad=True),requires_grad=True)

        self.bn = nn.BatchNorm1d(25*out_channels)
        self.relu = nn.ReLU()
        index_array = np.empty(25*in_channels).astype(np.int)
        for i in range(25):
            for j in range(in_channels):
                index_array[i*in_channels + j] = (i*in_channels + j + j*in_channels) % (in_channels*25)
        self.shift_in = nn.Parameter(torch.from_numpy(index_array),requires_grad=False)

        index_array = np.empty(25*out_channels).astype(np.int)
        for i in range(25):
            for j in range(out_channels):
                index_array[i*out_channels + j] = (i*out_channels + j - j*out_channels) % (out_channels*25)
        self.shift_out = nn.Parameter(torch.from_numpy(index_array),requires_grad=False)
        

    def forward(self, x0):
        n, c, t, v = x0.size()
        x = x0.permute(0,2,3,1).contiguous()
        # n,t,v,c
        # shift1
        x = x.view(n*t,v*c)
        x = torch.index_select(x, 1, self.shift_in)
        x = x.view(n*t,v,c)
        x = x * (torch.tanh(self.Feature_Mask)+1)

        x = torch.einsum('nwc,cd->nwd', (x, self.Linear_weight)).contiguous() # nt,v,c
        x = x + self.Linear_bias

        # shift2
        x = x.view(n*t,-1) 
        x = torch.index_select(x, 1, self.shift_out)
        x = self.bn(x)
        x = x.view(n,t,v,self.out_channels).permute(0,3,1,2) # n,c,t,v

        x = x + self.down(x0)
        x = self.relu(x)
        # print(self.Feature_Mask.shape)
        return x

我们把forward()里面的分为三大部分,分别是:1> shift_in操作;2> 卷积操作;3> shift_out操作;其中指的shift_inshift_out只是shift图卷积算子的不同形式而已,其主要是一致的。整个结构图如Fig 2(c)所示。

conv

Fig 2. Shift-Conv-Shift模组需要两个shift操作,代码中称之为shift_in和shift_out。

其中的卷积操作代码由爱因斯坦乘积[4]形式表示,其实本质上就是一种矩阵乘法,其将 x ∈ R N × W × C \mathbf{x} \in \mathbb{R}^{N \times W \times C} xRN×W×C W ∈ R C × D \mathbf{W} \in \mathbb{R}^{C \times D} WRC×D矩阵相乘,得到输出张量为 O ∈ R N × W × D \mathbf{O} \in \mathbb{R}^{N \times W \times D} ORN×W×D

x = torch.einsum('nwc,cd->nwd', (x, self.Linear_weight)).contiguous() # nt,v,c
x = x + self.Linear_bias

而进行的掩膜操作代码如下所示,这代码不需要太多仔细思考。

x = x * (torch.tanh(self.Feature_Mask)+1)

那么我们着重考虑以下的代码:

x = x.view(n*t,v*c)
x = torch.index_select(x, 1, self.shift_in)
x = x.view(n*t,v,c)

第一行代码将特征图展开,如Fig 3所示,得到了 25 × C 25 \times C 25×C大小的特征向量。通过torch.index_select对特征向量的不同分区进行选择得到最终的输出特征向量,选择的过程如Fig 4所示。

flatten

Fig 3. 将特征图进行拉平后得到特征向量。

那么可以知道,对于某个关节点 i i i而言,给定通道 j j j,当遍历不同通道时,会存在一个 C C C周期,因此是 ( j + j × C ) (j+j\times C) (j+j×C),比如对于第0号节点的第1个通道,其需要将 ( 1 + 1 × C ) (1+1\times C) (1+1×C)的值移入,如Fig 4的例子所示。而第2个通道则是需要考虑将 ( 2 + 2 × C ) (2+2\times C) (2+2×C)的值移入,我们发现是以 C C C为周期的。这个时候假定的是关节点都是同一个的时候,当遍历关节点时,我们最终的索引规则是 ( i × C + j × C + j ) (i\times C + j\times C + j) (i×C+j×C+j),因为考虑到了溢出的问题,因此需要求余,有 ( i × C + j × C + j )   m o d   ( 25 × C ) (i \times C + j \times C + j) \bmod (25 \times C) (i×C+j×C+j)mod(25×C)。这个对应源代码的第23-32行,如上所示。

shift_vector

Fig 4. 将特征图拉直之后的shift操作示意图,因此需要寻找一种特殊的索引规则,以将特征图shift问题转化为特征向量的shift问题。

在以这个举个代码例子,例子如下所示:

import numpy as np
import torch
array = np.arange(0,15).reshape(3,5)
array = torch.tensor(array)
index = np.zeros(15)
for i in range(3):
    for j in range(5):
        index[i*5+j] = (i*5+j*5+j) % (15)
index = torch.tensor(index).long()
out = torch.index_select(array.view(1,-1), 1, index).view(3,5)
print(array)
print(out)

输出为:

tensor([[ 0,  1,  2,  3,  4],
        [ 5,  6,  7,  8,  9],
        [10, 11, 12, 13, 14]])
tensor([[ 0,  6, 12,  3,  9],
        [ 5, 11,  2,  8, 14],
        [10,  1,  7, 13,  4]])

我们把这种正向移入的称之为shift-in,反过来移入则称之为shift-out,其索引公式有一点小变化,为: ( i × C − j × C + j )   m o d   ( 25 × C ) (i \times C - j \times C + j) \bmod (25 \times C) (i×Cj×C+j)mod(25×C)。代码例子如下:

import numpy as np
import torch
array = np.arange(0,15).reshape(3,5)
array = torch.tensor(array)
index = np.zeros(15)
for i in range(3):
    for j in range(5):
        index[i*5+j] = (i*5-j*5+j) % (15)
index = torch.tensor(index).long()
out = torch.index_select(array.view(1,-1), 1, index).view(3,5)
print(array)
print(out)

输出为:

tensor([[ 0,  1,  2,  3,  4],
        [ 5,  6,  7,  8,  9],
        [10, 11, 12, 13, 14]])
tensor([[ 0, 11,  7,  3, 14],
        [ 5,  1, 12,  8,  4],
        [10,  6,  2, 13,  9]])

输入和shift-in只是因为平移方向反过来了而已。

当然,进行了特征向量的shift还不够,还需要将其reshape回一个特征矩阵,因此会有:

 x = x.view(n*t,v,c)

这样的代码段出现。


Reference

[1]. https://fesian.blog.csdn.net/article/details/109563113

[2]. https://fesian.blog.csdn.net/article/details/109474701

[3]. https://github.com/kchengiva/Shift-GCN

[4]. https://blog.csdn.net/LoseInVain/article/details/81143966

评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

FesianXu

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值