pytorch 01 关于分割任务中 onehot 编码转换的问题

在分割任务中,我们拿到的label通常是由数字类别组成的,但是在应用某些损失函数时,我们需要把label转换成 one—hot编码的形式。

例如:原始label维度 224*224*1(由数字0-2组成) ,为一个三类别的分割任务,在onehot编码后维度为 224*224*3,(可以看成3张224*224*1的切片)。

 

代码:

一:当维度为 N  1 *
one-hot后 N C *

def make_one_hot(input, num_classes):
    """Convert class index tensor to one hot encoding tensor.
    Args:
         input: A tensor of shape [N, 1, *]
         num_classes: An int of number of class
    Returns:
        A tensor of shape [N, num_classes, *]
    """
    shape = np.array(input.shape)
    shape[1] = num_classes
    shape = tuple(shape)
    result = torch.zeros(shape)
    result = result.scatter_(1, torch.LongTensor(input), 1)

    return result
二:当维度为 1 * 
one_hot后 N *
def make_one_hot(input, num_classes):
    """Convert class index tensor to one hot encoding tensor.
    Args:
         input: A tensor of shape [N, 1, *]
         num_classes: An int of number of class
    Returns:
        A tensor of shape [N, num_classes, *]
    """
    shape = np.array(input.shape)
    shape[0] = num_classes
    shape = tuple(shape)
    result = torch.zeros(shape)
    result = result.scatter_(0, torch.LongTensor(input), 1)

    return result

* 代表图像大小 例如 224 x 224

 

  • 2
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值