torch.Generator 随机数生成器

PyTorch 通过 torch.Generator 类来操作随机数的生成


我们通常不会手动实例化 torch.Generator 随机数生成器, 当需要操作随机数时, PyTorch 会自动创建一个全局的 torch.Generator 实例, 随后随机数的生成默认使用该 torch.Generator 实例

使用默认随机数生成器

使用 torch.xxx 的形式操作随机数, 使用的是默认随机数生成器

import torch

# 设置默认的随机数种子
torch.manual_seed(0)

# 查看默认的随机数种子
torch.initial_seed()
使用指定随机数生成器

通过 torch.Generator 随机数生成器实例调用相应的方法:

实例化: g = torch.Generator()

  • g.manual_seed(int): 设置随机数种子
  • g.initial_seed(): 获取随机数的种子
  1. 获取默认的随机数生成器实例
g_1 = torch.default_generator

# 查看指定随机数生成器的种子(结果也为 0)
g_1.initial_seed()

通常使用的函数 torch.manual_seed() 会作用到默认的 torch.Generator 实例上, 并返回该默认实例

g_2 = torch.manual_seed(0)

# 结果为 True
g_1 is g_2

  1. 实例化并手动指定随机数种子
>>> import torch
>>> g = torch.Generator()
>>> g.manual_seed(0)
<torch._C.Generator object at 0x7f36e1d43850>
>>> g.initial_seed()
0

如果没有手动指定随机数种子, 系统会自动调用 g.seed() 生成种子

>>> import torch
>>> g = torch.Generator()
>>> g.seed()
4181592217041883240

  1. 在使用需要随机数的函数时, 如果没有指定 torch.Generator 实例, 则会使用全局默认的随机数生成器, 可以通过关键字参数 generator 指定随机数生成器
# 1. 使用默认随机数生成器
torch.manual_seed(1)

# 结果 tensor([0, 4, 2, 3, 1])
torch.randperm(5)


# 2. 手动创建随机数生成器
g = torch.Generator()
g.manual_seed(1)

# 结果也为 tensor([0, 4, 2, 3, 1])
torch.randperm(5, generator=g)

在 GPU 设备上使用随机数生成器

torch.Generator 随机数生成器实例会区分 CPU 与 GPU 两种设备

  1. 给当前 GPU 设备的默认随机数生成器设置种子
torch.cuda.manual_seed(0)
torch.cuda.initial_seed()
  1. 获取 GPU 设备的默认随机数生成器
>>> torch.cuda.default_generators
(<torch._C.Generator at 0x7f744f8359f0>,)

因为一台电脑可以有多个 GPU 设备, 所以返回了 torch.Generator 元组

  1. 实例化一个 GPU 类型的随机数生成器
import torch

g = torch.Generator(device='cuda')
g.manual_seed(1)

t = torch.randperm(5, device='cuda:0', generator=g)
print(t)

输出:

tensor([1, 3, 2, 4, 0], device='cuda:0')
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值