scatter_(input, dim, index, src)将src中数据根据index中的索引按照dim的方向填进input中.
1 >>> x = torch.rand(2, 5) 2 >>> x 3 4 0.4319 0.6500 0.4080 0.8760 0.2355 5 0.2609 0.4711 0.8486 0.8573 0.1029 6 [torch.FloatTensor of size 2x5]
1) dim = 0,分别对每列填充:
>>> torch.zeros(3, 5).scatter_(0, torch.LongTensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]), x) 0.4319 0.4711 0.8486 0.8760 0.2355 0.0000 0.6500 0.0000 0.8573 0.0000 0.2609 0.0000 0.4080 0.0000 0.1029 [torch.FloatTensor of size 3x5]
实现原理:
对于LoneTensor内的矩阵,暂且称为 tmp = [[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]];将最终的 3*5的矩阵,暂且称为result。result初始为全0,需要经过scatter_处理。
举例:
对于tmp[0][0] = 0 -> 取x中x[0][0] = 0.4319,将其插入到result第0列的第0个位置,result[0][0] = 0.4319;
对于tmp[0][1] = 1 -> 取x中x[0][1] =