强化学习——蛇棋游戏gym环境搭建
学习强化学习精要核心算法与Tensorflow实现这本书中,关于蛇棋游戏利用gym搭建。游戏的规则非常简单,详细请参考冯超的书<<强化学习精要核心算法与Tensorflow实现>>。
下面是游戏的具体实现:
import numpy as np
import gym
from gym.spaces import Discrete
class SnakeEnv(gym.Env):
SIZE = 100 # 格子数量
def __init__(self, ladder_num, dices): # 构造函数需要传入两个参数:梯子数量和不同投掷骰子方法的最大值
self.ladder_num = ladder_num # 梯子数量
self.dices = dices # 不同投掷骰子方法的最大值
self.ladders = dict(np.random.randint(1, self.SIZE, size=(self.ladder_num, 2)))
# 生成梯子,格式类似{78: 33, 52: 97, 71: 64, 51: 32}
self.observation_space = Discrete(self.SIZE + 1) # 状态空间
self.action_space = Discrete(len(dices)) # 行为空间
keys = self.ladders.keys()
for k in list(keys): # 将梯子反过来存一遍
self.ladders[self.ladders[k]] = k
print('ladders info:')
print(self.ladders)
# 创造len(dices)个矩阵矩阵维度是100*100
self.p = np.zeros([len(dices), self.SIZE + 1, self.SIZE + 1], dtype=np.float) # P
ladder_move = np.vectorize(lambda x: self.ladders[x] if x in self.ladders else x) # 如果落入梯子区域,则前进到梯子的另一头,否则,还在该位子
# 下面是P的值
for i, dice in enumerate(self.dices):
# print(i,dice)
prob = 1.0 / dice
for src in range(1, 100):
step = np.arange(dice)
step += src
step = np.piecewise(step, [step > 100, step <= 100],
[lambda x: 200 - x, lambda x: x])
step = ladder_move(step)
for dst in step:
self.p[i, src, dst] += prob
self.p[:, 100, 100] = 1
self.pos = 1 # 游戏位置
def reset(self):
self.pos = 1 # 将游位置重置为1
return self.pos
def step(self, a):
step = np.random.randint(1, self.dices[a] + 1) # 根据选择的骰子进行投掷
self.pos += step
if self.pos == 100:
return 100, 100, 1, {} # 到达位置100,终止游戏
elif self.pos > 100:
self.pos = 200 - self.pos # 超过100时要向回走
if self.pos in self.ladders:
self.pos = self.ladders[self.pos] # 遇到梯子要前进到梯子的另一头
return self.pos, -1, 0, {}
def reward(self, s):
# 到达位置100则获得100奖励,否则每次-1
if s == 100:
return 100
else:
return -1
def render(self):
pass # 不进行图形渲染
代码中构造函数需要传入两个参数:梯子数量和不同投掷骰子方法的最大值。用一个dict存储梯子相连的两个格子的关系,用一个list保存可能的骰子可投掷的最大值;reset方法将pos设置为1,也就是游戏开始的位置;step完成一次投掷,参数a表示玩家将采用何种方法。完成位置的更新后,函数将返回玩家的新位置、得分和其他信息。
利用这个环境进行游戏,设定游戏中共有10个梯子,两种投掷骰子的方法分别可以投掷[1,3]和[1,6]的整数值,玩家将完全使用第一种策略斤进行游戏,代码如下:
from snake import SnakeEnv
env = SnakeEnv(10, [3,6]) # 10个梯子,2个筛子最大值分别是3和6
env.reset()
while True:
state, reward, terminate, _ = env.step(1) # 每次都选择正常的骰子
print (reward, state) # 打印r和s
if terminate == 1:
break
运行结果如下:
ladders info:
{57: 96, 27: 55, 9: 43, 34: 52, 82: 30, 75: 61, 70: 41, 91: 12, 90: 94, 96: 57, 55: 27, 43: 9, 52: 34, 30: 82, 61: 75, 41: 70, 12: 91, 94: 90}
-1 5
-1 7
-1 11
-1 14
-1 18
........(以下省略若干中间结果)
-1 99
-1 99
-1 97
-1 98
100 100