小白学Pytorch系列–Torch.nn API Distance Functions(13)
方法 | 注释 |
---|---|
nn.CosineSimilarity | 返回沿着dim计算的x1和x2之间的余弦相似度。 |
nn.PairwiseDistance | 计算输入向量之间或输入矩阵列之间的成对距离。 |
nn.CosineSimilarity
>>> input1 = torch.randn(100, 128)
>>> input2 = torch.randn(100, 128)
>>> cos = nn.CosineSimilarity(dim=1, eps=1e-6)
>>> output = cos(input1, input2)
nn.PairwiseDistance
计算输入向量之间或输入矩阵列之间的成对距离。
>>> pdist = nn.PairwiseDistance(p=2)
>>> input1 = torch.randn(100, 128)
>>> input2 = torch.randn(100, 128)
>>> output = pdist(input1, input2)