pytorch-scatter解释

pytorch-scatter解释

举个二维的例子,用Python重写一下就很明白了,非常简单。注意:scatter返回新张量,scatter_在原张量上修改。torch中的默认都是这样,例如torch.tensor.data.zero_()就是直接在原张量上操作。

调用官方接口

src = torch.arange(1,7).reshape(2,-1).type(torch.FloatTensor)
target = torch.zeros((2, 3)).type(torch.FloatTensor)
index = torch.tensor([[1, 2, 0], [1, 2, 1]])

target_官方 = target.scatter_(-1, index, src)
'''
tensor([[3., 1., 2.],
        [0., 6., 5.]])
'''

重写:

src = torch.arange(1,7).reshape(2,-1).type(torch.FloatTensor)
target = torch.zeros((2, 3)).type(torch.FloatTensor)
index = torch.tensor([[1, 2, 0], [1, 2, 1]])


for i in range(2):
    for j in range(3):
        try:# 如果i, j超出index范围,那么不做改变。
            c = index[i][j]
            # 如果c可以找到,那么c一定要在target索引范围内,并且i,j也一定要在src范围内,这要是为啥src要比target大的原因。
            target[i][c] = src[i][j]
        
'''
tensor([[3., 1., 2.],
        [0., 6., 5.]])

本质就是dim=1, i[index[i][j]] for all j
'''

用这个特性,将labels改成one-hot编码

# 5个样本,6分类:0-5, 
labels = torch.tensor([0, 2, 5, 4, 1])
src = torch.ones((5, 6)) 或者 src = 1也可以
target = torch.zeros((5,6))

# 将lablels作为index使用,因此需要将labels改为二维,具体因为啥逻辑实际写一遍重写部分, 把(i,c)配对打出来。
labels = labels.unsqueeze(-1)

one_hot = target.scatter(-1, labels, src) # 或者 one_hot = target.scatter(-1, labels, 1)
'''
tensor([[1., 0., 0., 0., 0., 0.],
        [0., 0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 1., 0.],
        [0., 1., 0., 0., 0., 0.]])
'''
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值