博客:https://yanpanlau.github.io/2016/07/10/FlappyBird-Keras.html
代码:https://github.com/yanpanlau/Keras-FlappyBird
一、代码说明
flappy_bird_utils.py主要加载游戏需要的图像音频等文件,wrapped_flappy_bird.py主要提供GameState类能接收动作返回更新的游戏状态,qlearn.py主要接收图像输入、图像预处理、卷积神经网络选择动作、使用DQN训练网络(The network will be trained millions of times, via an algorithm called Q-learning, to maximize the future expected reward)。
二、核心思想
Q-learning中最重要的就是Q-function,DQN就是用神经网络代替Q表。
1. Q-function是什么
Q-function Q(s, a) representing the maximum discounted future reward when we perform action a in state s. Q(s, a) gives you an estimation of how good to choose an action a in state s.
2. Q-function有什么用
Suppose you are in state s and you need to decide whether you take action a or b. If you have this magical Q-function, the answers become really simple – pick the action with highest Q-value!也即策略π(s)=argmaxaQ(s,a)。
3. 如何得到Q-function
That’s where Q-learning is coming from.
3.1 Q函数的递推表达式
Rt=rt+γ∗Rt+1
Recall the definition of Q-function (maximum discounted future reward if we choose action a in state s):
Q(st,at)=maxRt+1
therefore, we can rewrite the Q-function as below:
Q(s,a)=r+γ∗maxa′Q(s′,a′)
We could now use an iterative method to solve for the Q-function. Given a transition (s,a,r,s′) , we are going to convert this episode into training set for the network. i.e. We want r+γmaxaQ(s,a) to be equal to Q(s,a). You can think of finding a Q-value is a regession task now, I have a estimator r+γmaxaQ(s,a) and a predictor Q(s,a), I can define the mean squared error (MSE), or the loss function, as below:
L=[r+γmaxa′Q(s′,a′)−Q(s,a)]2
If L getting smaller, we know the Q-function is getting converged into the optimal value, which is our “strategy book”.
3.2 用神经网络学习Q函数
The idea of the DQN is that I use the neural network to COMPRESS this Q-table, using some parameters θ (We called it weight in Neural Network). So instead of handling a large table, I just need to worry the weights of the neural network. By smartly tuning the weight parameters, I can find the optimal Q-function via the various Neural Network training algorithm.
Q(s,a)=fθ(s)
where f is our neural network with input s and weight parameters θ.
Here is the code below to demonstrate how it works:
if t > OBSERVE:
#sample a minibatch to train on
minibatch = random.sample(D, BATCH)
inputs = np.zeros((BATCH, s_t.shape[1], s_t.shape[2], s_t.shape[3])) #32, 80, 80, 4
targets = np.zeros((inputs.shape[0], ACTIONS)) #32, 2
#Now we do the experience replay
for i in range(0, len(minibatch)):
state_t = minibatch[i][0]
action_t = minibatch[i][1] #This is action index
reward_t = minibatch[i][2]
state_t1 = minibatch[i][3]
terminal = minibatch[i][4]
# if terminated, only equals reward
inputs[i:i + 1] = state_t #I saved down s_t
targets[i] = model.predict(state_t) # Hitting each buttom probability
Q_sa = model.predict(state_t1)
if terminal:
targets[i, action_t] = reward_t
else:
targets[i, action_t] = reward_t + GAMMA * np.max(Q_sa)
loss += model.train_on_batch(inputs, targets)
s_t = s_t1
t = t + 1
4. Experience Replay
指的是从经验池中随机取小样而不是用最近的样本进行训练,用于改善神经网络不稳定的问题。It was found that approximation of Q-value using non-linear functions like neural network is not very stable. The most important trick to solve this problem is called experience replay. During the gameplay all the episode (s,a,r,s′) are stored in replay memory D. (I use Python function deque() to store it). When training the network, random mini-batches from the replay memory are used instead of most the recent transition, which will greatly improve the stability.
从开始训练到几乎无敌,该博主花了近30个小时。训练5万次只会一直往上飞,20万次有了大致的方向会尝试越过第一个柱子,30万次基本可以正确找到第一个柱子的间隙并尝试越过,40万次已有很高的几率过第一个柱子且有一定几率过第二个柱子,100万次达到普通玩家的正常水平、能顺利通过5~8个柱子,200万次几乎无敌。