pytorch小记(十九):深入理解 PyTorch 的 `torch.randint()` 与 `.long()` 转换


在使用 PyTorch 进行深度学习建模或数据处理时,常常需要生成随机整数张量作为索引、伪标签或其它用途。本文将深入讲解 PyTorch 中的 torch.randint() 函数,以及为什么/如何结合 .long() 方法将张量转换为 64 位整型(LongTensor)。文末还会给出多种典型场景的实战示例,帮助你在项目中快速上手。


一、torch.randint() 基本概念

torch.randint() 用来在指定范围内均匀随机生成整数张量。它的函数签名如下:

torch.randint(
    low: int = 0,
    high: int,
    size: Tuple[int, ...],
    *,
    dtype: torch.dtype = torch.int64,
    layout: torch.layout = torch.strided,
    device: Optional[torch.device] = None,
    requires_grad: bool = False
) → Tensor
  • low:随机整数的下界(包含),默认为 0。
  • high:随机整数的上界(不包含),必须指定。
  • size:输出张量的形状,例如 (batch_size,)(2, 3)(B, C, H, W)
  • dtype:输出张量的数据类型,默认是 torch.int64(LongTensor)。
  • device:生成张量所在设备,如 'cpu' 或者 'cuda'

示例:生成一个二维随机整型张量

import torch

# 在 [0, 10) 范围内,生成 2×3 的随机整数张量
x = torch.randint(0, 10, (2, 3))
print(x)
# 可能输出:
# tensor([[2, 7, 1],
#         [5, 0, 9]])
print(x.dtype)   # torch.int64 (默认 LongTensor)

二、为什么需要调用 .long()

虽然 torch.randint 默认即可生成 torch.int64 的张量,但在以下场景中,我们仍常见到 .long() 的调用:

  1. 确保索引类型
    PyTorch 中,张量索引用的必须是 LongTensor(torch.int64)。如果手动指定了其它整型(如 torch.int32torch.uint8),则需要 .long() 转换:

    idx32 = torch.randint(0, 100, (16,), dtype=torch.int32)
    print(idx32.dtype)  # torch.int32
    
    idx64 = idx32.long()
    print(idx64.dtype)  # torch.int64
    # 这样才能用 idx64 在其它张量上进行索引
    
  2. 满足损失函数要求
    例如 torch.nn.CrossEntropyLoss 要求标签(targets)是 LongTensor:

    num_classes = 10
    batch_size = 32
    
    labels = torch.randint(0, num_classes, (batch_size,))  # 默认就是 int64
    # labels = labels.long()  # 如果你不确定 dtype,可以显式调用
    
    logits = torch.randn(batch_size, num_classes)
    loss_fn = torch.nn.CrossEntropyLoss()
    loss = loss_fn(logits, labels)
    
  3. 统一数据类型
    在复杂模型或数据管道中,手动控制 dtype 能避免莫名的类型不一致错误。显式地在生成后调用 .long(),可以给下游代码带来更好的可读性和健壮性。


三、典型场景示例

1. 随机索引采样

在自定义采样、数据重排或分批时,需要一组随机索引:

import torch

num_samples = 1000
batch_size = 64

# 生成 [0, num_samples) 范围内,大小为 batch_size 的随机索引
indices = torch.randint(0, num_samples, (batch_size,)).long()

# 假设 data 是一个形状为 [num_samples, ...] 的张量
data = torch.randn(num_samples, 3, 224, 224)
batch = data[indices]  # 用 long 类型索引

2. 伪标签生成

在无监督或对抗训练中,有时需要生成伪标签(fake labels):

import torch
import torch.nn as nn

num_classes = 5
batch_size = 16

# 随机生成伪标签
fake_labels = torch.randint(0, num_classes, (batch_size,)).long()

# 用 CrossEntropyLoss 计算损失
logits = torch.randn(batch_size, num_classes, requires_grad=True)
criterion = nn.CrossEntropyLoss()
loss = criterion(logits, fake_labels)
loss.backward()

3. 直接在 GPU 上生成 LongTensor

如果希望生成的随机张量直接存放在 GPU 上,同样可以指定 device,并明确 dtype:

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
batch_size, num_classes = 32, 10

# 一步到位生成 GPU 上的 LongTensor
labels = torch.randint(0, num_classes, (batch_size,),
                       device=device, dtype=torch.int64)
print(labels.device, labels.dtype)  # cuda:0 torch.int64

四、.long() 的几种等价写法

  • tensor.long()
  • tensor.to(torch.int64)
  • tensor.type(torch.int64)

它们的效果相同,大家可根据个人或团队习惯任选其一。通常推荐使用 .long(),因为更简洁。


五、小结

  • torch.randint(low, high, size):生成位于 [low, high) 的均匀随机整数张量,默认 dtype 是 torch.int64

  • .long():将任意整型或浮点型张量转换为 torch.int64(LongTensor),常用于索引、标签或保证数据类型一致。

  • 典型用途

    1. 随机采样索引
    2. 生成分类伪标签
    3. 在 GPU 上直接生成 long 型张量
  • 最佳实践:在不确定 dtype 时显式调用 .long(),或通过 dtype=torch.int64device='cuda' 一次性完成生成。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值