RL-Ch4-Policy Gradient
策略梯度(Policy Gradient)
强化学习的例子
Scene | Agent | Env | Reward Function |
---|---|---|---|
Video | 游戏手柄 | 主机 | 杀1怪得20分 |
Go | AlphaGo | 李世石 | the Rule of Go |
在上述例子中,策略(policy) π \pi π的具体表现形式可认为是神经网络从输入层到输出层之间的参数矩阵 θ \theta θ。
下图为一个加入了action的马尔可夫链,
记Trajectory τ = { s 1 , a 1 , . . . , s T , a T } \tau=\{s_1,a_1,...,s_T,a_T\} τ={ s1,a1,...,sT,aT},则
p θ ( τ ) = p ( s 1 ) p θ ( a 1 ∣ s 1 ) p ( s 2 ∣ s 1 , a 1 ) p θ ( a 2 ∣ s 2 ) p ( s 3 ∣ s 2 , a 2 ) . . . = p ( s 1 ) ∏ t = 1 T p θ ( a t ∣ s t ) p ( s t + 1 ∣ s t , a t ) p_\theta(\tau)=p(s_1)p_\theta(a_1|s_1)p(s_2|s_1,a_1)p_\theta(a_2|s_2)p(s_3|s_2,a_2)...\\ =p(s_1)\prod_{t=1}^Tp_\theta(a_t|s_t)p(s_{t+1}|s_t,a_t) pθ(τ)=p(s1)pθ(a1∣s1)p(s2∣s1,a1)pθ(a2∣s2)p(s3∣s2,a2)...=p(s1)t=1∏Tpθ(at∣st)p(st+1∣st,at)
同时决策需要考虑到收益,我们在上图中加入reward。
R ( τ ) = ∑ t = 1 T r t R(\tau)=\sum_{t=1}^Tr_t R(τ)=t=1∑Trt
则期望收益为
R θ ˉ = ∑ τ R ( τ ) p θ ( τ ) = E τ ∼ p θ ( τ ) [ R ( τ ) ] \bar{R_\theta}=\sum_\tau R(\tau)p_\theta(\tau)=\mathbb{E}_{\tau\sim p_\theta(\tau)}[R(\tau)] Rθˉ=τ∑R(τ)pθ(τ)=Eτ∼pθ(τ)[R(τ)]
求期望收益的梯度有
∇ R θ ˉ = ∑ τ R ( τ ) ∇ p θ ( τ ) = ∑ τ R ( τ ) p θ ( τ ) ∇ p θ ( τ ) p θ ( τ ) = ∑ τ R ( τ ) p θ ( τ ) ∇ log p θ ( τ ) = E τ ∼ p θ ( τ ) [ R ( τ ) ∇ log p θ ( τ ) ] \nabla \bar{R_\theta}=\sum_\tau R(\tau)\nabla p_\theta(\tau)=\sum_\tau R(\tau)p_\theta(\tau)\frac{\nabla p_\theta(\tau)}{p_\theta(\tau)}\\ =\sum_\tau R(\tau)p_\theta(\tau)\nabla \log p_\theta(\tau)\\ =\mathbb{E}_{\tau\sim p_\theta(\tau)}[R(\tau)\nabla \log p_\theta(\tau)] ∇Rθˉ=τ∑R(τ)∇pθ(τ)=τ∑R(τ)pθ(τ)pθ(τ)∇pθ(τ)=τ∑R(τ)pθ(τ)∇logpθ(τ)=Eτ∼pθ(τ)[R(τ)∇logpθ(τ)]
策略梯度的计算
有两种计算方法:
- 蒙特卡洛采样,式(4)可以改写为如下式(5)
∇ R θ ˉ ≈ 1 N ∑ n = 1 N R ( τ n ) ∇ log p θ ( τ n ) = 式(1) 1 N ∑ n = 1 N ∑ t = 1 T n R ( τ n ) ∇ log p θ ( a t n ∣ s t n ) \nabla \bar{R_\theta}\approx \frac{1}{N}\sum_{n=1}^N R(\tau^n)\nabla\log p_\theta(\tau^n)\overset{\text{式(1)}}{=}\frac{1}{N}\sum_{n=1}^N\sum_{t=1}^{T_n}R(\tau^n)\nabla\log p_\theta(a_t^n|s_t^n) ∇Rθˉ≈N1n=1∑NR(τn)∇logpθ(τn)=式(1)N1n=1∑Nt=1∑TnR(τn)∇logpθ(atn∣stn)
- 时序差分更新,式(4)可以改写为如下式(6)
∇ R θ ˉ ≈ 1 N ∑ n = 1 N ∑ t = 1 T n Q n ( s t n , a t n