Pytorch中scatter_ 的使用详细解读
先看一个例子:
torch.zeros(3, 5).scatter_(0, torch.tensor([[0, 1, 1, 1, 1]]), 2)
tensor([[2., 0., 0., 0., 0.],
[0., 2., 2., 2., 2.],
[0., 0., 0., 0., 0.]])
首先是定义了一个3行5列的数组,_scatter中第一个参数0.表示沿着第0轴, 后面第二个参数是坐标,第三个是对应坐标的值,整个意思就是给torch.zeros(3, 5)对
原创
2021-07-20 12:07:38 ·
1265 阅读 ·
0 评论