手写交叉熵损失

一、定义
二元交叉熵:loss=−1/n ∑​[ylna+(1−y)ln(1−a)]
多元交叉熵:loss=−1/n ∑​y​lna

二、实现

import torch
import torch.nn.functional as F
#创建数据集
y=torch.randint(0,2,size=(10,1)).to(torch.float)
p=torch.rand(size=(10,1),requires_grad=True)
print(y)
print(p)
def cross_entry(p,y):                  
    res=-1*torch.sum(y*torch.log(p)+(1-y)*torch.log(1-p))/y.shape[0]
    return res
print(cross_entry(p,y))
print(F.binary_cross_entropy(p,y))     #二元交叉熵
def mul_cross_entry(p,y):
    p=F.softmax(p)
    res=-1*torch.sum(F.one_hot(y)*torch.log(p+0.000001))/y.shape[0]
    return res
y=torch.randint(0,3,size=(10,),dtype=torch.int64)
p=torch.randn(10,3)
print(mul_cross_entry(p,y))            #多元交叉熵
print(F.cross_entropy(p,y))
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值