总说
写Lua时经常会碰到一些坑,随手记录,防止一坑不起。2017.2.20第四次更新。
Tensor
Tensor的小问题
- 获取某一维长度
当然可以用x:size(dim), 或是(#x)[dim] - #x得到的是tensor x的长度的tensor。不是一个num
所以在for中,如果x是一维的Tensor,那么应该这样写
for 1,(#x)[1] do
...
end
如果x是table的话,那么直接这样写
for 1,#tl do
...
end
- clone()很重要
凡是tensor的赋值,一般最好后面加上clone(), 否则就是引用!!因为torch为了加快速度, 如果不加clone(),默认是引用。 - 内存空间的连续
有些时候,进行某些操作时,必须要求Tensor是内存空间连续的。那什么情况下,得到的Tensor会是不连续的呢?其实一般是“截取”Tensor造成的。 比如
x = torch.Tensor(5,7)
y = x[{ {
2,-2},{
1,4} }] --负号是倒数的意思
t = y:contiguous()
contiguous函数如果y是内存连续的,就是引用。如果内存不连续,那么就相当于clone()。clone()不管y是不是内存连续的,总是开辟连续的空间存储tensor。
a[{ {index},{},{} }]与a[index]选哪个?
a = torch.Tensor(3,4,5)
a[1] -- size 4*5
a[{
{
1},{},{}}] -- size 1*4*5
-- 当然如果要赋值的话
b = torch.Tensor(4,5)
a[1] = b
a[{ {
1},{},{} }] = b
--上述两种方法等价。
Tensor中选择元素时索引必须用number类型
简单来说就是,tensor的索引必须用number!!!!!!!!!!!!!!!
来个例子:
...
_,sq = torch.sort(pow2dist,1)
鸡蛋来说上面的pow2dist是3*1的Tensor。然后,直接sort。值得注意的是此时sq是一个tensor,好了。然后我们再用
i = 1
a = torch.Tensor(3,4,5)
k = self.a[{ {sq[i]},{},{} }]:clone()
print(#k)
啥?????竟然输出的也是3*4*5。。当时真的是天昏地暗,这是什么鬼?,更吓人的是,
local i = 1
local a = torch.Tensor(3,4,5)
local k = a[{ {sq[i]},{},{} }]:clone()
print(#k)
--查看了 sq[1]就是1的时候,我就试试
print(#a[{ {
1},{},{} }] -- 1*4*5
好吧,这是为啥啊,明明sq[1]就是1啊,为什么直接代入1就正常。于是查阅torch.sort到底返回的是什么。
y, i = torch.sort(x) returns a Tensor y where all entries are sorted along the last dimension, in ascending order. It also returns a Tensor i that provides the corresponding indices from x.
也就是说,返回的sq是一个Tensor没毛病,that provides the corresponding indices from x
,既然是indices,所以就不能sq[i]就不能是一个值,而是一个tensor!
所以对于pow2dist是3*1的情况。可以:
local