Tensor的维度操作
1.unsqueeze(增加维度)
用法:torch.unsqueeze(input, dim)
或者 tensor.unsqueeze(dim)
input为输入的tensor,dim为增加的维度,如果指定tensor的话就可以只输入dim
import torch
x = torch.randn(5, 6)
y = torch.unsqueeze(x, 0)
z = torch.unsqueeze(x, 2)
d = x.unsqueeze(1)
print(x.shape)
print(y.shape)
print(z.shape)
print(d.shape)
# 输出:
torch.Size([5, 6])
torch.Size([1, 5, 6])
torch.Size([5, 6, 1])
torch.Size([5, 1, 6])
2.squeeze(减少维度)
- 和
unsqueeze()
方法相反,除去数值为1的维度需要注意的是删除的维度只能为1。
import torch
x = torch.randn(1,5, 6)
y = torch.squeeze(x, 0)
z = torch.squeeze(x, 1)
d = x.squeeze(0)
print(x.shape)
print(y.shape)
print(z.shape)
print(d.shape)
#输出:
torch.Size([1, 5, 6])
torch.Size([5, 6])
torch.Size([1, 5, 6])
torch.Size([5, 6])
3.repeat(维度扩张)
-
repeat(*size) 复制维度上的张量*size指定每个维度复制多少次
import torch
x = torch.randn(5, 6)
y = x.repeat(3,4)
z = x.squeeze(0).repeat(3,4,5)
print(x.shape)
print(y.shape)
print(z.shape)
#输出:
torch.Size([5, 6])
torch.Size([15, 24])
torch.Size([3, 20, 30])
4.narrow(获取张量在指定维度上的子张量)
- narrowed_tensor = tensor.narrow(dim, start, length)
import torch
x = torch.tensor([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
y = x.narrow(0, 1, 2)
print(x.shape)
print(y)
print(y.shape)
#输出:
torch.Size([3, 3])
tensor([[4, 5, 6],
[7, 8, 9]])
torch.Size([2, 3])
5.reshape/view(改变维度的形状)
- 从功能上来看,它们的作用是相同的,都是用来重塑 Tensor 的 shape 的。view 只适合对满足连续性条件 (contiguous) 的 Tensor进行操作,而reshape 同时还可以对不满足连续性条件的 Tensor 进行操作,具有更好的鲁棒性。view 能干的 reshape都能干,如果 view 不能干就可以用 reshape 来处理。建议无脑使用reshape
- reshape(input,shape) 将输入的张量改变为指定的shape,注意改变前后所有维度的乘机应该相同
import torch
x = torch.randn(5,6)
y = torch.reshape(x,(15,2))
print(x.shape)
print(y.shape)
#输出:
torch.Size([5, 6])
torch.Size([15, 2])
6.permute(对制定维度的顺序进行重拍)
Tensor.permute(*dims)
接受一个可变数量的参数dims,用于指定新的维度顺序
a = torch.arange(12).reshape(1,3,4)
b = a.permute(1,0,2)
print(a)
print(b)
print(a.shape)
print(b.shape)
#输出
tensor([[[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]]])
tensor([[[ 0, 1, 2, 3]],
[[ 4, 5, 6, 7]],
[[ 8, 9, 10, 11]]])
torch.Size([1, 3, 4])
torch.Size([3, 1, 4])
7. transpose(
用于对张量进行转置操作。它可以交换张量的两个维度,从而改变张量的形状。)
- 语法:
torch.transpose(input, dim0, dim1)
input
:要进行转置操作的输入张量。dim0、dim1
:要交换的两个维度。
8.flatten(把指定张量打平为一维)
9.cat(
将多个张量沿指定维度进行拼接(连接))
- 语法:
torch.cat(tensors, dim=0, out=None)
tensors
是一个要拼接的张量序列,可以是一个张量列表或元组。dim
是指定拼接的维度,默认为0。out
是一个可选的输出张量,用于指定结果张量的存储位置。
10.stack(沿着新的维度对给定序列的张量进行堆叠)
- 语法:
torch.stack(tensors, dim=0, out=None)
tensors
:要堆叠的张量序列。dim
:沿着哪个维度进行堆叠的维度,是一个整数值。out
:(可选)输出张量。如果指定了该参数,结果将存储在这个张量中。如果没有指定,将创建一个新的张量。