torch.matmul() 详解

最近在准备做 HW04,在读 transformer 的源码的时候发现 attention score 的 torch.matmul() 的奇妙设置,故有此篇文章进行分享。

前言碎碎念:

一开始我以为 torch.matmul 所做的工作就是简单的矩阵相乘,即:假设我们有两个矩阵 AB,它们的 size 分别为 (m, n)(n, p),那么 A x B 的 size 为 (m, p)。然后我看了眼官方文档的例子:

tensor1 = torch.randn(10, 3, 4)
tensor2 = torch.randn(10, 4, 5)
torch.matmul(tensor1, tensor2).size()
>> torch.Size([10, 3, 5])

大大的问号冒了出来 : ),这也能乘?

文章的代码文件:notebook 代码


下面结合官方文档提供一些例子给大家理解。

torch.matmul(input, other, ***, out=None) → Tensor

两个张量的矩阵乘积,具体行为取决于张量的维度,如下所示。

这里为了描述方便,用 input_dother_d 分别指代 input.dim()other.dim(),使用 torch.randint() 替代 torch.randn() 方便印证。

前期工作

import torch

# 固定 torch 的随机数种子,以便重现结果
torch.manual_seed(0)

# 打印信息
def print_info(A, B):
    print(f"A: {A}\nB: {B}")
    print(f"A 的维度: {A.dim()},\t B 的维度: {B.dim()}")
    print(f"A 的元素总数: {A.numel()},\t B 的元素总数: {B.numel()}")
    print(f"torch.matmul(A, B): {torch.matmul(A, B)}")
    print(f"torch.matmul(A, B).size(): {torch.matmul(A, B).size()}")
    

input_d = other_d = 1(两个 Tensor 皆为 1 维)

此时就是我们常说的点积(dot product),返回标量。注意,这里是维度为 1,而不是元素总数。

A = torch.randint(0, 5, size=(2,))
B = torch.randint(0, 5, size=(2,))

print_info(A, B)
>> A: tensor([4, 4])
>> B: tensor([3, 0])
>> A 的维度: 1,	 B 的维度: 1
>> A 的元素总数: 2,	 B 的元素总数: 2
>> torch.matmul(A, B) = 12
>> torch.matmul(A, B).size() = torch.Size([])

input_d = other_d = 2 (两个 Tensor 皆为 2 维)

返回矩阵乘积的结果。

A = torch.randint(0, 5, size=(2, 1))
B = torch.randint(0, 5, size=(1, 2))

print_info(A, B)
>> A: tensor([[3],
>>         [4]])
>> B: tensor([[2, 3]])
>> A 的维度: 2,	 B 的维度: 2
>> A 的元素总数: 2,	 B 的元素总数: 2
>> torch.matmul(A, B) = tensor([[ 6,  9],
>>         [ 8, 12]])
>> torch.matmul(A, B).size() = torch.Size([2, 2])

input_d = 1, other_d = 2

按照广播机制(boardcasting)进行处理,即:从 size 的尾部开始一一比对,如果维度不够,则扩展一维,令初始值为 1 再进行计算。计算完之后移除扩展的维度,用下面的例子来说就是扩展成 (1, 2) 后,(1, 2) * (2, 2) => (1, 2) => (2, )

A = torch.randint(0, 5, size=(2, ))
B = torch.randint(0, 5, size=(2, 2))

print_info(A, B)
>> A: tensor([2, 3])
>> B: tensor([[1, 1],
>>         [1, 4]])
>> A 的维度: 1,	 B 的维度: 2
>> A 的元素总数: 2,	 B 的元素总数: 4
>> torch.matmul(A, B) = tensor([ 5, 14])
>> torch.matmul(A, B).size() = torch.Size([2])

input_d = 2, other_d = 1

返回矩阵与向量的乘积。

# 这里使用上一次的矩阵和向量,方便对照
print_info(B, A)
>> A: tensor([[1, 1],
>>         [1, 4]])
>> B: tensor([2, 3])
>> A 的维度: 2,	 B 的维度: 1
>> A 的元素总数: 4,	 B 的元素总数: 2
>> torch.matmul(A, B) = tensor([ 5, 14])
>> torch.matmul(A, B).size() = torch.Size([2])

input_d > 2 or other_d > 2

以 input_d > 2 为例,维度不匹配就通过广播机制扩展,最后结果上删除掉扩展的维度。

个人理解:对于 dim >= 2 的 tensor 来说最后两维被看作矩阵的行和列,其余(如果存在)被看作 batch。

对于非矩阵(non-matrix)维度也是进行广播处理的,以 A.size() = (j, 1, m, n) 和 B.size() =(k, n, m) 为例,j x 1 和 k 是非矩阵维度,也就是 batch 维度,torch.matmul(A, B).size() = (j, k, m, m)。

input_d > 2 and other_d = 2

矩阵部分:(1, 2) * (2, 1)

A = torch.randint(0, 5, size=(2, 1, 2))
B = torch.randint(0, 5, size=(2, 1))

print_info(A, B)
>> A: tensor([[[3, 1]],
>> 
>>         [[1, 3]]])
>> B: tensor([[4],
>>         [3]])
>> A 的维度: 3,	 B 的维度: 2
>> A 的元素总数: 4,	 B 的元素总数: 2
>> torch.matmul(A, B) = tensor([[[15]],
>> 
>>         [[13]]])

input_d > 2 and other_d = 1

这里可以看成单拎出 A 的最后 2 维与 B 做 input_d = 2 和 other_d = 1 的乘法:(1, 2) * (2, ),具体细节可以回看上面对应的部分。

A = torch.randint(0, 5, size=(2, 1, 2))
B = torch.randint(0, 5, size=(2, ))

print_info(A, B)
>> A: tensor([[[1, 4]],
>> 
>>         [[1, 4]]])
>> B: tensor([4, 1])
>> A 的维度: 3,	 B 的维度: 1
>> A 的元素总数: 4,	 B 的元素总数: 2
>> torch.matmul(A, B) = tensor([[8],
>>         [8]])
>> torch.matmul(A, B).size() = torch.Size([2, 1])

input_d > 2 and other_d >2 (多维 Tensor)

广播部分:(2, 1, *, *) => (2, 2, *, *)。矩阵部分:(2, 1) * (1, 2)

A = torch.randint(0, 5, size=(2, 1, 2, 1))
B = torch.randint(0, 5, size=(2, 1, 2))

print_info(A, B)
>> A: tensor([[[[4],
>>           [4]]],
>> 
>> 
>>         [[[4],
>>           [0]]]])
>> B: tensor([[[1, 2]],
>> 
>>         [[3, 0]]])
>> A 的维度: 4,	 B 的维度: 3
>> A 的元素总数: 4,	 B 的元素总数: 4
>> torch.matmul(A, B) = tensor([[[[ 4,  8],
>>           [ 4,  8]],
>> 
>>          [[12,  0],
>>           [12,  0]]],
>> 
>> 
>>         [[[ 4,  8],
>>           [ 0,  0]],
>> 
>>          [[12,  0],
>>           [ 0,  0]]]])
>> torch.matmul(A, B).size() = torch.Size([2, 2, 2, 2])

在往下翻之前不妨思考一下 torch.matmul(B, A).size() 等于多少。

print_info(B, A)
>> A: tensor([[[1, 2]],
>> 
>>         [[3, 0]]])
>> B: tensor([[[[4],
>>           [4]]],
>> 
>> 
>>         [[[4],
>>           [0]]]])
>> A 的维度: 3,	 B 的维度: 4
>> A 的元素总数: 4,	 B 的元素总数: 4
>> torch.matmul(A, B) = tensor([[[[12]],
>> 
>>          [[12]]],
>> 
>> 
>>         [[[ 4]],
>> 
>>          [[12]]]])
>> torch.matmul(A, B).size() = torch.Size([2, 2, 1, 1])

拓展阅读

Broadcasting

`STFT` 是一个 PyTorch 模块,用于计算短时傅里叶变换(Short-Time Fourier Transform,STFT),是一种常用的信号处理技术。下面是这个模块的详细解释: ```python class STFT(torch.nn.Module): def __init__(self, filter_length=2048, hop_length=512, win_length=None, window='hann', center=True, pad_mode='reflect', freeze_parameters=True): super().__init__() self.filter_length = filter_length self.hop_length = hop_length self.center = center self.pad_mode = pad_mode if win_length is None: win_length = filter_length self.win_length = win_length self.window = get_window(window, win_length) # Create filter kernel fft_basis = np.fft.fft(np.eye(filter_length)) kernel = np.concatenate([np.real(fft_basis[:filter_length // 2 + 1, :]), np.imag(fft_basis[:filter_length // 2 + 1, :])], 0) self.register_buffer('kernel', torch.tensor(kernel, dtype=torch.float32)) # Freeze parameters if freeze_parameters: for name, param in self.named_parameters(): param.requires_grad = False def forward(self, waveform): assert (waveform.dim() == 1) # Pad waveform if self.center: waveform = nn.functional.pad(waveform.unsqueeze(0), (self.filter_length // 2, self.filter_length // 2), mode='constant', value=0) else: waveform = nn.functional.pad(waveform.unsqueeze(0), (self.filter_length - self.hop_length, 0), mode='constant', value=0) # Window waveform if waveform.shape[-1] < self.win_length: waveform = nn.functional.pad(waveform, (self.win_length - waveform.shape[-1], 0), mode='constant', value=0) waveform = waveform.squeeze(0) if self.window.device != waveform.device: self.window = self.window.to(waveform.device) windowed_waveform = waveform * self.window # Pad for linear convolution if self.center: windowed_waveform = nn.functional.pad(windowed_waveform, (self.filter_length // 2, self.filter_length // 2), mode='constant', value=0) else: windowed_waveform = nn.functional.pad(windowed_waveform, (self.filter_length - self.hop_length, 0), mode='constant', value=0) # Perform convolution fft = torch.fft.rfft(windowed_waveform.unsqueeze(0), dim=1) fft = torch.cat((fft.real, fft.imag), dim=1) output = torch.matmul(fft, self.kernel) # Remove redundant frequencies output = output[:, :self.filter_length // 2 + 1, :] return output ``` - `__init__` 方法:构造方法,用于初始化模块的各个参数。其中,`filter_length` 表示 STFT 的滤波器长度,`hop_length` 表示 STFT 的帧移(即相邻帧之间的采样点数),`win_length` 表示 STFT 的窗函数长度,`window` 是指定的窗函数类型(默认为汉宁窗),`center` 表示是否需要在信号两端填充 0 以保证 STFT 的中心位置与输入信号的中心位置对齐,`pad_mode` 是指定填充方式(默认为反射填充),`freeze_parameters` 表示是否需要冻结模块的参数。 - `forward` 方法:前向传播方法,用于计算输入信号的 STFT。其中,`waveform` 表示输入信号。首先,根据 `center` 和 `pad_mode` 对输入信号进行填充和窗函数处理,然后进行线性卷积,最后通过傅里叶变换计算 STFT。返回的 `output` 是一个张量,表示 STFT 系数。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Hoper.J

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值