torch.nn.Linear类(nn.Module)详解(包含内部逻辑)

torch.nn.Linear类(nn.Module)

  • torch.nn.Linear(in_features, out_features, bias=True, device=None, dtype=None)

    对输入使用线性变换: y = x A T + b y = xA^T + b y=xAT+b

    其实上面这个不是重点,这个函数的重点在于其可以针对维度如 ( 2 × 2 × 3 ) (2\times2\times3) (2×2×3)这种多维度的输入,与我以前所一直认为的只能针对二维输入不一样,其可以支持多维

    还是先说说其参数:

    • in_features:输入的最后一维度的大小,如 ( 2 × 2 × 3 ) (2\times2\times3) (2×2×3)就是3
    • out_features:你想要输出的最后一维度的大小,比如,如果这个参数是4,那么 ( 2 × 2 × 3 ) (2\times2\times3) (2×2×3)会变为 ( 2 × 2 × 4 ) (2\times2\times4) (2×2×4)
    • bias:是否使用偏差,即上面公式是否+b

    ps:这个组件的默认权重与偏差的初始化使用如下方式

    • 权重A使用 U ( − k , k ) \mathcal{U}(-\sqrt{k}, \sqrt{k}) U(k ,k )初始化,即在 ( − k , k ) (-\sqrt{k}, \sqrt{k}) (k ,k )范围内的均匀分布权重初始化, k = 1 in_features k = \frac{1}{\text{in\_features}} k=in_features1
    • 偏差b使用与权重一样的初始化方式
  • 举例:(内部逻辑)

    >>> x = torch.arange(12,dtype=torch.float32).view(2,2,3)
    >>> x
    Out: 
    tensor([[[ 0.,  1.,  2.],
             [ 3.,  4.,  5.]],
            [[ 6.,  7.,  8.],
             [ 9., 10., 11.]]])
    >>> linear=torch.nn.Linear(3,1)  # in_features=3,out_features=1
    >>> linear.weight  # 展示权重
    Out: 
    Parameter containing:
    tensor([[ 0.4956, -0.0664, -0.5577]], requires_grad=True)
    >>> linear.bias  # 展示偏差
    Out: 
    Parameter containing:
    tensor([0.2490], requires_grad=True)
    >>> y = linear(x)
    >>> y.shape
    Out: torch.Size([2, 2, 1])
    >>> y
    Out: 
    tensor([[[-0.9327],  # 下面会展示如何得到这个-0.9327
             [-1.3183]],
            [[-1.7040],
             [-2.0896]]], grad_fn=<ViewBackward0>)
    >>> torch.sum(torch.dot(x[0,0,:],linear.weight.t().view(3))+linear.bias)  
    ... # 即[ 0.,  1.,  2.]与权重[ 0.4956, -0.0664, -0.5577]的转置做矩阵乘法,再加上偏差0.2490
    Out: tensor(-0.9327, grad_fn=<SumBackward0>)
    
  • 2
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
`STFT` 是一个 PyTorch 模块,用于计算短时傅里叶变换(Short-Time Fourier Transform,STFT),是一种常用的信号处理技术。下面是这个模块的详细解释: ```python class STFT(torch.nn.Module): def __init__(self, filter_length=2048, hop_length=512, win_length=None, window='hann', center=True, pad_mode='reflect', freeze_parameters=True): super().__init__() self.filter_length = filter_length self.hop_length = hop_length self.center = center self.pad_mode = pad_mode if win_length is None: win_length = filter_length self.win_length = win_length self.window = get_window(window, win_length) # Create filter kernel fft_basis = np.fft.fft(np.eye(filter_length)) kernel = np.concatenate([np.real(fft_basis[:filter_length // 2 + 1, :]), np.imag(fft_basis[:filter_length // 2 + 1, :])], 0) self.register_buffer('kernel', torch.tensor(kernel, dtype=torch.float32)) # Freeze parameters if freeze_parameters: for name, param in self.named_parameters(): param.requires_grad = False def forward(self, waveform): assert (waveform.dim() == 1) # Pad waveform if self.center: waveform = nn.functional.pad(waveform.unsqueeze(0), (self.filter_length // 2, self.filter_length // 2), mode='constant', value=0) else: waveform = nn.functional.pad(waveform.unsqueeze(0), (self.filter_length - self.hop_length, 0), mode='constant', value=0) # Window waveform if waveform.shape[-1] < self.win_length: waveform = nn.functional.pad(waveform, (self.win_length - waveform.shape[-1], 0), mode='constant', value=0) waveform = waveform.squeeze(0) if self.window.device != waveform.device: self.window = self.window.to(waveform.device) windowed_waveform = waveform * self.window # Pad for linear convolution if self.center: windowed_waveform = nn.functional.pad(windowed_waveform, (self.filter_length // 2, self.filter_length // 2), mode='constant', value=0) else: windowed_waveform = nn.functional.pad(windowed_waveform, (self.filter_length - self.hop_length, 0), mode='constant', value=0) # Perform convolution fft = torch.fft.rfft(windowed_waveform.unsqueeze(0), dim=1) fft = torch.cat((fft.real, fft.imag), dim=1) output = torch.matmul(fft, self.kernel) # Remove redundant frequencies output = output[:, :self.filter_length // 2 + 1, :] return output ``` - `__init__` 方法:构造方法,用于初始化模块的各个参数。其中,`filter_length` 表示 STFT 的滤波器长度,`hop_length` 表示 STFT 的帧移(即相邻帧之间的采样点数),`win_length` 表示 STFT 的窗函数长度,`window` 是指定的窗函数型(默认为汉宁窗),`center` 表示是否需要在信号两端填充 0 以保证 STFT 的中心位置与输入信号的中心位置对齐,`pad_mode` 是指定填充方式(默认为反射填充),`freeze_parameters` 表示是否需要冻结模块的参数。 - `forward` 方法:前向传播方法,用于计算输入信号的 STFT。其中,`waveform` 表示输入信号。首先,根据 `center` 和 `pad_mode` 对输入信号进行填充和窗函数处理,然后进行线性卷积,最后通过傅里叶变换计算 STFT。返回的 `output` 是一个张量,表示 STFT 系数。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

大侠月牙

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值