Pytorch分布采样:torch.distributions.Normal(mean, std).sample() 和 torch.multinomial()

  • torch.distributions.Normal(mean, std).sample():
    • 用于从正态分布(高斯分布)中采样。
    • mean 和 std 分别是正态分布的均值和标准差。
    • 采样的结果是一个连续值,可以是任意实数(在理论上),但在实际应用中,由于计算机的数值表示限制,采样值将是浮点数。
    • 通常用于连续动作空间的强化学习任务,或者任何需要连续随机变量的场景。
import torch  
from torch.distributions import Normal  
  
# 设定正态分布的均值和标准差  
mean = 0.0  
std = 1.0  
  
# 创建一个正态分布对象  
normal_dist = Normal(mean, std)  
  
# 从正态分布中采样一个值  
sample = normal_dist.sample()  
  
print(f"从正态分布中采样的值: {sample}")
  • torch.multinomial():
    • 用于从多项分布中采样。
    • 需要一个权重向量(或称为概率向量),表示每个类别被选中的概率。
    • 采样的结果是离散的,表示选中的类别索引。
    • 通常用于分类任务或者任何需要从一组离散选项中选择的场景。
import torch  
  
# 设定多项分布的概率权重  
# 假设有三个类别,分别对应的概率为0.1, 0.3, 0.6  
weights
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值