pytorch中的索引与数据筛选

案例:

import torch

a = torch.rand(4, 4)
b = torch.rand(4, 4)
print(a)
print(b)
out = torch.where(a > 0.5, a, b)
print(out)
print('*' * 100)

a = torch.rand(4, 4)
print(a)
out = torch.index_select(a, dim=0, index=torch.tensor([0, 3]))
print(out)
print(out.shape)
print('*' * 100)

a = torch.linspace(1, 16, 16).view(size=[4, 4])
print(a)
out = torch.gather(a, dim=0, index=torch.tensor([[0,1,1,1],
                                                 [0,1,2,1]]))
print(out)

print('*' * 100)

a = torch.linspace(1, 16, 16).view(size=[4, 4])
print(a)
mask = torch.gt(a, 8)
print(mask)
out = torch.masked_select(a, mask)
print(out)

print('*' * 100)
a = torch.linspace(1, 16, 16).view(size=[4, 4])
print(a)
out = torch.take(a, index=torch.tensor([0,2]))
print(out)
# exit()

print('*' * 100)
a = torch.tensor([[0, 1], [3, 5]])
print(a)
out = torch.nonzero(a)
print(out)


"""
tensor([[0.8969, 0.0789, 0.3952, 0.9113],
        [0.7224, 0.0376, 0.0088, 0.6201],
        [0.0165, 0.6198, 0.6357, 0.1959],
        [0.2630, 0.7964, 0.9304, 0.4399]])
tensor([[0.9669, 0.6546, 0.5789, 0.4978],
        [0.2343, 0.1951, 0.2673, 0.4190],
        [0.6410, 0.3289, 0.6181, 0.8436],
        [0.0032, 0.4834, 0.2422, 0.2723]])
tensor([[0.8969, 0.6546, 0.5789, 0.9113],
        [0.7224, 0.1951, 0.2673, 0.6201],
        [0.6410, 0.6198, 0.6357, 0.8436],
        [0.0032, 0.7964, 0.9304, 0.2723]])
*****************************************************************************************
tensor([[0.2883, 0.7950, 0.1790, 0.0426],
        [0.7390, 0.5961, 0.6238, 0.5012],
        [0.3688, 0.5733, 0.4973, 0.2533],
        [0.8975, 0.1249, 0.0760, 0.1148]])
tensor([[0.2883, 0.7950, 0.1790, 0.0426],
        [0.8975, 0.1249, 0.0760, 0.1148]])
torch.Size([2, 4])
*****************************************************************************************
tensor([[ 1.,  2.,  3.,  4.],
        [ 5.,  6.,  7.,  8.],
        [ 9., 10., 11., 12.],
        [13., 14., 15., 16.]])
tensor([[ 1.,  6.,  7.,  8.],
        [ 1.,  6., 11.,  8.]])
*****************************************************************************************
tensor([[ 1.,  2.,  3.,  4.],
        [ 5.,  6.,  7.,  8.],
        [ 9., 10., 11., 12.],
        [13., 14., 15., 16.]])
tensor([[False, False, False, False],
        [False, False, False, False],
        [ True,  True,  True,  True],
        [ True,  True,  True,  True]])
tensor([ 9., 10., 11., 12., 13., 14., 15., 16.])
*****************************************************************************************
tensor([[ 1.,  2.,  3.,  4.],
        [ 5.,  6.,  7.,  8.],
        [ 9., 10., 11., 12.],
        [13., 14., 15., 16.]])
tensor([1., 3.])
*****************************************************************************************
tensor([[0, 1],
        [3, 5]])
tensor([[0, 1],
        [1, 0],
        [1, 1]])


Process finished with exit code

"""

  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
PyTorch数据流的处理通常涉及以下几个关键组件: 1. 数据集(Dataset):数据集是用来存储和组织原始数据的对象。PyTorch的`torch.utils.data.Dataset`是一个抽象类,你可以自定义一个类来继承它,并实现`__len__`和`__getitem__`方法。`__len__`方法返回数据集的大小,`__getitem__`方法返回给定索引数据样本。 2. 数据加载器(DataLoader):数据加载器是用来加载数据集并生成可迭代的数据批次的对象。PyTorch的`torch.utils.data.DataLoader`提供了一个简单易用的接口,可以将数据集包装成数据加载器。你可以指定每个批次的大小、是否打乱数据以及并行加载等参数。 3. 数据转换(Data Transformation):数据转换是在数据加载过程数据进行预处理或增强的操作。PyTorch的`torchvision.transforms`模块提供了一系列常用的图像转换操作,例如裁剪、缩放、翻转、归一化等。你可以使用这些转换函数来构建一个转换管道,并将其应用于数据集或数据加载器。 4. 设备选择(Device Selection):在PyTorch,你可以选择将张量和模型放在CPU或GPU上进行计算。通过调用`to`方法,你可以将张量或模型转移到特定设备上。例如,`tensor.to('cuda')`将张量转移到GPU上。 5. 迭代数据流(Iterating Data Flow):一旦数据加载器准备好了,你可以使用`for`循环迭代数据加载器的输出来遍历数据批次。每个数据批次都是一个包含输入数据和对应标签的元组,你可以将它们传递给模型进行训练或推断。 6. 批次处理(Batch Processing):在训练过程,通常会对一个批次的数据进行处理。这包括将输入数据传递给模型进行前向计算、计算损失、计算梯度、更新模型参数等操作。PyTorch提供了灵活的接口,可以轻松地进行这些操作。 总结起来,PyTorch数据流处理通常包括准备数据集、构建数据加载器、定义数据转换、选择设备、迭代数据加载器输出以及处理批次数据等步骤。这些步骤的具体实现可以根据你的任务和需求进行适当调整和扩展。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

浅蓝的风

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值