torch.cat
在给定的维度上对张量序列进行连接操作
参数
inputs 任意相同Tensor序列
dimension 沿着该维连接张量序列
a = torch.randn(2,3)
print(a)
print('*'*20)
print(torch.cat((a,a,a),1))#沿列拼接
print('*'*20)
print(torch.cat((a,a,a),0))#沿行拼接
torch.chunk
在给定维度上将张量进行分块儿
参数:
- tensor (Tensor) – 待分块的输入张量
- chunks (int) – 分块的个数
- dim (int) – 沿着此维度进行分块
a = torch.randn(2,7)
print(a)
print('*'*20)
# 如不能整除,最后一个取余
print(torch.chunk(a,3,1))
torch.gather
沿给定轴dim,将张量索index指定位置得值进行聚合
参数:
- input (Tensor) – 源张量
- dim (int) – 索引的轴
- index (LongTensor) – 聚合元素的下标
- out (Tensor, optional) – 目标张量
t = torch.Tensor([[1,2],[3,4]])
print(t)
print('*'*20)
print(torch.gather(t, 1, torch.LongTensor([[0, 0], [1, 0]])))
具体来说,索引张量 [[0, 0], [1, 0]] 的含义是:
对于第一行(索引为 0),我们要从 t 的第一行(索引为 0)中收集第 0 列和第 0 列的元素。
对于第二行(索引为 1),我们要从 t 的第二行(索引为 1)中收集第 1 列和第 0 列的元素。
由于 t 是 [[1, 2], [3, 4]],根据上面的索引,我们收集到的元素分别是:
第一行:t[0, 0] 和 t[0, 0],即 1 和 1。
第二行:t[1, 1] 和 t[1, 0],即 4 和 3。
torch.index_select
参数:
- input (Tensor) – 源张量
- dim (int) – 索引的轴
- index (LongTensor) – 包含索引下标的一维张量
- out (Tensor, optional) – 目标张量
t = torch.randn(3,4)
print(t)
print('*'*20)
index = torch.LongTensor([0,2])
# 第0行和第2行
print(torch.index_select(t, 0, index))
torch.split
将输入张量分割成相等形状的chunks(如果可分)。 如果沿指定维的张量形状大小不能被split_size 整分, 则最后一个分块会小于其它分块。
参数:
- tensor (Tensor) – 待分割张量
- split_size (int) – 单个分块的形状大小
- dim (int) – 沿着此维进行分割