本文为强化学习笔记,主要参考以下内容:
- Reinforcement Learning: An Introduction
- 代码全部来自 GitHub
- 习题答案参考 Github
Consider the 4 × 4 4\times4 4×4 gridworld shown below.
The nonterminal states are
S
=
{
1
,
2
,
.
.
.
,
14
}
S = \{1, 2, . . . , 14\}
S={1,2,...,14}. There are four actions possible in each state,
A
=
{
u
p
,
d
o
w
n
,
r
i
g
h
t
,
l
e
f
t
}
A = \{up, down, right, left\}
A={up,down,right,left}, which deterministically cause the corresponding state transitions, except that actions that would take the agent off the grid in fact leave the state unchanged.
This is an undiscounted, episodic task. The reward is − 1 −1 −1 on all transitions until the terminal state is reached. The terminal state is shaded in the figure (although it is shown in two places, it is formally one state). The expected reward function is thus r ( s , a , s ′ ) = − 1 r(s, a, s') = −1 r(s,a,s′)=−1 for all states s , s ′ s, s' s,s′ and actions a a a.
Suppose the agent follows the equiprobable random policy. The left side of Figure 4.1 shows the sequence of value functions { v k } \{v_k\} {vk} computed by iterative policy evaluation. The final estimate is in fact v π v_\pi vπ, which in this case gives for each state the negation(负数) of the expected number of steps from that state until termination.
The last row shows an example of policy improvement for stochastic policies. The states with multiple arrows in the π ′ \pi' π′ diagram are those in which several actions achieve the maximum in (4.9); any apportionment of probability among these actions is permitted.
The last policy is guaranteed only to be an improvement over the random policy, but in this case it, and all policies after the third iteration, are optimal.
Code
#######################################################################
# Copyright (C) #
# 2016-2018 Shangtong Zhang(zhangshangtong.cpp@gmail.com) #
# 2016 Kenta Shimada(hyperkentakun@gmail.com) #
# Permission given to modify the code as long as you keep this #
# declaration at the top #
#######################################################################
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.table import Table
matplotlib.use('Agg')
Settings
WORLD_SIZE = 4
# left, up, right, down
ACTIONS = [np.array([0, -1]),
np.array([-1, 0]),
np.array([0, 1]),
np.array([1, 0])]
ACTION_PROB = 0.25
Environment
def is_terminal(state):
x, y = state
return (x == 0 and y == 0) or (x == WORLD_SIZE - 1 and y == WORLD_SIZE - 1)
def step(state, action):
if is_terminal(state):
return state, 0
next_state = (np.array(state) + action).tolist()
x, y = next_state
if x < 0 or x >= WORLD_SIZE or y < 0 or y >= WORLD_SIZE:
next_state = state
reward = -1
return next_state, reward
Visualization
def draw_image(image):
fig, ax = plt.subplots()
ax.set_axis_off()
tb = Table(ax, bbox=[0, 0, 1, 1])
nrows, ncols = image.shape
width, height = 1.0 / ncols, 1.0 / nrows
# Add cells
for (i, j), val in np.ndenumerate(image):
tb.add_cell(i, j, width, height, text=val,
loc='center', facecolor='white')
# Row and column labels...
for i in range(len(image)):
tb.add_cell(i, -1, width, height, text=i+1, loc='right',
edgecolor='none', facecolor='none')
tb.add_cell(-1, i, width, height/2, text=i+1, loc='center',
edgecolor='none', facecolor='none')
ax.add_table(tb)
Iterative policy evaluation
def compute_state_value(in_place=True, discount=1.0):
new_state_values = np.zeros((WORLD_SIZE, WORLD_SIZE))
iteration = 0
while True:
if in_place: # in-place version
state_values = new_state_values
else:
state_values = new_state_values.copy()
old_state_values = state_values.copy()
for i in range(WORLD_SIZE):
for j in range(WORLD_SIZE):
value = 0
for action in ACTIONS:
(next_i, next_j), reward = step([i, j], action)
value += ACTION_PROB * (reward + discount * state_values[next_i, next_j])
new_state_values[i, j] = value
max_delta_value = abs(old_state_values - new_state_values).max()
if max_delta_value < 1e-4:
break
iteration += 1
return new_state_values, iteration
def figure_4_1():
_, asycn_iteration = compute_state_value(in_place=True)
values, sync_iteration = compute_state_value(in_place=False)
draw_image(np.round(values, decimals=2))
print('In-place: {} iterations'.format(asycn_iteration))
print('Synchronous: {} iterations'.format(sync_iteration))
plt.savefig('../images/figure_4_1.png')
plt.close()
if __name__ == '__main__':
figure_4_1()
output:
In-place: 113 iterations
Synchronous: 172 iterations