Pytorch一些不常见函数解析(持续更新)

1. Categorical()

torch.distributions.Categorical()

可以按照一定概率产生具体数字,比如:

import torch
from torch.distributions import Categorical

rand = Categorical(torch.tensor([0.25,0.25,0.25,0.25]))
print(rand.sample())
# tensor(3)

这个Categorical()还有一些有趣的功能,比如可以求策略梯度REINFORCE,有个小例子:

probs = policy_network(state)
# Note that this is equivalent to what used to be called multinomial
m = Categorical(probs)
action = m.sample()
next_state, reward = env.step(action)
loss = -m.log_prob(action) * reward
loss.backward()

注意,这里的策略梯度REINFORCE的公式为:
在这里插入图片描述

  • 左边是(神经网络)的参数
  • Alpha是学习速率,r是奖励(reward),p则是在状态s以及给定策略pi中执行动作a的概率

而上述代码中的m.log_prob(value)函数则是公式中的log部分

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值