PyTorch基础——one hot编码转换

主要使用函数:torch.Tensor.scatter_(.....)

【sample】

将[1, 5, 4, 2]四个样本转换为one hot编码,形式如下:
[[0, 1, 0, 0, 0, 0],
 [0, 0, 0, 0, 0, 1],
 [0, 0, 0, 0, 1, 0],
 [0, 0, 1, 0, 0, 0]],其中每一行表示一个样本

函数调用:
torch.zeros(4, 6).scatter_(1, torch.tensor([1,5,4,2]).unsqueeze(1), 1)

说明:
1. one hot编码以每一行表示一个样本,创建4行6列零元素矩阵,其中
        每一行表示一个样本,共4个样本,因此创建4行
        元素最大值为5, 因此创建6列,表示索引0~5
2. 调用scatter_函数,其中
        第一个参数表示沿着哪一个轴索引, 1表示沿着列方向,即水平方向
        第二个参数表示需要进行one hot编码转换的tensor,首先转换成4行1列,行数与样本个数一致
        第三个参数表示需要替换的值,根据第二个参数转换成的4行1列矩阵,每行的元素代表4行6列零元素矩阵中每行的索引位置,然后替换为第三个参数

核心在于转换后的one hot编码矩阵
            第一个参数指向的维度大小至少为输入tensor中最大值加1(索引从0开始);
            除第一个参数指向的维度大小不同,其它维度与输入tensor应一致;

 

处理流程:
1. 4行6列零元素矩阵:
[[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0]]

2. [1, 5, 4, 2]转换为4行1列矩阵:
[[1],
 [5],
 [4],
 [2]]

3. 取第2步矩阵每一行元素,将第1步矩阵对应索引位置替换为1

[[0, 0, 0, 0, 0, 0],       [[1],         =====>  [[0, 1, 0, 0, 0, 0], 
[0, 0, 0, 0, 0, 0],         [5],         =====>   [0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 0],         [4],         =====>   [0, 0, 0, 0, 1, 0],
[0, 0, 0, 0, 0, 0]]         [2]]         =====>   [0, 0, 1, 0, 0, 0]]
 

=================================================
torch.Tensor.scatter_(.....)函数功能并不仅仅局限于进行one hot编码转换,后续会进行说明.

  • 4
    点赞
  • 14
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值