pytorch标签onehot编码_pytorch将标签转为onehot

由于想多分类中使用Diceloss,所以需要将[0,1,2,..N]类型的标签转化为onehot类型。

1、在cpu上处理

input数据类型: torch.LongTensor()

数据形状:[bs, 1, *]         可为2D或3D数据

def make_one_hot(input, num_classes):

"""Convert class index tensor to one hot encoding tensor.

Args:

input: A tensor of shape [bs, 1, *]

num_classes: An int of number of class

Returns:

A tensor of shape [bs, num_classes, *]

"""

shape = np.array(input.shape)

shape[1] = num_classes

shape = tuple(shape)

result = torch.zeros(shape)

result = result.scatter_(1, input.cpu(), 1)

return result

2、在GPU上处理

input数据类型: torch.LongTensor().cuda()

数据形状:[bs, 1, *]         可为2D或3D数据

def make_one_hot(input, num_classes):

"""Convert class index tensor to one hot encoding tensor.

Args:

input: A tensor of shape [bs, 1, *]

num_classes: An int of number of class

Returns:

A tensor of shape [bs, num_classes, *]

"""

shape = np.array(input.shape)

shape[1] = num_classes

shape = tuple(shape)

result = torch.zeros(shape).cuda()

result = result.scatter_(1, input, 1)

return result

3、温馨提示

1、FloatTensor转化为LongTensor:

# 此时的输入label为FloatTensor,可在cuda,也可是cpu

label_long = label.long()

2、 Tensor增加一个维度

label_onehot = label_onehot.unsqueeze(1) #在第一维增加一个维度

3、多分类交叉熵是不需要将标签转为onehot的

详情请查看  https://blog.csdn.net/longshaonihaoa/article/details/105253553

4、最近版pytorch有直接的转化为onehot的代码,了解之后更新。

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值