计算机视觉领域(CV)论文中“圈加”、“圈乘”和“点乘”的解释以及代码示例(⊕、⊙、⊗、广播、广播机制、element-wise、矩阵、乘法、矩阵乘法、向量)

0. 引言

我们经常在论文中看到 ⊕ \oplus ⊗ \otimes ⊙ \odot 符号,那么有下面两个问题:

  1. 这三个符号有什么作用呢?
  2. 如何在论文中正确使用这三个数学符号

1. 两种符号的解释

1.1 逐元素相加: ⊕ \oplus

⊕ \oplus 在论文中表示逐元素相加,如果用两个矩阵表示,即:

[ a b c d ] 2 × 2 + [ e f g h ] 2 × 2 = [ a + e b + f c + g d + h ] 2 × 2 \left[\begin{matrix}a & b \cr c & d\end{matrix}\right]_{2\times2}+\left[\begin{matrix}e & f \cr g & h\end{matrix}\right]_{2\times2}=\left[\begin{matrix}a+e & b+f \cr c+g & d+h\end{matrix}\right]_{2\times2} [acbd]2×2+[egfh]2×2=[a+ec+gb+fd+h]2×2

从公式可以看到, ⊕ \oplus 表示对应元素相加,即两个矩阵的形状必须相同。

1.2 矩阵乘法: ⊗ \otimes

圈乘 ⊗ \otimes 表示传统线性代数学的矩阵乘法,用公式即:

[ a 11 a 12 a 13 a 21 a 22 a 23 ] 2 × 3 × [ b 11 b 12 b 21 b 22 b 31 b 32 ] 3 × 2 = [ a 11 × b 11 + a 12 × b 21 + a 13 × b 31 a 11 × b 12 + a 12 × b 22 + a 13 × b 32 a 21 × b 11 + a 22 × b 21 + a 23 × b 31 a 21 × b 12 + a 22 × b 22 + a 23 × b 32 ] 2 × 2 \left[\begin{matrix}a_{11} & a_{12} & a_{13}\cr a_{21} & a_{22} & a_{23}\end{matrix}\right]_{2\times3} \times \left[\begin{matrix} b_{11} & b_{12} \cr b_{21} & b_{22} \cr b_{31} & b_{32} \end{matrix}\right]_{3\times2} =\left[\begin{matrix}a_{11}\times b_{11} + a_{12}\times b_{21} + a_{13}\times b_{31} & a_{11}\times b_{12} + a_{12}\times b_{22} + a_{13}\times b_{32} \cr a_{21}\times b_{11} + a_{22}\times b_{21} + a_{23}\times b_{31} & a_{21}\times b_{12} + a_{22}\times b_{22} + a_{23}\times b_{32} \end{matrix}\right]_{2\times2} [a11a21a12a22a13a23]2×3× b11b21b31b12b22b32 3×2=[a11×b11+a12×b21+a13×b31a21×b11+a22×b21+a23×b31a11×b12+a12×b22+a13×b32a21×b12+a22×b22+a23×b32]2×2

可以看到就是普通的矩阵乘法,要求 A 矩阵第二维度与 B 矩阵第一维度相等。

1.3 矩阵点乘: ⊙ \odot

矩阵点乘 ⊙ \odot 表示矩阵对应位置元素相乘,例子如下:

[ a b c d ] 2 × 2 ⊙ [ e f g h ] 2 × 2 = [ a × e b × f c × g d × h ] 2 × 2 \left[\begin{matrix}a & b \cr c & d\end{matrix}\right]_{2\times2} \odot \left[\begin{matrix}e & f \cr g & h\end{matrix}\right]_{2\times2}=\left[\begin{matrix}a\times e & b \times f \cr c \times g & d \times h\end{matrix}\right]_{2\times2} [acbd]2×2[egfh]2×2=[a×ec×gb×fd×h]2×2

与矩阵加法 ⊕ \oplus 类似,也是要求两个矩阵的维度必须相同。

2. 两种符号的代码表示

名称符号PyTorch 代码含义条件
矩阵乘法(element-wise) ⊗ \otimes torch.mm()torch.matmul()矩阵乘法(自动广播)形状相同或满足广播机制
矩阵加法(element-wise) ⊕ \oplus +torch.add(A, B)两个矩阵对应位置元素相加形状相同或满足广播机制
矩阵点乘 ⊙ \odot *torch.mul(A, B)两个矩阵对应位置元素相乘(自动广播)形状相同或满足广播机制

element-wise 表示逐元素地

3. PyTorch 广播机制

3.1 定义

在 PyTorch 中,算子 ⊗ 和 ⊙ \otimes 和 \odot 操作支持广播,即可以自动将两种 Tensor 中较小尺寸的 Tensor 扩展为相等大小(无需复制数据)。

3.2 用法

如果 Tensor AB 符合广播的条件,那么可以按照下面的方式计算:

  1. 如果 AB 的维度不相同,用 1 来扩展维度较少的那个,使两个 Tensor 的维度一致。
  2. 对于每个维度,结果的维度是 AB 对应维度的最大值。即:列向量按列扩展,行向量按行扩展,使两个 Tensor 扩展为相同尺寸,然后再进行对应元素相乘。用伪代码表示为:
out.shape = max(A.shape, B.shape)

4. 代码演示

4.1 对应位置元素相加: ⊕ \oplus

4.1.1 标量(Scale)与向量(Vector)对应元素相加(自动触发广播机制)

import torch


A = torch.randint(
    high=100, 
    size=(3, 3)
)

print(f"{A = }")
print(f"{A.shape = }")

# 1. 矩阵与标量对应位置相加
print(
    f"-----------------------------------------------\n"
    f"矩阵与标量对应位置相加: \n"
    f"{A + 10 = }"
    f"\n-----------------------------------------------"
)
A = tensor([[79, 20, 39],
            [31, 47,  9],
            [22, 54, 81]])
A.shape = torch.Size([3, 3])
-----------------------------------------------
矩阵与标量对应位置相加: 
A + 10 = tensor([[89, 30, 49],
                 [41, 57, 19],
                 [32, 64, 91]])
-----------------------------------------------

4.1.2 向量(Vector)与向量(Vector)对应元素相加

import torch


# A = torch.tensor([1], [2], [3])  # tensor() takes 1 positional argument but 3 were given
A = torch.tensor(
    data=(1, 2, 3)
)  # A.shape = torch.Size([3])

# 为 A 增加维度
A = A.unsqueeze(-1)  # [3] -> [3, 1]

# 将 A 转置生成 B(也可以使用 A.T)
B = A.t()  # [3, 1] -> [1, 3]

print(f"{A = }")
print(f"{A.shape = }")
print(f"-----------------------------------------------")
print(f"{B = }")
print(f"{B.shape = }")
print(f"-----------------------------------------------")

# 向量与向量对应位置相加
print(f"向量与向量对应位置相加:\n"
      f"{A + B = }\n"
      f"{(A + B).shape = }")
A = tensor([[1],
            [2],
            [3]])
A.shape = torch.Size([3, 1])
-----------------------------------------------
B = tensor([[1, 2, 3]])
B.shape = torch.Size([1, 3])
-----------------------------------------------
向量与向量对应位置相加:
A + B = tensor([[2, 3, 4],
                [3, 4, 5],
                [4, 5, 6]])
(A + B).shape = torch.Size([3, 3])

4.1.3 向量(Vector)与矩阵(Matrix)对应元素相加

import torch


A = torch.tensor(
    data=([1], [2], [3])
)

B = torch.randint(
    high=100, 
    size=(3, 3)
)

print(f"{A = }")
print(f"{A.shape = }")
print(f"-----------------------------------------------")
print(f"{B = }")
print(f"{B.shape = }")
print(f"-----------------------------------------------")

# 向量与矩阵对应位置相加
print(f"向量与矩阵对应位置相加:\n"
      f"{A + B = }\n"
      f"{(A + B).shape = }")
A = tensor([[1],
            [2],
            [3]])
A.shape = torch.Size([3, 1])
-----------------------------------------------
B = tensor([[21, 69, 84],
            [59, 46, 54],
            [68, 25, 95]])
B.shape = torch.Size([3, 3])
-----------------------------------------------
向量与矩阵对应位置相加:
A + B = tensor([[22, 70, 85],
               [61, 48, 56],
               [71, 28, 98]])
(A + B).shape = torch.Size([3, 3])

4.1.4 矩阵(Matrix)与矩阵(Matrix)对应元素相加

1. 不触发广播机制
import torch


A = torch.randint(
    high=100, 
    size=(3, 3)
)

B = torch.randint(
    high=50, 
    size=(3, 3)
)

print(f"-----------------------------------------------")
print(f"{A = }")
print(f"{A.shape = }")
print(f"-----------------------------------------------")
print(f"{B = }")
print(f"{B.shape = }")
print(f"-----------------------------------------------")

print(f"矩阵与矩阵对应位置相加:\n"
      f"{A + B = }\n"
      f"{(A + B).shape = }")
-----------------------------------------------
A = tensor([[10, 64, 54],
            [51, 13, 99],
            [35, 62, 11]])
A.shape = torch.Size([3, 3])
-----------------------------------------------
B = tensor([[26, 20,  5],
            [40,  1, 26],
            [ 4, 20, 47]])
B.shape = torch.Size([3, 3])
-----------------------------------------------
矩阵与矩阵对应位置相加:
A + B = tensor([[ 36,  84,  59],
                [ 91,  14, 125],
                [ 39,  82,  58]])
(A + B).shape = torch.Size([3, 3])
2. 触发广播机制
import torch


A = torch.randint(
    high=100, 
    size=(3, 3)
)

B = torch.randint(
    high=50, size=(1, 3)
)

print(f"-----------------------------------------------")
print(f"{A = }")
print(f"{A.shape = }")
print(f"-----------------------------------------------")
print(f"{B = }")
print(f"{B.shape = }")
print(f"-----------------------------------------------")

# [广播机制]矩阵与矩阵对应位置相加
print(f"[广播机制] 矩阵与矩阵对应位置相加:\n"
      f"{A + B = }\n"
      f"{(A + B).shape = }")
-----------------------------------------------
A = tensor([[82, 43, 47],
            [14, 18, 57],
            [27, 45, 71]])
A.shape = torch.Size([3, 3])
-----------------------------------------------
B = tensor([[34, 40, 22]])
B.shape = torch.Size([1, 3])
-----------------------------------------------
[广播机制] 矩阵与矩阵对应位置相加:
A + B = tensor([[116,  83,  69],
                [ 48,  58,  79],
                [ 61,  85,  93]])
(A + B).shape = torch.Size([3, 3])
3. [错误的] 触发广播机制
import torch


A = torch.randint(
    high=100, 
    size=(3, 3)
)

B = torch.randint(
    high=50, 
    size=(1, 2)
)

print(f"-----------------------------------------------")
print(f"{A = }")
print(f"{A.shape = }")
print(f"-----------------------------------------------")
print(f"{B = }")
print(f"{B.shape = }")
print(f"-----------------------------------------------")

print(f"[错误的广播机制] 矩阵与矩阵对应位置相加")
try:
    print(f"{A + B = }")
except Exception as e:
    print(f"A + B = {e}")
try:
    print(f"{(A + B).shape = }")
except Exception as e:
    print(f"(A + B).shape = {e}")
-----------------------------------------------
A = tensor([[82,  6, 22],
            [71, 56, 10],
            [76, 74, 80]])
A.shape = torch.Size([3, 3])
-----------------------------------------------
B = tensor([[37, 30]])
B.shape = torch.Size([1, 2])
-----------------------------------------------
[错误的广播机制] 矩阵与矩阵对应位置相加
A + B = The size of tensor a (3) must match the size of tensor b (2) at non-singleton dimension 1
(A + B).shape = The size of tensor a (3) must match the size of tensor b (2) at non-singleton dimension 1

4.2 对应元素相乘: ⊙ \odot

import torch


A = torch.randint(
    high=50, 
    size=(3, 3)
)

B = torch.randint(
    high=100,
    size=(3, 3)
)

print(f"-----------------------------------------------")
print(f"{A = }")
print(f"{A.shape = }")
print(f"-----------------------------------------------")
print(f"{B = }")
print(f"{B.shape = }")
print(f"-----------------------------------------------")

print(f"[torch.mul(A, B)] 矩阵与矩阵进行对应位置元素乘法:\n"
      f"{torch.mul(A, B) = }\n"
      f"{torch.mul(A, B).shape = }")
print(f"-----------------------------------------------")
print(f"[A*B] 矩阵与矩阵进行对应位置元素乘法:\n"
      f"{A * B = }\n"
      f"{(A * B).shape = }")
-----------------------------------------------
A = tensor([[36, 37, 31],
        [42, 15,  8],
        [ 8, 31, 18]])
A.shape = torch.Size([3, 3])
-----------------------------------------------
B = tensor([[46, 85, 34],
        [25, 14, 24],
        [89,  1, 71]])
B.shape = torch.Size([3, 3])
-----------------------------------------------
[torch.mul(A, B)] 矩阵与矩阵进行对应位置元素乘法:
torch.mul(A, B) = tensor([[1656, 3145, 1054],
        [1050,  210,  192],
        [ 712,   31, 1278]])
torch.mul(A, B).shape = torch.Size([3, 3])
-----------------------------------------------
[A*B] 矩阵与矩阵进行对应位置元素乘法:
A * B = tensor([[1656, 3145, 1054],
        [1050,  210,  192],
        [ 712,   31, 1278]])
(A * B).shape = torch.Size([3, 3])

⚠️ OBS:程序中的 A * B 表示的是 ⊙ \odot 而非 ⊗ \otimes

4.3 矩阵乘法: ⊗ \otimes

import torch


A = torch.randint(
    high=50,
    size=(3, 3)
)

B = torch.randint(
    high=100, 
    size=(3, 3)
)

print(f"-----------------------------------------------")
print(f"{A = }")
print(f"{A.shape = }")
print(f"-----------------------------------------------")
print(f"{B = }")
print(f"{B.shape = }")
print(f"-----------------------------------------------")

print(f"[torch.matmul(A, B)] 矩阵与矩阵进行矩阵乘法:\n"
      f"{torch.matmul(A, B) = }\n"
      f"{torch.matmul(A, B).shape = }")
print(f"-----------------------------------------------")
print(f"[torch.mm(A, B)] 矩阵与矩阵进行矩阵乘法:\n"
      f"{torch.mm(A, B) = }\n"
      f"{torch.mm(A, B).shape = }")
-----------------------------------------------
A = tensor([[14, 26, 24],
        [24, 31, 18],
        [43, 19, 14]])
A.shape = torch.Size([3, 3])
-----------------------------------------------
B = tensor([[38, 25, 26],
        [73, 75, 49],
        [17, 83, 44]])
B.shape = torch.Size([3, 3])
-----------------------------------------------
[torch.matmul(A, B)] 矩阵与矩阵进行矩阵乘法:
torch.matmul(A, B) = tensor([[2838, 4292, 2694],
        [3481, 4419, 2935],
        [3259, 3662, 2665]])
torch.matmul(A, B).shape = torch.Size([3, 3])
-----------------------------------------------
[torch.mm(A, B)] 矩阵与矩阵进行矩阵乘法:
torch.mm(A, B) = tensor([[2838, 4292, 2694],
        [3481, 4419, 2935],
        [3259, 3662, 2665]])
torch.mm(A, B).shape = torch.Size([3, 3])

5. 总结

  • ⊕ \oplus 表示对应位置元素相加
  • ⊙ \odot 表示对应位置元素相乘
  • ⊗ \otimes 表示矩阵乘法(线性代数中的矩阵乘法)
  • 8
    点赞
  • 37
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值