CartPole-v1 50行python实现
背景
很久没有写文章了,github上维护的博客 https://blog.dong.black/ 上篇还是今年5月15号的,距离现在也有一个多月了。
之前在优达学过一小段时间的机器学习,感觉挺有意思,最近又看到了gym,想动手再尝试一把。
CartPole-v1是gym中比较经(jian)典(dian)的题目,号称机器学习中的 hello world
,比较适合我这种小白。趁着周末的闲功夫,求解一下。
题目
官方其实已经给出解释了:
Reinforcement learning Q-learning approach to OpenAI Gym’s CartPole environment.
这本质上是一个Q-learning问题,但是作为强化学习的 hello world
,其实也有很多其它的解法。
作者尝试过使用DQN解,但是收敛速度和稳定性差强人意。个人电脑吱吱转,算法却死活不收敛。
罢了,使用线性模型蒙一下吧。
线性模型
基本思路,是使用单个神经元。这里也不反向传导了,直接在当前空间随机探索,然后查看效果。
激活函数就是根据结果符号输出action,可以简单理解为 int(input > 0)
。
基本的过程如下:
- 随机选取 weights + bias
- 生成随机步长 delta_weights
- 计算更新后 weights 可以得到的回报 rewards
- 如果 rewards 相较之前增大了,应用 delta_weights;减小了,反向应用 delta_weights
算法简单粗暴,但是针对这个简单粗暴的题目,效果挺好。
代码
代码量只有50行,也没有比较复杂的逻辑,所以这里就直接贴出来了。
import gym
import numpy as np
import matplotlib.pyplot as plt
def predict(state, weight):</