【深入浅出强化学习-编程实战】3 基于动态规划的方法
2.1 鸳鸯系统——基于动态规划的方法
左上为雄鸟,右上为雌鸟,中间有两道障碍物。目标:雄鸟找到雌鸟。
2.1.1 基于策略迭代代码展示
yuanyang_env.py
import pygame
from resource.load import *
import math
import time
import random
import numpy as np
class YuanYangEnv:
def __init__(self):
# 状态空间
self.states = [] # 0-99
for i in range(0, 100):
self.states.append(i)
self.actions = ['e', 'w', 'n', 's']
self.gamma = 0.8
# 值函数
self.value = np.zeros((10, 10)) # 10*10的表格
# 设置渲染属性
self.viewer = None # 一个渲染窗口
# 帧速率是指程序每秒在屏幕山绘制图像的数目,我们可以用FPS来表示它。一般的计算机都能达到每秒60帧的速度。如果我们把帧速率讲得比较低,那么游戏也会看上去较为卡顿。
self.FPSCLOCK = pygame.time.Clock()
# 屏幕大小
self.screen_size = (1200, 900)
# 雄鸟当前位置
self.bird_position = (0, 0)
# 雄鸟在x方向每走一次像素为120
self.limit_distance_x = 120
# 雄鸟在y方向每走一次像素为90
self.limit_distance_y = 90
# 每个障碍物大小为120像素*90像素
self.obstacle_size = [120, 90]
# 一共有两个障碍物墙,每个障碍物墙由8个小障碍物组成
self.obstacle1_x = []
self.obstacle1_y = []
self.obstacle2_x = []
self.obstacle2_y = []
for i in range(8):
# 第一个障碍物
self.obstacle1_x.append(360)
if i <= 3:
self.obstacle1_y.append(90 * i)
else:
self.obstacle1_y.append(90 * (i + 2))
# 第二个障碍物
self.obstacle2_x.append(720)
if i <= 4:
self.obstacle2_y.append(90 * i)
else:
self.obstacle2_y.append(90 * (i + 2))
# 雄鸟初始位置
self.bird_male_init_position = [0, 0]
# 雄鸟当前位置
self.bird_male_position = [0, 0]
# 雌鸟初始位置
self.bird_female_init_position = [1080, 0]
# 声明路径
self.path = []
# 雄鸟碰撞检测子函数
def collide(self, state_position):
# 用标志flag,flag1,flag2分别表示是否与障碍物、障碍物墙1、障碍物墙2发生碰撞
flag = 1
flag1 = 1
flag2 = 1
# 检测雄鸟是否与第一个障碍物墙发生碰撞
# 找到雄鸟与第一个障碍物所有障碍物x方向和y方向最近的障碍物的坐标差
# 并判断最近的坐标差是否大于一个最小运动距离
# 如果大于等于 就不会发生碰撞
dx = []
dy = []
for i in range(8):
dx1 = abs(self.obstacle1_x[i] - state_position[0])
dx.append(dx1)
dy1 = abs(self.obstacle1_y[i] - state_position[1])
dy.append(dy1)
mindx = min(dx)
mindy = min(dy)
if mindx >= self.limit_distance_x or mindy >= self.limit_distance_y:
flag1 = 0 # 没碰
# 是否与第二个障碍物墙碰撞
second_dx = []
second_dy = []
for i in range(8):
dx2 = abs(self.obstacle2_x[i] - state_position[0])
second_dx.append(dx2)
dy2 = abs(self.obstacle2_y[i] - state_position[1])
second_dy.append(dy2)
mindx = min(second_dx)
mindy = min(second_dy)
if mindx >= self.limit_distance_x or mindy >= self.limit_distance_y:
flag2 = 0 # 没碰
if flag1 == 0 and flag2 == 0:
flag = 0 # 没碰
# 是否超出边界,如果是,也认为发生碰撞
if state_position[0] > 1080 or state_position[0] < 0 or state_position[1] > 810 or state_position[1] < 0:
flag = 1 # 碰了
# 返回碰撞标志位
return flag
# 雄鸟是否找到雌鸟子函数
def find(self, state_position):
# 设置标志位flag
# 判断雄鸟当前位置和雌鸟位置坐标差,雄安与最小运动距离则为找到
flag = 0
if abs(state_position[0] - self.bird_female_init_position[0]) < self.limit_distance_x and abs(
state_position[1] - self.bird_female_init_position[1]) < self.limit_distance_y:
flag = 1
return flag
# 状态转化为像素坐标子函数
def state_to_position(self, state):
i = int(state / 10)
j = state % 10
position = [0, 0]
position[0] = 120 * j
position[1] = 90 * i
return position
# 像素转化为状态坐标子函数
def position_to_state(self, position):
i = position[0] / 120
j = position[1] / 90
return int(i + 10 * j)
def reset(self):
#随机产生初始状态
flag1=1
flag2=1
while flag1 or flag2 ==1:
#随机产生初始状态,0~99,randoom.random() 产生一个0~1的随机数
state=self.states[int(random.random()*len(self.states))]
state_position = self.state_to_position(state)
flag1 = self.collide(state_position)
flag2 = self.find(state_position)
return state
def transform(self,state, action):
#将当前状态转化为坐标
current_position=self.state_to_position(state)
next_position = [0,0]
flag_collide=0
flag_find=0
#判断当前坐标是否与障碍物碰撞
flag_collide=self.collide(current_position)
#判断状态是否是终点
flag_find=self.find(current_position)
if flag_collide==1 or flag_find==1:
return state, 0, True
#状态转移
if action=='e':
next_position[0]=current_position[0]+120
next_position[1]=current_position[1]
if action=='s':
next_position[0]=current_position[0]
next_position[1]=current_position[1]+90
if action=='w':
next_position[0] = current_position[0] - 120
next_position[1] = current_position[1]
if action=='n':
next_position[0] = current_position[0]
next_position[1] = current_position[1] - 90
#判断next_state是否与障碍物碰撞
flag_collide = self.collide(next_position)
#如果碰撞,那么回报为-1,并结束
if flag_collide==1:
return self.position_to_state(current_position),-1,True
#判断是否终点
flag_find = self.find(next_position)
if flag_find==1:
return self.position_to_state(next_position),1,True
return self.position_to_state(next_position), 0, False
def gameover(self):
for event in pygame.event.get():
if event.type == QUIT:
exit()
def render(self):
if self.viewer is None:
pygame.init()
#画一个窗口
self.viewer=pygame.display.set_mode(self.screen_size,0,32)
pygame.display.set_caption("yuanyang")
#下载图片
self.bird_male = load_bird_male()
self.bird_female = load_bird_female()
self.background = load_background()
self.obstacle = load_obstacle()
#self.viewer.blit(self.bird_male, self.bird_male_init_position)
#在幕布上画图片
self.viewer.blit(self.bird_female, self.bird_female_init_position)
self.viewer.blit(self.background, (0, 0))
self.font = pygame.font.SysFont('times', 15)
self.viewer.blit(self.background,(0,0))
#画直线
for i in range(11):
pygame.draw.lines(self.viewer, (255, 255, 255), True, ((120*i, 0), (120*i, 900)), 1)
pygame.draw.lines(self.viewer, (255, 255, 255), True, ((0, 90* i), (1200, 90 * i)), 1)
self.viewer.blit(self.bird_female, self.bird_female_init_position)
#画障碍物
for i in range(8):
self.viewer.blit(self.obstacle, (self.obstacle1_x[i], self.obstacle1_y[i]))
self.viewer.blit(self.obstacle, (self.obstacle2_x[i], self.obstacle2_y[i]))
#画小鸟
self.viewer.blit(self.bird_male, self.bird_male_position)
# 画值函数
for i in range(10):
for j in range(10):
surface = self.font.render(str(round(float(self.value[i, j]), 3)), True, (0, 0, 0))
self.viewer.blit(surface, (120 * i + 5, 90 * j + 70))
# 画路径点
for i in range(len(self.path)):
rec_position = self.state_to_position(self.path[i])
pygame.draw.rect(self.viewer, [255, 0, 0], [rec_position[0], rec_position[1], 120, 90], 3)
surface = self.font.render(str(i), True, (255, 0, 0))
self.viewer.blit(surface, (rec_position[0] + 5, rec_position[1] + 5))
pygame.display.update()
self.gameover()
# time.sleep(0.1)
self.FPSCLOCK.tick(30)
if __name__=="__main__":
yy=YuanYangEnv()
yy.render()
while True:
for event in pygame.event.get():
if event.type == QUIT:
exit()
DP_Poliy_Iter.py
# 便于生成随机数和时间延迟函数
import random
import time
# 导入YuanYangEnv
from yuanyang_env import YuanYangEnv
class DP_Policy_Iter:
def __init__(self,yuanyang):
# yuanyang为类的初始化参数,用于调用yuanyang系统
# 将鸳鸯游戏系统的状态空间和动作空间赋给当前类的状态和动作
self.states = yuanyang.states
self.actions = yuanyang.actions
# v 表示值函数
self.v = [0.0 for i in range(len(self.states)+1)]
# 声明一个数据结构为字典以便存储策略
self.pi = dict()
self.yuanyang = yuanyang
self.gamma = yuanyang.gamma
# 初始化策略
for state in self.states:
flag1 = 0
flag2 = 0
flag1 = yuanyang.collide(yuanyang.state_to_position(state))
flag2 = yuanyang.find(yuanyang.state_to_position(state))
if flag1 == 1 or flag2 == 1: continue
# 利用随机函数初始化策略
self.pi[state] = self.actions[int(random.random()*len(self.actions))]
# 策略评估子函数
# 策略评估包括两个训练,内层循环遍历状态空间中的每个状态,利用贝尔曼算子更新每个状态处的值函数;
# 外层循环则迭代计算整个值函数。
# 利用新旧值函数的累积和来控制策略评估是否收敛,如果累积和小于1e-6,那么说明收敛,则退出策略评估程序
def policy_evaluate(self):
# 策略评估计算值函数
for i in range(100):
delta = 0.0
for state in self.states:
flag1 = 0
flag2 = 0
flag1 = yuanyang.collide(yuanyang.state_to_position(state))
flag2 = yuanyang.find(yuanyang.state_to_position(state))
if flag1 == 1 or flag2 ==1 : continue
action = self.pi[state]
s,r,t = yuanyang.transform(state,action)
# 更新值
new_v = r + self.gamma*self.v[s]
delta += abs(self.v[state]-new_v)
# 更新值替换原来的值函数
self.v[state] = new_v
if delta < 1e-6:
print("策略评估迭代次数:",i)
break
# 策略改善子函数
# 有了值函数,就可以用贪婪策略对当前策略改善
# 外循环是个状态遍历,对每个状态处的策略进行改善
# 在每个状态处,利用当前的值函数找到对应的使之最大的动作
def policy_improve(self):
# 利用更新后值函数进行策略改善
for state in self.states:
flag1 = 0
flag2 = 0
flag1 = yuanyang.collide(yuanyang.state_to_position(state))
flag2 = yuanyang.find(yuanyang.state_to_position(state))
if flag1 == 1 or flag2 == 1: continue
a1 = self.actions[0]
s,r,t = yuanyang.transform(state,a1)
v1 = r + self.gamma*self.v[s]
# 找状态s时,采用哪种动作,值函数最大
for action in self.actions:
s,r,t = yuanyang.transform(state,action)
if v1 < r + self.gamma*self.v[s]:
a1 = action
v1 = r + self.gamma*self.v[s]
# 贪婪策略进行更新
self.pi[state] = a1
# 策略迭代子函数
# 循环进行策略评估和策略改善,当策略不再发生变化时,结束
def policy_iterate(self):
for i in range(100):
# 策略评估,变的是v
self.policy_evaluate()
# 策略改善
pi_old = self.pi.copy()
# 变的是pi
self.policy_improve()
if(self.pi == pi_old):
print("策略改善次数:",i)
break
# 主函数测试
# 首先,实例化一个鸳鸯游戏YuanYang;
# 然后将之作为参数传入策略爹地啊类中实例化一个策略迭代类policy_value
# 再利用该类调用策略迭代子函数policy_iterate()完成对策略的学习
if __name__=="__main__":
yuanyang = YuanYangEnv()
policy_value = DP_Policy_Iter(yuanyang)
policy_value.policy_iterate()
# 对学到的策略进行测试,初始状态为0,当前的路径还不存在
# 将策略迭代中学到的值函数给到游戏中,渲染出来
flag = 1 # 标志位
s = 0
path = []
# 将v值打印出来
for state in range(100):
i = int(state/10)
j = state % 10
yuanyang.value[j,i] = policy_value.v[state]
step_num = 0
# 下面是agent利用学到的策略pi与游戏环境进行交互
# 并将交互结果渲染出来,在雄鸟移动过程中,我们将移动的状态和动作都打印出来
# 将最优路径打印出来
while flag:
# 渲染路径点
path.append(s)
yuanyang.path = path
a = policy_value.pi[s]
print('%d->%s\t'%(s,a))
yuanyang.bird_male_position = yuanyang.state_to_position(s)
yuanyang.render()
time.sleep(0.2)
step_num+=1
s_,r,t=yuanyang.transform(s,a)
if t == True or step_num > 200: # t == True则结束transform
flag = 0
s = s_
# 渲染最后的路径点
yuanyang.bird_male_position = yuanyang.state_to_position(s)
path.append(s)
yuanyang.render()
while True:
yuanyang.render()
结果
2.1.2 基于值函数迭代代码展示
#值迭代算法策略评估和策略盖上放在一起,首先进行策略评估再取贪婪策略
def value_iteration(self):
for i in range(1000):
delta = 0.0
for state in self.states:
flag1 = 0
flag2 = 0
flag1 = yuanyang.collide(yuanyang.state_to_position(state))
flag2 = yuanyang.find(yuanyang.state_to_position(state))
a1 = self.actions[int(random.random()*4)]
s,r,t = yuanyang.transform(state,a1)
# 策略评估
v1 = r + self.gamma*self.v[s]
# 策略改善
for action in self.actions:
s,r,t = yuanyang.transform(state,action)
if v1 < r +self.gamma*self.v[s]:
a1 = action
v1 = r+self.gamma*self.v[s]
delta+= abs(v1-self.v[state])
self.pi[state]=a1
self.v[state] = v1
if delta<1e-6:
print("迭代次数为:",i)
break
结果
2.1.3 部分代码解析
- yuanyang_env.py——line 216
pygame.draw.rect(self.viewer,[255,0,0],[rec_position[0],rec_position[1],120,90],3)
- 原型:
pygame.draw.rect(Surface, color, Rect, width=0): return Rect
- 用途:在Surface上绘制矩形,第二个参数是线条(或填充)的颜色,第三个参数Rect的形式是
((x, y), (width, height))
,表示的是所绘制矩形的区域,其中第一个元组(x, y)
表示的是该矩形左上角的坐标,第二个元组(width, height)
表示的是矩形的宽度和高度。width表示线条的粗细,单位为像素;默认值为0,表示填充矩形内部。
- DP_Policy_Iter——line 14
self.v = [0.0 for i in range(len(self.states)+1)]
- Python在方括号中使用for循环,叫做列表解析List Comprehensions。
列表解析 - 根据已有列表,高效创建新列表的方式。
- 列表解析是Python迭代机制的一种应用,它常用于实现创建新的列表,因此用在
[]
中。 - 语法:
[expression for iter_val in iterable]
[expression for iter_val in iterable if cond_expr]
- 原式等于
self.v = []
for i in range(len(self.states)+1):
v.append(0.0)
print v