总结了一些计算向量相似度的函数,例如Cosine, BiLinear, TriLinear, Muiltihead等
import torch
import torch.nn as nn
import math
class CosineSimilarity(nn.Module):
"""
This similarity function simply computes the cosine similarity between each pair of vectors. It has
no parameters.
"""
def forward(self, tensor_1, tensor_2):
normalized_tensor_1 = tensor_1 / tensor_1.norm(dim=-1, keepdim=True)
normalized_tensor_2 = tensor_2 / tensor_2.norm(dim=-1, keepdim=True)
return (normalized_tensor_1 * normalized_tensor_2).sum(dim=-1)
class DotProductSimilarity(nn.Module):
"""
This similarity function simply computes the dot product between each pair of vectors, with an
optional scaling to reduce the variance of the output elements.
"""
def __init__(self, scale_output=False):
super(DotProductSimilarity, self).__init__()
self.scale_output = scale_output
def forward(self, tensor_1, tensor_2):
result = (tensor_1 * tensor_2).sum(dim=-1)
if self.scale_output:
# TODO why allennlp do multiplication at here ?
result /= math.sqrt(tensor_1.size(-1))
return result
class ProjectedDotProductSimilarity(nn.Module):
"""
This similarity function does a projection and then computes the dot product. It's computed
as ``x^T W_1 (y^T W_2)^T``
An activation function applied after the calculation. Default is no activation.
"""