接下来的系列基于《强化学习精要:核心算法与TensorFlow实现》一书。这本书前面的章节部分都有python代码,非常推荐。
1. 蛇棋案例
我们有两个骰子,一个是常规的骰子(1-6各有1/6的概率出现,我们称之为正常骰子),另一个骰子是1~3,每个数字出现两次(也就是说1、2、3各有1/3的概率出现,我们称之为重复骰子)。我们需要选择一个骰子进行投掷。游戏从1出发,每次投到的多大的数字就往前走多少步,但是每次碰到梯子就需要走到另一头,直到走到100为止。如果超过100,则需要按照剩下的步数往回走。要求在最短的步数内到达终点。
2. python代码
下面基于env定义一个SnakeEnv类,用于模拟上面的问题。类中重新定义了reset/step/reward/render方法。
import numpy as np
import gym
from gym.spaces import Discrete
class SnakeEnv(gym.Env):
SIZE=100 # 格子数量
def __init__(self, ladder_num, dices):
self.ladder_num = ladder_num #梯子数量
self.dices = dices # 骰子数量
self.ladders = dict(np.random.randint(1, self.SIZE, \
size=(self.ladder_num, 2))) # 生成梯子,格式类似{78: 33, 52: 97, 71: 64, 51: 32}
self.observation_space=Discrete(self.SIZE+1)
self.action_space=Discrete(len(dices))
keys = self.ladders.keys()
for k in list(keys): # 将梯子反过来存一遍
self.ladders[self.ladders[k]] = k
print ('ladders info:')
print (self.ladders)
self.p = np.zeros([len(dices), self.SIZE+1, self.SIZE+1], dtype=np.float) # P
ladder_move = np.vectorize(lambda x: self.ladders[x] if x in self.ladders else x) # 如果落入梯子区域,则前进到梯子的另一头
# 下面是P的值
for i, dice in enumerate(self.dices):
prob = 1.0 / dice
for src in range(1, 100):
step = np.arange(dice)
step += src
step = np.piecewise(step, [step > 100, step <= 100],
[lambda x: 200 - x, lambda x: x])
step = ladder_move(step)
for dst in step:
self.p[i, src, dst] += prob
self.p[:, 100, 100]=1
self.pos = 1 # 游戏位置
def reset(self):
self.pos = 1 # 将游位置重置为1
return self.pos
def step(self, a):
step = np.random.randint(1, self.dices[a] + 1) # 根据选择的骰子进行投掷
self.pos += step
if self.pos == 100:
return 100, 100, 1, {} # 到达位置100,终止游戏
elif self.pos > 100:
self.pos = 200 - self.pos # 超过100时要向回走
if self.pos in self.ladders:
self.pos = self.ladders[self.pos] # 遇到梯子要前进到梯子的另一头
return self.pos, -1, 0, {}
def reward(self, s):
# 到达位置100则获得100奖励,否则每次-1
if s == 100:
return 100
else:
return -1
def render(self):
pass # 不进行图形渲染
运行下面的代码进行测试:
env = SnakeEnv(10, [3,6]) # 10个梯子,2个筛子最大值分别是3和6
env.reset()
while True:
state, reward, terminate, _ = env.step(1) # 每次都选择正常的骰子
print (reward, state) # 打印r和s
if terminate == 1:
break
我运行了一次的输出如下。由于是投骰子,因此每次输出的结果可能不一样。
ladders info:
{63: 79, 85: 54, 52: 43, 6: 42, 4: 24, 96: 50, 99: 80, 37: 32, 83: 73, 24: 4, 79: 63, 54: 85, 43: 52, 42: 6, 50: 96, 80: 99, 32: 37, 73: 83}
dice ranges:
[3, 6]
-1 24
-1 25
-1 31
-1 37
-1 38
-1 6
-1 11
-1 16
-1 21
-1 26
-1 37
-1 41
-1 46
-1 43
-1 45
-1 48
-1 85
-1 91
-1 92
-1 94
100 100