triplet loss
在Pytorch中有一个类,已经定义好了triplet loss的criterion, class TripletMarginLoss(Module):
class TripletMarginLoss(Module):
r"""Creates a criterion that measures the triplet loss given an input
tensors x1, x2, x3 and a margin with a value greater than 0.
This is used for measuring a relative similarity between samples. A triplet
is composed by `a`, `p` and `n`: anchor, positive examples and negative
example respectively. The shape of all input variables should be
:math:`(N, D)`.
The distance swap is described in detail in the paper `Learning shallow
convolutional feature descriptors with triplet losses`_ by
V. Balntas, E. Riba et al.
Args:
anchor: anchor input tensor
positive: positive input tensor
negative: negative input tensor
p: the norm degree. Default: 2
Shape:
- Input: :math:`(N, D)` where `D = vector dimension`
- Output: :math:`(N, 1)`
使用示例:
>>> triplet_loss = nn.TripletMarginLoss(margin=1.0, p=2)
>>> input1 = autograd.Variable(torch.randn(100, 128))
>>> input2 = autograd.Variable(torch.randn(100, 128))
>>> input3 = autograd.Variable(torch.randn(100, 128))
>>> output = triplet_loss(input1, input2, input3)
>>> output.backward()
参考网址
https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/loss.py