Pytorch张量视图(Tensor Views)


前言

张量视图(Tensor Views)是 PyTorch 中的一个重要概念。视图是指与原始张量共享相同底层数据的新张量,但具有不同的形状或步幅。

通过创建张量视图,我们可以在不复制数据的情况下对张量进行形状变换、切片和其他操作,从而实现快速且内存高效的操作。


1.torch.as_strided()

  使用 torch.as_strided 函数创建新的张量,这个新视图与原张量共享内存,但可以以不同的步幅访问这些数据。

import torch

# 创建一个输入张量
input = torch.tensor([1, 2, 3, 4, 5, 6])
output = torch.as_strided(input, size=(3, 2), stride=(2, 1))

# 打印输入张量和输出张量
print("Input tensor:")
print(input)

print("Output tensor:")
print(output)
Input tensor:
tensor([1, 2, 3, 4, 5, 6])
Output tensor:
tensor([[1, 2],
        [3, 4],
        [5, 6]])

2.torch.detach()

 当使用 torch.detach() 函数时,它会创建一个新的张量,该张量与原始张量共享相同的数据,但不共享计算图的梯度。这使得新张量可以被用作不需要梯度的中间结果或者被用于与原始张量相互独立的计算。

import torch

# 创建一个需要梯度的张量
x = torch.tensor([2.0, 3.0], requires_grad=True)

# 进行一些操作,并分离计算图
y = x + 1
z = y * 2
z_detached = z.detach()

# 查看张量的梯度信息
print(x.requires_grad)  # 输出: True
print(y.requires_grad)  # 输出: True
print(z.requires_grad)  # 输出: True
print(z_detached.requires_grad)  # 输出: False

3.torch.diagonal()

  用于提取张量的对角线元素。

torch.diagonal(input, offset=0, dim1=0, dim2=1)
"""
参数说明:

input:输入张量。
offset:对角线的偏移量,表示从主对角线的偏移量,默认为0(主对角线)。
dim1:起始维度。
dim2:结束维度。
"""
import torch

# 创建一个二维张量
x = torch.tensor([[1, 2, 3],
                  [4, 5, 6],
                  [7, 8, 9]])

# 提取主对角线的元素
diagonal = torch.diagonal(x)

print(diagonal)  # 输出: tensor([1, 5, 9])

# 提取副对角线的元素
diagonal_offset = torch.diagonal(x, offset=1)

print(diagonal_offset)  # 输出: tensor([2, 6])

# 提取指定维度的对角线元素
diagonal_dim = torch.diagonal(x, dim1=1, dim2=0)

print(diagonal_dim)  # 输出: tensor([1, 5, 9])

4.torch.expand()

  torch.expand() 是 PyTorch 中的一个函数,用于扩展张量的形状。

torch.expand(input, size)
"""
参数说明:

input:输入张量。
size:扩展后的目标形状,可以是一个元组或列表。
"""
import torch

# 创建一个一维张量
x = torch.tensor([1, 2, 3])

# 扩展张量的形状
expanded = torch.expand(x, (2, 3))

print(expanded)
# 输出:
# tensor([[1, 2, 3, 1, 2, 3, 1, 2, 3],
#         [1, 2, 3, 1, 2, 3, 1, 2, 3]])

# 创建一个二维张量
y = torch.tensor([[1],
                  [2]])

# 扩展张量的形状
expanded_2 = torch.expand(y, (2, 3))

print(expanded_2)
# 输出:
# tensor([[1, 1, 1],
#         [2, 2, 2],
#         [1, 1, 1],
#         [2, 2, 2]])

5.torch.movedim()

  torch.movedim() 是 PyTorch 中的一个函数,用于移动张量的维度顺序。

torch.movedim(input, source, destination)
"""
参数说明:

input:输入张量。
source:原始维度或维度的序列。
destination:目标维度或维度的序列。
"""
import torch

# 创建一个三维张量
x = torch.tensor([[[1, 2],
                   [3, 4]],
                  
                  [[5, 6],
                   [7, 8]]])

# 移动张量的维度顺序
moved = torch.movedim(x, source=[0, 1, 2], destination=[2, 1, 0])

print(moved.shape)  # 输出: torch.Size([2, 2, 2])
print(moved)
# tensor([[[1, 5],
#          [3, 7]],

#         [[2, 6],
#          [4, 8]]])

6.torch.narrow()

  torch.narrow() 是 PyTorch 中的一个函数,用于沿指定维度缩小张量,通过选择一定范围的索引来实现。

torch.narrow(input, dim, start, length)
"""
参数说明:

input:输入张量。
dim:要缩小的维度。
start:缩小范围的起始索引。
length:缩小范围的长度。
"""
import torch

# 创建一个张量
x = torch.tensor([[1, 2, 3],
                  [4, 5, 6],
                  [7, 8, 9]])

# 沿着维度0进行缩小
narrowed = torch.narrow(x, 0, 1, 2)

print(narrowed)
# 输出:
# tensor([[4, 5, 6],
#         [7, 8, 9]])

7.torch.permute()

torch.permute() 是 PyTorch 中的一个函数,用于对张量进行维度重排。

torch.permute(*dims)
"""
参数说明:

*dims:要重排的维度顺序,可以是一个整数序列或多个整数参数。
"""
import torch

# 创建一个三维张量
x = torch.tensor([[[1, 2, 3],
                   [4, 5, 6]],
                  
                  [[7, 8, 9],
                   [10, 11, 12]]])

# 对张量进行维度重排
permuted = torch.permute(x, [0, 2, 1])
#permuted = x.permute(0,2,1)
print(permuted.shape)  # 输出: torch.Size([2, 3, 2])
print(permuted)
# 输出:
# tensor([[[ 1,  4],
#          [ 2,  5],
#          [ 3,  6]],
# 
#         [[ 7, 10],
#          [ 8, 11],
#          [ 9, 12]]])

8.torch.select()

torch.select() 是 PyTorch 中的一个函数,用于按照指定的条件选择张量中的元素。

torch.select(input, dim, index)
"""
参数说明:

input:输入张量。
dim:要选择的维度。
index:要选择的索引。
"""
import torch

# 创建一个二维张量
x = torch.tensor([[1, 2, 3],
                  [4, 5, 6],
                  [7, 8, 9]])

# 在维度0上选择索引为1的元素
selected = torch.select(x, 0, 1)

print(selected)
# 输出:
# tensor([4, 5, 6])

9.torch.squeeze()

  torch.squeeze() 是 PyTorch 中的一个函数,用于去除张量中尺寸为 1 的维度。

torch.squeeze(input, dim=None)
"""
参数说明:

input:输入张量。
dim(可选):指定要挤压的维度。
					如果指定了 dim,则只会挤压该维度上的尺寸为 1 的维度;
					如果未指定 dim,则会挤压所有尺寸为 1 的维度
"""
import torch

# 创建一个具有尺寸为 1 的维度的张量
x = torch.tensor([[[1, 2, 3]]])

# 去除尺寸为 1 的维度
squeezed = torch.squeeze(x)
x = x.squeeze()
print(x.shape)  # 输出: torch.Size([3])
print(x)
# 输出:
# tensor([1, 2, 3])

10.torch.transpose()

  torch.transpose() 是 PyTorch 中的一个函数,用于对张量进行转置操作,即交换维度的顺序。

torch.transpose() 和 torch.permute() 都是用于改变张量的维度顺序,但它们之间有一些区别。

  • 参数形式:torch.transpose() 函数接受两个参数 dim0 和 dim1,用于指定要交换的维度索引。而torch.permute() 函数接受一个可变数量的参数 *dims,用于指定新的维度顺序。
  • 功能范围:torch.transpose() 只能实现维度之间的交换,即将指定的两个维度进行交换。而 torch.permute()可以实现更灵活的维度重新排列,可以同时对多个维度进行重排,不仅仅是交换。
  • 维度顺序:torch.transpose() 只能交换两个维度的顺序,无法改变其他维度的顺序。而 torch.permute()可以在任意维度上进行重排,允许灵活地改变维度的顺序。
import torch

# 创建一个三维张量
x = torch.tensor([[[1, 2, 3],
                   [4, 5, 6]],

                  [[7, 8, 9],
                   [10, 11, 12]]])

# 使用 torch.transpose() 进行维度交换
transposed = torch.transpose(x, 0, 2)

# 使用 torch.permute() 进行维度重排
permuted = x.permute(2, 1, 0)

print(transposed.shape)  # 输出: torch.Size([3, 2, 2])
print(transposed)
# 输出:
# tensor([[[ 1,  4],
#          [ 7, 10]],
#
#         [[ 2,  5],
#          [ 8, 11]],
#
#         [[ 3,  6],
#          [ 9, 12]]])

print(permuted.shape)  # 输出: torch.Size([3, 2, 2])
print(permuted)
# 输出:
# tensor([[[ 1,  4],
#          [ 7, 10]],
#
#         [[ 2,  5],
#          [ 8, 11]],
#
#         [[ 3,  6],
#          [ 9, 12]]])

11.torch.t()

  该函数期望输入为 2 维或更低维的张量,并将维度 0 和维度 1 进行转置。0 维和 1 维张量将按原样返回。当输入为 2 维张量时,这等价于 transpose(input, 0, 1)。

import torch

# 创建一个二维张量
x = torch.tensor([[1, 2, 3],
                  [4, 5, 6]])

# 对张量进行转置操作
transposed = torch.t(x)

print(transposed.shape)  # 输出: torch.Size([3, 2])
print(transposed)
# 输出:
# tensor([[1, 4],
#         [2, 5],
#         [3, 6]])
import torch

# 创建一个一维张量
x = torch.tensor([1, 2, 3, 4, 5])

# 对张量进行转置操作
transposed = torch.t(x)

print(transposed.shape)  # 输出: torch.Size([5])
print(transposed)
# 输出:
# tensor([1, 2, 3, 4, 5])

12.torch.real和torch.imag

  torch.real() 函数用于提取复数张量中的实部部分。

import torch

# 创建一个复数张量
x = torch.tensor([1+2j, 3+4j, 5+6j])

# 提取复数张量的实部
real_part = torch.real(x)

print(real_part.shape)  # 输出: torch.Size([3])
print(real_part)
# 输出:
# tensor([1., 3., 5.])
import torch

# 创建一个复数张量
x = torch.tensor([1+2j, 3+4j, 5+6j])

# 分别提取复数张量的实部和虚部
real_part = x.real
imag_part = x.imag

print(real_part.shape)  # 输出: torch.Size([3])
print(real_part)
# 输出:
# tensor([1., 3., 5.], dtype=torch.float32)

print(imag_part.shape)  # 输出: torch.Size([3])
print(imag_part)
# 输出:
# tensor([2., 4., 6.], dtype=torch.float32)

13.torch.unflatten()

  torch.unflatten() 函数用于将给定的维度展平的张量恢复成原始的形状。

torch.unflatten(dim, input_sizes)
"""
参数说明:
dim:指定要展平的维度。
input_sizes:一个元组或列表,指定原始张量在指定维度上的各个子张量的大小。
"""
import torch

# 创建一个展平的张量
x = torch.tensor([1, 2, 3, 4, 5, 6])

# 恢复成原始形状
unflattened = torch.unflatten(x,0, (2, 3))

print(unflattened.shape)  # 输出: torch.Size([2, 3])
print(unflattened)
# 输出:
# tensor([[1, 2, 3],
#         [4, 5, 6]])

14.torch.unsqueeze()

  torch.unsqueeze() 函数用于在指定的维度上增加一个维度,从而扩展张量的形状。

torch.unsqueeze(input, dim)
"""
参数说明:
input:输入张量。
dim:指定要扩展的维度索引。
"""
import torch

# 创建一个二维张量
x = torch.tensor([[1, 2, 3],
                  [4, 5, 6]])

# 在维度 1 上增加一个维度
expanded = torch.unsqueeze(x, dim=1)

print(expanded.shape)  # 输出: torch.Size([2, 1, 3])
print(expanded)
# 输出:
# tensor([[[1, 2, 3]],
#         [[4, 5, 6]]])

15.torch.view()

  torch.view() 是 PyTorch 中用于调整张量形状的函数,它允许你重新定义张量的维度和大小,而不改变原始数据。

torch.view(*shape)
"""
参数说明:
shape:一个整数元组或多个整数参数,用于指定新张量的形状。
"""
import torch

# 创建一个张量
x = torch.tensor([[1, 2, 3],
                  [4, 5, 6]])

# 改变张量的形状
reshaped = x.view(3, 2)

print(reshaped.shape)  # 输出: torch.Size([3, 2])
print(reshaped)
# 输出:
# tensor([[1, 2],
#         [3, 4],
#         [5, 6]])

16.torch.unbind()

  torch.unbind() 函数用于按维度将张量拆分为多个张量序列。

torch.unbind(input, dim=0)
"""
参数说明:
input:输入张量。
dim:指定的维度,沿该维度进行拆分。默认值为 0。
"""
import torch

# 创建一个张量
x = torch.tensor([[1, 2, 3],
                  [4, 5, 6]])

# 按维度 1 拆分张量
unpacked = torch.unbind(x, dim=1)

print(len(unpacked))  # 输出: 3
print(unpacked[0])  # 输出: tensor([1, 4])
print(unpacked[1])  # 输出: tensor([2, 5])
print(unpacked[2])  # 输出: tensor([3, 6])
import torch

# 创建一个张量
x = torch.tensor([[1, 2, 3],
                  [4, 5, 6]])

# 按维度 0 拆分张量
unpacked = torch.unbind(x, dim=0)

print(len(unpacked))  # 输出: 2
print(unpacked[0])  # 输出: tensor([1, 2, 3])
print(unpacked[1])  # 输出: tensor([4, 5, 6])

17.torch.split()

  torch.split() 函数用于按指定维度将张量分割为多个子张量。

torch.split(tensor, split_size_or_sections, dim=0)
"""
tensor:输入张量。
split_size_or_sections:指定分割的大小(整数)或分割的位置(整数列表)。
dim:指定的维度,沿该维度进行分割。默认值为 0。
"""
import torch

# 创建一个张量
x = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8])

# 按维度 0 使用整数列表分割张量
splits = torch.split(x, split_size_or_sections=[2, 3, 1, 2], dim=0)

print(len(splits))  # 输出: 4
print(splits[0])  # 输出: tensor([1, 2])
print(splits[1])  # 输出: tensor([3, 4, 5])
print(splits[2])  # 输出: tensor([6])
print(splits[3])  # 输出: tensor([7, 8])
import torch

# 创建一个张量
x = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8])

# 按维度 0 分割张量
splits = torch.split(x, split_size_or_sections=3, dim=0)

print(len(splits))  # 输出: 3
print(splits[0])  # 输出: tensor([1, 2, 3])
print(splits[1])  # 输出: tensor([4, 5, 6])
print(splits[2])  # 输出: tensor([7, 8])

18.torch.chunk()

  torch.chunk() 函数用于按指定维度将张量分割为多个块(chunks)。

torch.chunk(tensor, chunks, dim=0)
"""
参数说明:
tensor:输入张量。
chunks:要分割的块数。
dim:指定的维度,沿该维度进行分割。默认值为 0。
"""
import torch

# 创建一个张量
x = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8])

# 按维度 0 分割张量为两个块
chunks = torch.chunk(x, chunks=2, dim=0)

print(len(chunks))  # 输出: 2
print(chunks[0])  # 输出: tensor([1, 2, 3, 4])
print(chunks[1])  # 输出: tensor([5, 6, 7, 8])
import torch

# 创建一个二维张量
x = torch.tensor([[1, 2, 3],
                  [4, 5, 6]])

# 按维度 1 分割张量为三个块
chunks = torch.chunk(x, chunks=3, dim=1)

print(len(chunks))  # 输出: 3
print(chunks[0])  # 输出: tensor([[1],
                 #         [4]])
print(chunks[1])  # 输出: tensor([[2],
                 #         [5]])
print(chunks[2])  # 输出: tensor([[3],
                 #         [6]])
  • 23
    点赞
  • 22
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值