使用 PyTorch 实现 Kronecker Product

使用 PyTorch 实现 Kronecker Product

原始文档:https://www.yuque.com/lart/ugkv9f/gb2h93

计算过程

A = [ a 11 a 12 a 13 a 21 a 22 a 23 ] , B = [ b 11 b 12 b 21 b 22 b 31 b 32 ] , A ⊗ B = [ a 11 b 11 a 11 b 12 a 12 b 11 a 12 b 12 a 13 b 11 a 13 b 12 a 11 b 21 a 11 b 22 a 12 b 21 a 12 b 22 a 13 b 21 a 13 b 22 a 11 b 31 a 11 b 32 a 12 b 31 a 12 b 32 a 13 b 31 a 13 b 32 a 21 b 11 a 21 b 12 a 22 b 11 a 22 b 12 a 23 b 11 a 23 b 12 a 21 b 21 a 21 b 22 a 22 b 21 a 22 b 22 a 23 b 21 a 23 b 22 a 21 b 31 a 21 b 32 a 22 b 31 a 22 b 32 a 23 b 31 a 23 b 32 ] \begin{matrix} A &=\begin{bmatrix} a_{11} & a_{12} & a_{13} \\ a_{21} & a_{22} & a_{23} \end{bmatrix}, \\ B &=\begin{bmatrix} b_{11} & b_{12} \\ b_{21} & b_{22} \\ b_{31} & b_{32} \end{bmatrix}, \\ A \otimes B &=\begin{bmatrix} a_{11}b_{11} & a_{11}b_{12} & a_{12}b_{11} & a_{12}b_{12} & a_{13}b_{11} & a_{13}b_{12} \\ a_{11}b_{21} & a_{11}b_{22} & a_{12}b_{21} & a_{12}b_{22} & a_{13}b_{21} & a_{13}b_{22} \\ a_{11}b_{31} & a_{11}b_{32} & a_{12}b_{31} & a_{12}b_{32} & a_{13}b_{31} & a_{13}b_{32} \\ a_{21}b_{11} & a_{21}b_{12} & a_{22}b_{11} & a_{22}b_{12} & a_{23}b_{11} & a_{23}b_{12} \\ a_{21}b_{21} & a_{21}b_{22} & a_{22}b_{21} & a_{22}b_{22} & a_{23}b_{21} & a_{23}b_{22} \\ a_{21}b_{31} & a_{21}b_{32} & a_{22}b_{31} & a_{22}b_{32} & a_{23}b_{31} & a_{23}b_{32} \end{bmatrix} \end{matrix} ABAB=[a11a21a12a22a13a23],=b11b21b31b12b22b32,=a11b11a11b21a11b31a21b11a21b21a21b31a11b12a11b22a11b32a21b12a21b22a21b32a12b11a12b21a12b31a22b11a22b21a22b31a12b12a12b22a12b32a22b12a22b22a22b32a13b11a13b21a13b31a23b11a23b21a23b31a13b12a13b22a13b32a23b12a23b22a23b32

PyTorch 实现

起因是看到了这篇文章:https://zhuanlan.zhihu.com/p/79295551介绍了一种新颖的卷积方式,其中使用了 kronecker product 方法来实现。这种计算理解很容易,但是实现起来该如何编程,这是一个值得思考的问题。文章结尾作者推荐的代码中给出了一种实现https://github.com/d-li14/dgconv.pytorch/blob/master/dgconv.py#L26

def kronecker_product(mat1, mat2):
    # 在pytorch1.7版本之后,torch.ger就被torch.outer所代替了
    out_mat = torch.ger(mat1.view(-1), mat2.view(-1))
    # 这里的(mat1.size() + mat2.size())表示的是将两个list拼接起来
    out_mat = out_mat.reshape(*(mat1.size() + mat2.size())).permute([0, 2, 1, 3])
    out_mat = out_mat.reshape(mat1.size(0) * mat2.size(0), mat1.size(1) * mat2.size(1))
    return out_mat

这里应该参考的是这里的方法https://discuss.pytorch.org/t/kronecker-product/3919/7

但是该帖子后面给出了一种更简单的方法https://discuss.pytorch.org/t/kronecker-product/3919/10

def kronecker(A, B):
    AB = torch.einsum("ab,cd->acbd", A, B)
    AB = AB.view(A.size(0)*B.size(0), A.size(1)*B.size(1))
    return AB

二者实际上是一致的,也就是说这里的

out_mat = torch.ger(mat1.view(-1), mat2.view(-1))
out_mat = out_mat.reshape(*(mat1.size() + mat2.size())).permute([0, 2, 1, 3])

与基于 enisum 的实现

AB = torch.einsum("ab,cd->acbd", A, B)

表示的是一样的行为。

在前者中,假设 A=mat1B=mat2 ,二者分别为 axb 和 cxd 大小的矩阵。计算过程可以表述如下:

  1. 对于二者先通过矢量化(.view(-1))。
  2. 利用外积操作(torch.ger)计算出各个元素之间的乘积构成的矩阵,大小为 abxcd。
  3. 将结果调整为 axbxcxd 大小的形状。
  4. 利用 permute 操作变成 axcxbxd 大小的形状。

这个过程实际上与 einsum 中的维度索引的调整 'ab,cd->acbd' 是一致的。

殊途同归。

从另一个角度看einsum的模式设定

我们最终的目标是得到这样一个结果:

  A={lab} a,b={1,2}
  B={mcd} c,d={1,2}

      b: ┌──────1──────┐  ┌──────2──────┐
      d: │  1       2  │  │  1       2  │
  a,c ┌──┴─────────────┴──┴─────────────┴─►
  : : │
  1 1 │  l11xm11 l11xm12  l12xm11 l12xm12
  │ 2 │  l11xm21 l11xm22  l12xm21 l12xm22
  │   │
  2 1 │  l21xm11 l21xm12  l22xm11 l22xm12
  │ 2 │  l21xm21 l21xm22  l22xm21 l22xm22
  ▼ ▼ ▼

这里我标注了 A 和 B 对应的下标的变化范围。

由于我们利用 PyTorch 实现这些处理,那我们必定是要使用对应的矩阵运算和形状变换的。

从这里的图可以看出来,若是对于这里最终的结果形式反着使用变形操作就可以得到这样的结果:

  A={lab} a,b={1,2}
  B={mcd} c,d={1,2}

       d:   1       2
   a,c ┌─────────────────► b
   : : │                 │ :
   │ 1 │ l11xm11 l11xm12 │ 1
   │   │ l12xm11 l12xm12 │ 2
   1───┤                 │
   │ 2 │ l11xm21 l11xm22 │ 1
   │   │ l12xm21 l12xm22 │ 2
   ├───┤                 │
   │ 1 │ l21xm11 l21xm12 │ 1
   │   │ l22xm11 l22xm12 │ 2
   2───┤                 │
   │ 2 │ l21xm21 l21xm22 │ 1
   │   │ l22xm21 l22xm22 │ 2
   ▼   ▼                 ▼

这实际上也就是索引为 (i,m,j,n) 的 2D 矩阵,也是对应于 torch.einsum("ab,cd->acbd", A, B) 结果的表示。

前面的两种形式的代码实际上都是为了得到后面这个结果。再通过一个reshape操作从而获得最终结果。。

参考资料

欢迎关注

在这里插入图片描述

  • 4
    点赞
  • 14
    收藏
    觉得还不错? 一键收藏
  • 4
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值