Tensor Operation
1. 张量拼接与切分
torch.cat():将张量的维度dim进行拼接,不会扩张张量的维度
如dim=0,则两个向量将在第0维进行拼接:(3,4)concat(3,4)-->(6,4)
torch.stack():在新创建的维度dim上进行拼接
如dim=0,则(3,4)stack(3,4)-->(2,3,4)
如dim=2,则(3,4)stack(3,4)-->(3,4,2)
torch.chunk():将张量维度dim进行平均切分,返回张量列表。
若不能整除,最后一份张量小于其他张量
chunk:要切的份数
torch.split():将张量按维度dim进行切分,返回张量列表
split_size_or_sections:当为int时,表示每一份的长度;当为list时,按list元素切分
2. 张量索引
torch.index_select():在维度dim上,按index索引数据。依index索引数据拼接的张量。
index:是dtype为torch.long的tensor
t = torch.randint(0, 9, size=(3, 3))
idx = torch.tensor([0, 1], dtype=torch.long)
t_select = torch.index_select(t, dim=0, index=idx)
torch.mask_select():按mask中的True进行索引,返回一维张量。
t = torch.randint(0, 9, size=(3, 3))
mask = t.ge(5) # >=5 return true; else false
t_select = torch.masked_select(t, mask)
3. 张量变换
torch.reshape():变换张量形状。当张量在内存中是连续时,新张量与input共享数据内存。
torch.transpose():变换张量的两个维度
torch.t():二维张量的转置
torch.squeeze():压缩长度为1的维度。
dim为None时,移除所有长度为1的轴;若指定维度,当且仅当该轴长度为1时,可以被移除
torch.unqueeze():依据dim扩展维度
Tensor Math Operation
torch.add():逐元素计算input+alpha+other
torch.addcdiv():
torch.addcmul():