Pytorch分布采样:torch.distributions.Normal(mean, std).sample() 和 torch.multinomial()

  • torch.distributions.Normal(mean, std).sample():
    • 用于从正态分布(高斯分布)中采样。
    • mean 和 std 分别是正态分布的均值和标准差。
    • 采样的结果是一个连续值,可以是任意实数(在理论上),但在实际应用中,由于计算机的数值表示限制,采样值将是浮点数。
    • 通常用于连续动作空间的强化学习任务,或者任何需要连续随机变量的场景。
import torch  
from torch.distributions import Normal  
  
# 设定正态分布的均值和标准差  
mean = 0.0  
std = 1.0  
  
# 创建一个正态分布对象  
normal_dist = Normal(mean, std)  
  
# 从正态分布中采样一个值  
sample = normal_dist.sample()  
  
print(f"从正态分布中采样的值: {sample}")
  • torch.multinomial():
    • 用于从多项分布中采样。
    • 需要一个权重向量(或称为概率向量),表示每个类别被选中的概率。
    • 采样的结果是离散的,表示选中的类别索引。
    • 通常用于分类任务或者任何需要从一组离散选项中选择的场景。
import torch  
  
# 设定多项分布的概率权重  
# 假设有三个类别,分别对应的概率为0.1, 0.3, 0.6  
weights = torch.tensor([0.1, 0.3, 0.6])  
  
# 从多项分布中采样一个类别  
# 这里我们进行一次采样,选择1个样本(即num_samples=1)  
sample = torch.multinomial(weights, num_samples=1, replacement=True)  
  
print(f"从多项分布中采样的类别索引: {sample}")

主要区别:

  • 分布类型:Normal 是连续的正态分布,而 multinomial 是离散的多项分布。
  • 采样值类型:Normal.sample() 返回连续值(浮点数),而 torch.multinomial() 返回离散值(通常是整数索引)。
  • 应用场景:Normal.sample() 常用于需要连续动作或参数的场合,如机器人控制中的连续移动;torch.multinomial() 则常用于分类或需要从一组选项中选择一个的场景,如自然语言处理中的词汇选择。

在强化学习中,如果动作空间是连续的(例如,调整机器人的关节角度),可以选择 Normal(mean, std).sample() 来采样动作。如果动作空间是离散的(例如,选择向左、向右、跳跃或蹲下),可以使用 torch.multinomial() 来根据每个动作的概率采样一个动作。

  • 5
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值