1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40
| class Biaffine(torch.nn.Module): def __init__(self, n_in=768, n_out=2, bias_x=True, bias_y=True): super(Biaffine, self).__init__() self.n_in = n_in self.n_out = n_out self.bias_x = bias_x self.bias_y = bias_y
self.weight = nn.Parameter(torch.Tensor(n_out, n_in + bias_x, n_in + bias_y)) self.reset_parameters() def reset_parameters(self): nn.init.ones_(self.weight)
def forward(self, x, y): if self.bias_x: x = torch.cat((x, torch.ones_like(x[..., :1])), -1) if self.bias_y: y = torch.cat((y, torch.ones_like(y[..., :1])), -1)
b = x.shape[0] o = self.weight.shape[0]
x = x.unsqueeze(1).expand(-1, o, -1, -1) weight = self.weight.unsqueeze(0).expand(b, -1, -1, -1) y = y.unsqueeze(1).expand(-1, o, -1, -1)
s = torch.matmul(torch.matmul(x, weight), y.permute((0, 1, 3, 2))) if s.shape[1] == 1: s = s.squeeze(dim=1) return s
model = Biaffine(3, 1, bias_x=False, bias_y=False) x = torch.arange(12, dtype=torch.float).reshape(2, 2, 3) y = torch.arange(12, dtype=torch.float).reshape(2, 2, 3) result = model(x, y) print(result)
|