一定不要忘记初始化,否则可能会出错!!!
class NewMatrices(nn.Module):#[16, 1, 400]-->batch_size必须是16,若不满足,则丢弃
def __init__(self, batch_size, input_dim, output_dim, bias_dim=128, dropout=0.1):
super(NewMatrices, self).__init__()
# self.dropout = nn.Dropout(dropout)
self.weight = nn.Parameter(torch.Tensor(batch_size, input_dim, output_dim))
self.bias = nn.Parameter(torch.Tensor(bias_dim))
self.reset_parameters()
def reset_parameters(self):
torch.nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
if self.bias is not None:
fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(self.weight)
bound = 1 / math.sqrt(fan_in)
torch.nn.init.uniform_(self.bias, -bound, bound)
def forward(self, x):
# x = self.dropout(x)
return torch.bmm(self.weight, x)