Pytorch之torch.nn.functional.pad函数详解

torch.nn.functional.pad是PyTorch内置的矩阵填充函数

(1).torch.nn.functional.pad函数详细描述如下:

torch.nn.functional.pad(input, pad, mode,value ) 
Args:
	"""
	input:四维或者五维的tensor Variabe
	pad:不同Tensor的填充方式
		1.四维Tensor:传入四元素tuple(pad_l, pad_r, pad_t, pad_b),
		指的是(左填充,右填充,上填充,下填充),其数值代表填充次数
		2.六维Tensor:传入六元素tuple(pleft, pright, ptop, pbottom, pfront, pback),
		指的是(左填充,右填充,上填充,下填充,前填充,后填充),其数值代表填充次数
	mode: ’constant‘, ‘reflect’ or ‘replicate’三种模式,指的是常量,反射,复制三种模式
	value:填充的数值,在"contant"模式下默认填充0,mode="reflect" or "replicate"时没有			
		value参数
		
	"""

(2).代码演示(举个四维Tensor填充的例子)

1.获取一个shape=(2,1,3,2)的随机数矩阵

import torch
import torch.nn.functional as F
original_values = torch.randn([2,1, 3, 2])
print("original_values: ",original_values,"\n")
print("original_values的shape: ",original_values.shape)

显示结果:

original_values:  tensor([[[[ 0.8182, -1.2295],
          [ 0.1985,  1.2261],
          [-2.1763, -0.5790]]],
        [[[ 0.4204,  1.5903],
          [ 1.6354, -2.7076],
          [ 0.6119,  1.4595]]]]) 
original_values的shape:  torch.Size([2, 1, 3, 2])

2.进行左填充,pad=(1,0,0,0,0,0),在第四个维度填充

padding_values = F.pad(original_values, pad=(1,0,0,0,0,0), mode="constant",value=0)  
print("padding_values: ",padding_values,"\n")
print("padding_values的shape: ",padding_values.shape)

显示结果1(mode=“constant”):

在mode="constant"的padding_values:  tensor([[[[ 0.0000,  0.8182, -1.2295],
          [ 0.0000,  0.1985,  1.2261],
          [ 0.0000, -2.1763, -0.5790]]],
        [[[ 0.0000,  0.4204,  1.5903],
          [ 0.0000,  1.6354, -2.7076],
          [ 0.0000,  0.6119,  1.4595]]]]) 
padding_values的shape:  torch.Size([2, 1, 3, 3])

显示结果二2(mode=“reflect”):

在mode="reflect"的padding_values:  tensor([[[[-1.2295,  0.8182, -1.2295],
          [ 1.2261,  0.1985,  1.2261],
          [-0.5790, -2.1763, -0.5790]]],
        [[[ 1.5903,  0.4204,  1.5903],
          [-2.7076,  1.6354, -2.7076],
          [ 1.4595,  0.6119,  1.4595]]]]) 
padding_values的shape:  torch.Size([2, 1, 3, 3])

显示结果三(mode=“replicate”):

在mode="replicate"的padding_values:  tensor([[[[ 0.8182,  0.8182, -1.2295],
          [ 0.1985,  0.1985,  1.2261],
          [-2.1763, -2.1763, -0.5790]]],
        [[[ 0.4204,  0.4204,  1.5903],
          [ 1.6354,  1.6354, -2.7076],
          [ 0.6119,  0.6119,  1.4595]]]])
padding_values的shape:  torch.Size([2, 1, 3, 3])

3.进行左填充,pad=(2,0,0,0,0,0),在第四个维度填充

padding_values = F.pad(original_values, pad=(2,0,0,0,0,0), mode="constant",value=0)  
print("padding_values: ",padding_values,"\n")
print("padding_values的shape: ",padding_values.shape)

显示结果(mode=“constant”):

padding_values:  tensor([[[[ 0.0000,  0.0000,  0.8182, -1.2295],
          [ 0.0000,  0.0000,  0.1985,  1.2261],
          [ 0.0000,  0.0000, -2.1763, -0.5790]]],
        [[[ 0.0000,  0.0000,  0.4204,  1.5903],
          [ 0.0000,  0.0000,  1.6354, -2.7076],
          [ 0.0000,  0.0000,  0.6119,  1.4595]]]])
padding_values的shape:  torch.Size([2, 1, 3, 4])

4.进行右填充,pad=(0,1,0,0,0,0),在第四个维度填充

padding_values = F.pad(original_values, pad=(0,1,0,0,0,0), mode="constant",value=0)  
print("padding_values: ",padding_values)
print("padding_values的shape: ",padding_values.shape)

显示结果(mode=“constant”):

padding_values:  tensor([[[[ 0.8182, -1.2295,  0.0000],
          [ 0.1985,  1.2261,  0.0000],
          [-2.1763, -0.5790,  0.0000]]],
        [[[ 0.4204,  1.5903,  0.0000],
          [ 1.6354, -2.7076,  0.0000],
          [ 0.6119,  1.4595,  0.0000]]]])
padding_values的shape:  torch.Size([2, 1, 3, 3])

5.进行上填充,pad=(0,0,1,0,0,0),在第三个维度填充

padding_values = F.pad(original_values, pad=(0,0,1,0,0,0), mode="constant",value=0)  
print("padding_values: ",padding_values)
print("padding_values的shape: ",padding_values.shape)

显示结果(mode=“constant”):

padding_values:  tensor([[[[ 0.0000,  0.0000],
          [ 0.8182, -1.2295],
          [ 0.1985,  1.2261],
          [-2.1763, -0.5790]]],
        [[[ 0.0000,  0.0000],
          [ 0.4204,  1.5903],
          [ 1.6354, -2.7076],
          [ 0.6119,  1.4595]]]])
padding_values的shape:  torch.Size([2, 1, 4, 2])

6.进行下填充,pad=(0,0,0,1,0,0),在第三个维度填充

padding_values = F.pad(original_values, pad=(0,0,0,1,0,0), mode="constant",value=0)  
print("padding_values: ",padding_values)
print("padding_values的shape: ",padding_values.shape)

显示结果(mode=“constant”):

padding_values:  tensor([[[[ 0.8182, -1.2295],
          [ 0.1985,  1.2261],
          [-2.1763, -0.5790],
          [ 0.0000,  0.0000]]],
        [[[ 0.4204,  1.5903],
          [ 1.6354, -2.7076],
          [ 0.6119,  1.4595],
          [ 0.0000,  0.0000]]]])
padding_values的shape:  torch.Size([2, 1, 4, 2])

7.进行前填充,此时pad=(0,0,0,0,1,0),在第二个维度上填充

padding_values = F.pad(original_values, pad=(0,0,0,0,1,0), mode="constant",value=0)  
print("padding_values: ",padding_values)
print("padding_values的shape: ",padding_values.shape)

显示结果(mode=“constant”):

padding_values:  tensor([[[[ 0.0000,  0.0000],
          [ 0.0000,  0.0000],
          [ 0.0000,  0.0000]],
         [[ 0.8182, -1.2295],
          [ 0.1985,  1.2261],
          [-2.1763, -0.5790]]],
        [[[ 0.0000,  0.0000],
          [ 0.0000,  0.0000],
          [ 0.0000,  0.0000]],
         [[ 0.4204,  1.5903],
          [ 1.6354, -2.7076],
          [ 0.6119,  1.4595]]]])
padding_values的shape:  torch.Size([2, 2, 3, 2])

8.进行后填充,此时pad=(0,0,0,0,1,0),在第二个维度上填充

padding_values = F.pad(original_values, pad=(0,0,0,0,0,1), mode="constant",value=0)  
print("padding_values: ",padding_values)
print("padding_values的shape: ",padding_values.shape)

显示结果(mode=“constant”):

padding_values:  tensor([[[[ 0.8182, -1.2295],
          [ 0.1985,  1.2261],
          [-2.1763, -0.5790]],
         [[ 0.0000,  0.0000],
          [ 0.0000,  0.0000],
          [ 0.0000,  0.0000]]],
        [[[ 0.4204,  1.5903],
          [ 1.6354, -2.7076],
          [ 0.6119,  1.4595]],
         [[ 0.0000,  0.0000],
          [ 0.0000,  0.0000],
          [ 0.0000,  0.0000]]]])
padding_values的shape:  torch.Size([2, 2, 3, 2])
  • 73
    点赞
  • 179
    收藏
    觉得还不错? 一键收藏
  • 17
    评论
`STFT` 是一个 PyTorch 模块,用于计算短时傅里叶变换(Short-Time Fourier Transform,STFT),是一种常用的信号处理技术。下面是这个模块的详细解释: ```python class STFT(torch.nn.Module): def __init__(self, filter_length=2048, hop_length=512, win_length=None, window='hann', center=True, pad_mode='reflect', freeze_parameters=True): super().__init__() self.filter_length = filter_length self.hop_length = hop_length self.center = center self.pad_mode = pad_mode if win_length is None: win_length = filter_length self.win_length = win_length self.window = get_window(window, win_length) # Create filter kernel fft_basis = np.fft.fft(np.eye(filter_length)) kernel = np.concatenate([np.real(fft_basis[:filter_length // 2 + 1, :]), np.imag(fft_basis[:filter_length // 2 + 1, :])], 0) self.register_buffer('kernel', torch.tensor(kernel, dtype=torch.float32)) # Freeze parameters if freeze_parameters: for name, param in self.named_parameters(): param.requires_grad = False def forward(self, waveform): assert (waveform.dim() == 1) # Pad waveform if self.center: waveform = nn.functional.pad(waveform.unsqueeze(0), (self.filter_length // 2, self.filter_length // 2), mode='constant', value=0) else: waveform = nn.functional.pad(waveform.unsqueeze(0), (self.filter_length - self.hop_length, 0), mode='constant', value=0) # Window waveform if waveform.shape[-1] < self.win_length: waveform = nn.functional.pad(waveform, (self.win_length - waveform.shape[-1], 0), mode='constant', value=0) waveform = waveform.squeeze(0) if self.window.device != waveform.device: self.window = self.window.to(waveform.device) windowed_waveform = waveform * self.window # Pad for linear convolution if self.center: windowed_waveform = nn.functional.pad(windowed_waveform, (self.filter_length // 2, self.filter_length // 2), mode='constant', value=0) else: windowed_waveform = nn.functional.pad(windowed_waveform, (self.filter_length - self.hop_length, 0), mode='constant', value=0) # Perform convolution fft = torch.fft.rfft(windowed_waveform.unsqueeze(0), dim=1) fft = torch.cat((fft.real, fft.imag), dim=1) output = torch.matmul(fft, self.kernel) # Remove redundant frequencies output = output[:, :self.filter_length // 2 + 1, :] return output ``` - `__init__` 方法:构造方法,用于初始化模块的各个参数。其中,`filter_length` 表示 STFT 的滤波器长度,`hop_length` 表示 STFT 的帧移(即相邻帧之间的采样点数),`win_length` 表示 STFT 的窗函数长度,`window` 是指定的窗函数类型(默认为汉宁窗),`center` 表示是否需要在信号两端填充 0 以保证 STFT 的中心位置与输入信号的中心位置对齐,`pad_mode` 是指定填充方式(默认为反射填充),`freeze_parameters` 表示是否需要冻结模块的参数。 - `forward` 方法:前向传播方法,用于计算输入信号的 STFT。其中,`waveform` 表示输入信号。首先,根据 `center` 和 `pad_mode` 对输入信号进行填充和窗函数处理,然后进行线性卷积,最后通过傅里叶变换计算 STFT。返回的 `output` 是一个张量,表示 STFT 系数。
评论 17
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值