tensor的scatter和scatter_add操作,这篇讲得比较详细,这里我就简单总结一下,以a.scatter(dim=dim,index=index,src=src)为例
- len(a.size())=len(index.size())=len(src.size())=dim_num 即三者维度数必须相等
- a与index的关系:a.size(i)≥index.size(i) i≠dim
a.size(dim)与index.size(dim)不存在明确的大小关系。(此条对应原文中的约束3)
比如a的size是(2,3,4),dim=1,index的size是(x,y,z),则x≤2,z≤4,y的取值无所谓,大于0就可以
index中每个位置上的值的取值范围为[0,a.size(1)-1] - index与src的关系:index.size(i)≤src.size(i),即src在每个维度上不小于index就可以
其实感觉就是分两步:
step1. 在src中从左上角开始,按index的size切片或切块(这也是为什么src的在每个维度上都要大于等于index的原因)
step2. 用dim值替换:比如index[2,3,4]的值为0,src[2,3,4]的值为100,dim=1,那么就用100去替代a[2,0,4]的值;再比如,index[1,2,5]的值为3,src[1,2,5]的值为85,dim=2,那么就用85去替代a[1,2,3]的值,若dim=0,那么就用85去替代a[3,2,5]的值。