【torch】rsample与sample的区别

sample():从概率分布中随机采样。所以,我们不能反向传播,因为它是随机的! (计算图被截断)。

请参阅torch.distributions.normal.Normal中示例的源代码:

def sample(self, sample_shape=torch.Size()):
    shape = self._extended_shape(sample_shape)
    with torch.no_grad():
        return torch.normal(self.loc.expand(shape), self.scale.expand(shape))

torch.normal 返回随机数张量。此外,torch.no_grad() 上下文可以防止计算图进一步增长。

你看,我们不能反向传播。 Sample() 返回的张量仅包含一些数字,而不是整个计算图。

那么,rsample() 是什么?

通过使用 rsample,我们可以反向传播,因为它使计算图保持活动状态。

如何?通过将随机性放在单独的参数中。这称为“重新参数化技巧”。

rsample:使用重新参数化技巧进行采样。

源码中有eps:

def rsample(self, sample_shape=torch.Size()):
    shape = self._extended_shape(sample_shape)
    eps = _standard_normal(shape, dtype=self.loc.dtype, device=self.loc.device)
    return self.loc + eps * self.scale

eps 是负责采样随机性的单独参数。

查看返回值:平均值 + eps * 标准差

eps 不依赖于您想要微分的参数。

所以,现在你可以自由地反向传播(=微分),因为当参数改变时 eps 不会改变。

(如果我们改变参数,重新参数化的样本的分布会因为 self.loc 和 self.scale 改变而改变,但 eps 的分布不会改变。)

请注意,采样的随机性来自于 eps 的随机采样。计算图本身不存在随机性。一旦选择了 eps,它就被固定了。 (eps 元素的分布在采样后是固定的。)

例如,在强化学习中的 SAC(Soft Actor-Critic)算法的实现中,eps 可能由与单个小批量动作相对应的元素组成(并且一个动作可能由许多元素组成)。

  • 1
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
在 PyTorch 中,`rsample()` 方法主要用于从概率分布中生成样本,并且在生成样本的同时计算梯度。这个方法通常用于需要进行反向传播的情况下,因为它会保留梯度信息。 以下是 `rsample()` 方法的一些常见用法: 1. 生成单个样本:可以直接调用概率分布对象的 `rsample()` 方法来生成一个样本。 ```python import torch from torch.distributions import Normal # 创建正态分布 normal_dist = Normal(0, 1) # 生成单个样本 sample = normal_dist.rsample() ``` 2. 批量生成样本:可以通过指定 `sample_shape` 参数来生成指定数量的样本。 ```python import torch from torch.distributions import Normal # 创建正态分布 normal_dist = Normal(0, 1) # 批量生成样本 samples = normal_dist.rsample(sample_shape=torch.Size([5])) ``` 3. 与其他张量进行运算:由于 `rsample()` 方法生成的样本是张量,因此可以与其他张量进行运算。 ```python import torch from torch.distributions import Normal # 创建正态分布和张量 normal_dist = Normal(0, 1) tensor = torch.tensor([1.0, 2.0, 3.0]) # 生成样本并与张量相乘 sample = normal_dist.rsample() result = sample * tensor ``` 需要注意的是,`rsample()` 方法会返回一个具有相同形状的张量,并且这个张量的值是从概率分布中以反向传播可微的方式生成的。 希望这些示例能够帮助你理解 `rsample()` 方法的用法。如果还有其他问题,请随时提问!

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值