pytorch索引加一出错
为了加速运算,写了一行对应索引加一的代码,大概长这个样子:
#注意data是pytorch里的tensor
data[index] += 1
少量数据的测试中,上面的代码运行正常,结果也正确。
当index长度大于data的长度时候,运行也正常,但是加和总数对不上。
后来发现,在pytorch中上述的索引加一操作在一次调用中只会对单个索引执行一次。具体讲个例子:
比如index=[1, 1, 1]
预想的结果是data[1] = data[1]+3, 但是实际上只加了一次也就是data[1]=data[1]+1。
猜测这可能是pytorch的速度优化导致的。
numpy与pytorch的负数步长
原意是想找个翻转tensor的操作,但是transformer里面要求是PIL的Image格式。所以想自己写了一个。
第一个版本长这样
# 注意data是tensor
data =