官方文档:
def narrow(self: Tensor, dim: _int, start: _int, length: _int) -> Tensor:
函数作用:
:将第dim维缩短,也就是切片。
比如是5行5列的矩阵,我们只需要其中的第3到4行
start为2,起始是第三行
length为2,一共需要切出两行
函数参数:
input (Tensor) – 需要被操作的Tensor
dim (int) – 需要被压缩的维度(可以用行,列来类比)
start (int) – 从哪一维(可以用行,列来类比)开始
length (int) – 需要切片的长度
举例实现:
首先输入5行五列的矩阵,然后输出看一下
x = torch.tensor([[1, 2, 3, 4, 5], [6, 7, 8, 9, 10], [11, 12, 13, 14, 15],[16, 17, 18, 19, 20],[21, 22, 23, 24, 25]])
print(x)
print(x.size())
输出结果:
我们使用torch.narrow()的函数来取出这个矩阵的第三四行
相当于对矩阵 x 将第0维切片,从第三行开始,长度为2
y = torch.narrow(x,0,2,2)
print(y)
print(y.size())
输出结果:
如果要对列进行切片,比如需要将1,2,1列切出来
只需要修改一下参数,dim改为1,start从0开始,length长度为3
z = torch.narrow(x,1,0,3)
print(z)
print(z.size())
输出结果:
总结:
对输入的Tensor x 的第 dim 维度进行压缩/切片,
从start开始,长度为length。其他维度不变