第一章主要介绍pytorch对张量的索引,切片,连接,换位
一、torch.cat
1.语法
torch.cat(input, dimention=0) --> Tensor
参数:
inputs (sequence of Tensors) – 可以是任意相同Tensor 类型的python 序列
dimension (int, optional) – 沿着此维连接张量序列。
2.举例(分别按行、列进行连接)
按行——表示dim=0
按列——表示dim=1
import torch
x=torch.tensor([[1,2],[3,4]]) #生成一个2*2的张量
print(x)
row_x=torch.cat((x,x,x), dim=0) #按行连接x,x,x
print(row_x)
col_x=torch.cat((x,x,x), dim=1) #按列连接x,x,x
print(col_x)
输出如下,依次是x、row_x、col_x
二、torch.index_select
1.语法
torch.index_select(input, dimention=0, index=indice) -->Tensor
参数:
input 待切片的张量
dim 按照dim维度进行切片
index 待切片索引
沿着指定维度对输入进行切片,取index中指定的相应项(index为一个LongTensor),然后返回到一个新的张量, 返回的张量与原始张量_Tensor_有相同的维度(在指定轴上)。
2.举例(分别按行、列进行切片)
按行——表示dim=0
按列——表示dim=1
x=torch.randn([3,4]) #随机生成一个3行4列的张量
print(x)
indice=LongTensor([0,2]) #设置index值为0、2
row_x=torch.index_select(x, dim=0, index=indice) #按行切片
print(row_x)
col_x=torch.index_select(x, dim=1, index=indice) #按列切片
print(col_x)
输出结果如下,依次是x、row_x、col_x
三、torch.nonzero
1.语法
torch.nonzero(input, out=None) → LongTensor
参数:
input (Tensor) – 源张量
out (LongTensor, optional) – 包含索引值的结果张量
返回一个包含输入input中非零元素索引的张量。输出张量中的每行包含输入中非零元素的索引。
2.举例
import torch
x=torch.tensor([[0.6, 0.0, 0.0, 0.0],
[0.0, 0.4, 0.0, 0.0],
[0.0, 0.0, 1.2, 0.0],
[0.0, 0.0, 0.0,-0.4]]) #生成一个4*4的张量
nonz_x=torch.nonzero(x) #返回张量x的非零项索引
print(nonz_x)
输出结果如下:
四、torch.split
1.语法
torch.split(tensor, split_size, dim=0)
参数:
tensor #待分块张量
split_size #分块尺寸
dim #按dim维度进行分块
将输入张量分割成相等形状的chunks(如果可分)。 如果沿指定维的张量形状大小不能被split_size 整分, 则最后一个分块会小于其它分块。
2.举例(分别按行、按列进行分块)
x=torch.randn([3,4])
print(x)
row_x=torch.split(x, 2, dim=0)
print(row_x)
col_x=torch.split(x, 2, dim=1)
print(col_x)
输出结果如下,依次为x、row_x、col_x
五、torch.squeeze
1.语法
torch.squeeze(input, dim=None, out=None)
参数:
input (Tensor) – 输入张量
dim (int, optional) – 如果给定,则input只会在给定维度挤压
out (Tensor, optional) – 输出张量
将输入张量形状中的1去除并返回。 如果输入是形如(A×1×B×1×C×1×D),那么输出形状就为: (A×B×C×D)
当给定dim时,那么挤压操作只在给定维度上。例如,输入形状为: (A×1×B), squeeze(input, 0) 将会保持张量不变,只有用 squeeze(input, 1),形状会变成 (A×B)
2.举例
x = torch.zeros(2,1,2,1,2) #生成一个2*1*2*1*2的张量
print(x.shape)
sq_x=torch.squeeze(x) #对x张量形状中的1去除并返回
print(sq_x)
sq_x_dim=torch.squeeze(x, dim=3) #对x张量中第3维中的1去除并返回
print(sq_x_dim)
输出结果如下,依次是x、sq_x、sq_x_dim张量的形状