1. torch.bmm
(Batch Matrix Multiplication)
功能:对两个三维张量进行批量矩阵乘法。
数学形式:
对于输入张量
A
∈
R
b
×
m
×
n
A \in \mathbb{R}^{b \times m \times n}
A∈Rb×m×n和
B
∈
R
b
×
n
×
p
B \in \mathbb{R}^{b \times n \times p}
B∈Rb×n×p,输出
C
∈
R
b
×
m
×
p
C \in \mathbb{R}^{b \times m \times p}
C∈Rb×m×p,其中:
C
[
i
,
:
,
:
]
=
A
[
i
,
:
,
:
]
⋅
B
[
i
,
:
,
:
]
C[i,:,:] = A[i,:,:] \cdot B[i,:,:]
C[i,:,:]=A[i,:,:]⋅B[i,:,:]
参数:
• input
(Tensor): 第一个三维张量。
• mat2
(Tensor): 第二个三维张量(最后一维必须与 input
的倒数第二维匹配)。
代码示例:
import torch
# 定义两个三维张量(batch_size=2, 矩阵维度 3x4 和 4x5)
A = torch.randn(2, 3, 4)
B = torch.randn(2, 4, 5)
# 批量矩阵乘法
C = torch.bmm(A, B) # 输出形状: [2, 3, 5]
应用:
bb = torch.bmm(b[0].view(-1, self.classes, 1), b[1].view(-1, 1, self.classes))
• 目的:计算两个视图的归一化证据
b
b
b的相似性矩阵(即
b
0
⋅
b
1
b^0 \cdot b^1
b0⋅b1)。
• 形状变换:
• b[0].view(-1, self.classes, 1)
:将 b[0]
变形为 [batch_size, classes, 1]
。
• b[1].view(-1, 1, self.classes)
:将 b[1]
变形为 [batch_size, 1, classes]
。
• 结果:bb
的形状为 [batch_size, classes, classes]
,表示两两类别之间的相关性。
2. torch.mul
(Element-wise Multiplication)
功能:逐元素乘法(支持广播机制)。
数学形式:
对于输入张量
A
A
A和
B
B
B,输出
C
=
A
⊙
B
C = A \odot B
C=A⊙B,其中
C
C
C的每个元素满足:
C
i
,
j
=
A
i
,
j
×
B
i
,
j
C_{i,j} = A_{i,j} \times B_{i,j}
Ci,j=Ai,j×Bi,j
参数:
• input
(Tensor): 第一个张量。
• other
(Tensor/Number): 第二个张量或标量(自动广播)。
代码示例:
import torch
# 张量与张量逐元素相乘
A = torch.tensor([[1, 2], [3, 4]])
B = torch.tensor([[2, 2], [1, 0]])
C = torch.mul(A, B) # 输出: [[2, 4], [3, 0]]
# 张量与标量相乘
D = torch.mul(A, 3) # 输出: [[3, 6], [9, 12]]
应用:
bu = torch.mul(b[0], u[1].expand(b[0].shape))
ub = torch.mul(b[1], u[0].expand(b[0].shape))
b_a = (torch.mul(b[0], b[1]) + bu + ub) / (1 - C.view(-1, 1))
• 目的:
bu
和ub
:计算视图间的证据-不确定性交互项( b 0 ⋅ u 1 b^0 \cdot u^1 b0⋅u1和 b 1 ⋅ u 0 b^1 \cdot u^0 b1⋅u0)。b_a
:合成归一化证据时,结合原始证据、交互项和冲突调整因子 1 − C 1 - C 1−C。
• 广播机制:u[1].expand(b[0].shape)
将不确定性张量扩展为与证据张量相同的形状。
对比总结
函数 | 作用 | 输入形状 | 输出形状 | 典型应用场景 |
---|---|---|---|---|
torch.bmm | 批量矩阵乘法 | [b,m,n] 和 [b,n,p] | [b,m,p] | 多视图相关性计算、注意力机制 |
torch.mul | 逐元素乘法(支持广播) | 任意形状(需可广播) | 同输入形状(广播后) | 加权融合、动态调整参数 |
结构示意图(DST例子)
- 输入:两个视图的狄利克雷参数
alpha1
和alpha2
。 - 中间计算:
• 通过torch.bmm
计算视图间相关性矩阵bb
。
• 通过torch.mul
计算证据与不确定性的交互项bu
和ub
。 - 输出:合成后的参数
alpha_a
。