1. Categorical()
torch.distributions.Categorical()
可以按照一定概率产生具体数字,比如:
import torch
from torch.distributions import Categorical
rand = Categorical(torch.tensor([0.25,0.25,0.25,0.25]))
print(rand.sample())
# tensor(3)
这个Categorical()还有一些有趣的功能,比如可以求策略梯度REINFORCE,有个小例子:
probs = policy_network(state)
# Note that this is equivalent to what used to be called multinomial
m = Categorical(probs)
action = m.sample()
next_state, reward = env.step(action)
loss = -m.log_prob(action) * reward
loss.backward()
注意,这里的策略梯度REINFORCE的公式为:
- 左边是(神经网络)的参数
- Alpha是学习速率,r是奖励(reward),p则是在状态s以及给定策略pi中执行动作a的概率
而上述代码中的m.log_prob(value)函数则是公式中的log部分