One-hot function in PyTorch

One-hot function implementation in PyTorch


Method1: use scatter_ function

labels = [0, 1, 4, 7, 3, 2]
one_hot = torch.zeros(6, 8).scatter_(dim = 1, index = labels, value = 1)

Method2: use index_select() funtion

labels = [0, 1, 4, 7, 3, 2]
index = torch.eye(8)
one_hot = torch.index_select(index, dim = 0, index = labels)

Method3: use Embedding module

emb = nn.Embedding(8, 8)
emb.weight.data = torch.eye(8)

then we can get

emb(Variable(torch.LongTensor([1, 2], [3, 4])))
Variable containing:
(0,.,.) = 
0 1 0 0 0 0 0 0 
0 0 1 0 0 0 0 0
(1 ,.,.) =
0 0 0 1 0 0 0 0
0 0 0 0 1 0 0 0

Method4: create a module

class One_Hot(nn.Module):
    def __init__(self, depth):
        super(One_Hot,self).__init__()
        self.depth = depth
        self.ones = torch.sparse.torch.eye(depth)
    def forward(self, X_in):
        X_in = X_in.long()
        return Variable(self.ones.index_select(0,X_in.data))
    def __repr__(self):
        return self.__class__.__name__ + "({})".format(self.depth)
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值