PyTorch中torch.nn.functional.pad函数使用详解

顾明思义,这个函数是用来扩充张量数据的边界的。但是PyTorch中,pad的函数和numpy以及tensorflow的pad用法都不一样。今天就带来这个函数简明的用法解释。

首先跳到函数定义中,看一下有哪些参数。

def pad(input, pad, mode=‘constant’, value=0)

  • input : 输入张量
  • pad: 指定padding的维度和数目,形式是元组,稍后讲。
  • mode: 填充模式,不一样的模式,填充的值也不一样,
  • value: 仅当mode为‘constant’时有效,意思是填充的值是常亮,且值为value

重点就是讲一下这个pad参数。

假设现在有一个tensor的shape为 [ 3 , 3 , 32 , 40 ] [3,3,32,40] [3,3,32,40],四维张量。
假设pad为:

(2,2
3,4,
1,2,
1,1 )

第一行的(2,2),意义是对最低的维度(dim=-1)前面填充2个单位,后面填充2个单位。
第二行的(3,4),意义是对
倒数第二个维度(dim=-2)
,前面填充3个单位,后面填充4个单位

第三行第四行的意义以此类推。重点就是pad里面每两个元素为1组,指定了由低维到高维,每一维度,前面填充和后面填充的数值单位。

如果对于一个四维张量,pad里面有4个元素,又是啥情况?
当然是只对最后两个维度pading了。
下面就看一个例子。

import torch
from torch.nn import functional as F

a = torch.randn([2,3,4,5])  # torch.Size([2, 3, 4, 5])
padding = (
    1,2,   # 前面填充1个单位,后面填充两个单位,输入的最后一个维度则增加1+2个单位,成为8
    2,3,
    3,4
)
print(a.shape)
b = F.pad(a, padding)
print(b.shape)  # torch.Size([2, 10, 9, 8])  

从上面的例子看出,之后后三个维度发生了扩增,因为我们输入的padding长度为6,只能影响后三个维度。

`STFT` 是一个 PyTorch 模块,用于计算短时傅里叶变换(Short-Time Fourier Transform,STFT),是一种常用的信号处理技术。下面是这个模块的详细解释: ```python class STFT(torch.nn.Module): def __init__(self, filter_length=2048, hop_length=512, win_length=None, window='hann', center=True, pad_mode='reflect', freeze_parameters=True): super().__init__() self.filter_length = filter_length self.hop_length = hop_length self.center = center self.pad_mode = pad_mode if win_length is None: win_length = filter_length self.win_length = win_length self.window = get_window(window, win_length) # Create filter kernel fft_basis = np.fft.fft(np.eye(filter_length)) kernel = np.concatenate([np.real(fft_basis[:filter_length // 2 + 1, :]), np.imag(fft_basis[:filter_length // 2 + 1, :])], 0) self.register_buffer('kernel', torch.tensor(kernel, dtype=torch.float32)) # Freeze parameters if freeze_parameters: for name, param in self.named_parameters(): param.requires_grad = False def forward(self, waveform): assert (waveform.dim() == 1) # Pad waveform if self.center: waveform = nn.functional.pad(waveform.unsqueeze(0), (self.filter_length // 2, self.filter_length // 2), mode='constant', value=0) else: waveform = nn.functional.pad(waveform.unsqueeze(0), (self.filter_length - self.hop_length, 0), mode='constant', value=0) # Window waveform if waveform.shape[-1] < self.win_length: waveform = nn.functional.pad(waveform, (self.win_length - waveform.shape[-1], 0), mode='constant', value=0) waveform = waveform.squeeze(0) if self.window.device != waveform.device: self.window = self.window.to(waveform.device) windowed_waveform = waveform * self.window # Pad for linear convolution if self.center: windowed_waveform = nn.functional.pad(windowed_waveform, (self.filter_length // 2, self.filter_length // 2), mode='constant', value=0) else: windowed_waveform = nn.functional.pad(windowed_waveform, (self.filter_length - self.hop_length, 0), mode='constant', value=0) # Perform convolution fft = torch.fft.rfft(windowed_waveform.unsqueeze(0), dim=1) fft = torch.cat((fft.real, fft.imag), dim=1) output = torch.matmul(fft, self.kernel) # Remove redundant frequencies output = output[:, :self.filter_length // 2 + 1, :] return output ``` - `__init__` 方法:构造方法,用于初始化模块的各个参数。其中,`filter_length` 表示 STFT 的滤波器长度,`hop_length` 表示 STFT 的帧移(即相邻帧之间的采样点数),`win_length` 表示 STFT 的窗函数长度,`window` 是指定的窗函数类型(默认为汉宁窗),`center` 表示是否需要在信号两端填充 0 以保证 STFT 的中心位置与输入信号的中心位置对齐,`pad_mode` 是指定填充方式(默认为反射填充),`freeze_parameters` 表示是否需要冻结模块的参数。 - `forward` 方法:前向传播方法,用于计算输入信号的 STFT。其中,`waveform` 表示输入信号。首先,根据 `center` 和 `pad_mode` 对输入信号进行填充和窗函数处理,然后进行线性卷积,最后通过傅里叶变换计算 STFT。返回的 `output` 是一个张量,表示 STFT 系数。
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值