torch.distributions.Normal(mean, std).sample()
:
- 用于从正态分布(高斯分布)中采样。
mean
和 std
分别是正态分布的均值和标准差。
- 采样的结果是一个连续值,可以是任意实数(在理论上),但在实际应用中,由于计算机的数值表示限制,采样值将是浮点数。
- 通常用于连续动作空间的强化学习任务,或者任何需要连续随机变量的场景。
import torch
from torch.distributions import Normal
# 设定正态分布的均值和标准差
mean = 0.0
std = 1.0
# 创建一个正态分布对象
normal_dist = Normal(mean, std)
# 从正态分布中采样一个值
sample = normal_dist.sample()
print(f"从正态分布中采样的值: {sample}")
torch.multinomial()
:
- 用于从多项分布中采样。
- 需要一个权重向量(或称为概率向量),表示每个类别被选中的概率。
- 采样的结果是离散的,表示选中的类别索引。
- 通常用于分类任务或者任何需要从一组离散选项中选择的场景。
import torch
# 设定多项分布的概率权重
# 假设有三个类别,分别对应的概率为0.1, 0.3, 0.6
weights