TDNN代码
核心思想是将语音的提取特征的帧进行前后联系,展开
举例:
1,2,3,4,5
123,234,345
进行了扩展,使得网络看到的特征范围更广
import torch.nn as nn
import torch.nn.function as F
class TDNN(nn.Moudle):
def __init__(
self,
input_dim,
output_dim,
context_size,
stride,
dilation,
batch_norm,
dropout
):
super(TDNN, self).__init__()
self.input_dim = input_dim
self.output_dim = output_dim
self.context_size = context_size
self.stride = stride
self.dilation = dilation
self.batch_norm = batch_norm
self.dropout = dropout
self.kernel = nn.Linear(input_dim * context_size, output_dim)
self.nonlinearity = nn.ReLU()
if self.batch_norm:
self.bn = nn.BatchNorm1d(output_dim)
if self.dropout_p:
self.dropout = nn.Dropout(p = self.dropout)
def forward(self, x):
_, _, d = x.shape
x = x.unsqueeze(1)
# 将前后几帧都联系起来,已当前帧为中心,扩展到周围
x = F.unfold(
x,
(self.context_size, self.input_dim),
stride = (1, self.input_dim),
dilation = (self.dilation, 1)
)
x = x.tranpose(1, 2)
x = self.kernel(x.float())
x = self.nonlinearity(x)
if self.dropout_p:
x = self.dropout(x)
if self.batch_norm:
x = x.transpose(1, 2)
x = self.bn(x)
x = x.transpose(1, 2)
return x
关于unfold函数的理解
关于transpose函数的理解