三元组损失的具体计算过程(手写)

文章探讨了在深度学习中如何计算TripletMarginLoss,通过PyTorch和PaddlePaddle的API示例解释了其具体用法。作者在寻找triplet_margin_loss的详细计算过程时遇到困难,因此决定通过代码演示来帮助读者理解这一损失函数的计算。
摘要由CSDN通过智能技术生成

写这篇博客的原因就是,我想知道triplet_margin_loss的具体计算过程,结果找了半天没找到,[○・`Д´・ ○],全是这么一通就算出来的

anchor = torch.randn(20, 20, requires_grad=True)
positive = torch.randn(20, 20, requires_grad=True)
negative = torch.randn(20, 20, requires_grad=True)

torch.nn.functional.triplet_margin_loss(anchor, positive, negative,reduction='none')
>>> 
tensor([1.0158, 0.0975, 2.1613, 1.4658, 0.7332, 1.5604, 1.0034, 0.3777, 0.1616,
        0.7618, 0.9989, 0.0000, 3.4407, 1.0938, 0.3333, 0.0000, 0.0000, 0.4422,
        1.1857, 1.7083], grad_fn=<ClampMinBackward>)

torch.nn.functional.triplet_margin_loss(anchor, positive, negative,reduction='mean')
>>>
tensor(0.9271, grad_fn=<MeanBackward0>)

把paddle的三元组损失的链接放
https://www.paddlepaddle.org.cn/documentation/docs/zh/2.4/api/paddle/nn/functional/triplet_margin_loss_cn.html

paddle.nn.functional.triplet_margin_loss(input, positive, negative, margin: float = 1.0, p: float = 2.0, epsilon: float = 1e-6, swap: bool = False, reduction: str = 'mean', name: str = None)

该 api 计算输入 input 和 positive 和 negative 间的 triplet margin loss 损失,测量 input 与 positive examples 和 negative examples 之间的相对相似性。所有输入张量的形状都为 (N,∗),* 是任意其他维度。

import paddle
import paddle.nn.functional as F

input = paddle.to_tensor([[1, 5, 3], [0, 3, 2], [1, 4, 1]], dtype=paddle.float32)
positive= paddle.to_tensor([[5, 1, 2], [3, 2, 1], [3, -1, 1]], dtype=paddle.float32)
negative = paddle.to_tensor([[2, 1, -3], [1, 1, -1], [4, -2, 1]], dtype=paddle.float32)
loss = F.triplet_margin_loss(input, positive, negative, margin=1.0, reduction='none')
print(loss)
# Tensor([0.        , 0.57496738, 0.        ])


loss = F.triplet_margin_loss(input, positive, negative, margin=1.0, reduction='mean')
print(loss)
# Tensor([0.19165580])

在这里插入图片描述

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值