【深入浅出强化学习-编程实战】4 基于时间差分的方法
4.1 鸳鸯系统——基于时间差分的方法
左上为雄鸟,右上为雌鸟,中间有两道障碍物。目标:雄鸟找到雌鸟。
yuanyang_env_td.py
import pygame
from resource.load import *
import math
import time
import random
import numpy as np
class YuanYangEnv:
def __init__(self):
# 状态空间
self.states = [] # 0-99
for i in range(0, 100):
self.states.append(i)
self.actions = ['e', 'w', 'n', 's']
# 无模型的强化学习算法需要评估行为-值函数
# 每个方格表示一个状态,每个状态的4个值分别对应着4个动作的行为-值函数
self.gamma = 0.95 # 蒙特卡洛利用整条轨迹的数据进行评估策略,如果gamma太小,后面回报的贡献会很快衰减
# 行为-值函数
self.action_value = np.zeros((100, 4))
# 设置渲染属性
self.viewer = None # 一个渲染窗口
# 帧速率是指程序每秒在屏幕山绘制图像的数目,我们可以用FPS来表示它。一般的计算机都能达到每秒60帧的速度。如果我们把帧速率讲得比较低,那么游戏也会看上去较为卡顿。
self.FPSCLOCK = pygame.time.Clock()
# 屏幕大小
self.screen_size = (1200, 900)
# 雄鸟当前位置
self.bird_position = (0, 0)
# 雄鸟在x方向每走一次像素为120
self.limit_distance_x = 120
# 雄鸟在y方向每走一次像素为90
self.limit_distance_y = 90
# 每个障碍物大小为120像素*90像素
self.obstacle_size = [120, 90]
# 一共有两个障碍物墙,每个障碍物墙由8个小障碍物组成
self.obstacle1_x = []
self.obstacle1_y = []
self.obstacle2_x = []
self.obstacle2_y = []
self.path = []
for i in range(8):
# 第一个障碍物
self.obstacle1_x.append(360)
if i <= 3:
self.obstacle1_y.append(90 * i)
else:
self.obstacle1_y.append(90 * (i + 2))
# 第二个障碍物
self.obstacle2_x.append(720)
if i <= 4:
self.obstacle2_y.append(90 * i)
else:
self.obstacle2_y.append(90 * (i + 2))
# 雄鸟初始位置
self.bird_male_init_position = [0, 0]
# 雄鸟当前位置
self.bird_male_position = [0, 0]
# 雌鸟初始位置
self.bird_female_init_position = [1080, 0]
# 雄鸟碰撞检测子函数
def collide(self, state_position):
# 用标志flag,flag1,flag2分别表示是否与障碍物、障碍物墙1、障碍物墙2发生碰撞
flag = 1
flag1 = 1
flag2 = 1
# 检测雄鸟是否与第一个障碍物墙发生碰撞
# 找到雄鸟与第一个障碍物所有障碍物x方向和y方向最近的障碍物的坐标差
# 并判断最近的坐标差是否大于一个最小运动距离
# 如果大于等于 就不会发生碰撞
dx = []
dy = []
for i in range(8):
dx1 = abs(self.obstacle1_x[i] - state_position[0])
dx.append(dx1)
dy1 = abs(self.obstacle1_y[i] - state_position[1])
dy.append(dy1)
mindx = min(dx)
mindy = min(dy)
if mindx >= self.limit_distance_x or mindy >= self.limit_distance_y:
flag1 = 0 # 没碰
# 是否与第二个障碍物墙碰撞
second_dx = []
second_dy = []
for i in range(8):
dx2 = abs(self.obstacle2_x[i] - state_position[0])
second_dx.append(dx2)
dy2 = abs(self.obstacle2_y[i] - state_position[1])
second_dy.append(dy2)
mindx = min(second_dx)
mindy = min(second_dy)
if mindx >= self.limit_distance_x or mindy >= self.limit_distance_y:
flag2 = 0 # 没碰
if flag1 == 0 and flag2 == 0:
flag = 0 # 没碰
# 是否超出边界,如果是,也认为发生碰撞
if state_position[0] > 1080 or state_position[0] < 0 or state_position[1] > 810 or state_position[1] < 0:
flag = 1 # 碰了
# 返回碰撞标志位
return flag
# 雄鸟是否找到雌鸟子函数
def find(self, state_position):
# 设置标志位flag
# 判断雄鸟当前位置和雌鸟位置坐标差,雄安与最小运动距离则为找到
flag = 0
if abs(state_position[0] - self.bird_female_init_position[0]) < self.limit_distance_x and abs(
state_position[1] - self.bird_female_init_position[1]) < self.limit_distance_y:
flag = 1
return flag
# 状态转化为像素坐标子函数
def state_to_position(self, state):
i = int(state / 10)
j = state % 10
position = [0, 0]
position[0] = 120 * j
position[1] = 90 * i
return position
# 像素转化为状态坐标子函数
def position_to_state(self, position):
i = position[0] / 120
j = position[1] / 90
return int(i + 10 * j)
def reset(self):
#随机产生初始状态
flag1=1
flag2=1
while flag1 or flag2 ==1:
#随机产生初始状态,0~99,randoom.random() 产生一个0~1的随机数
state=self.states[int(random.random()*len(self.states))]
state_position = self.state_to_position(state)
flag1 = self.collide(state_position)
flag2 = self.find(state_position)
return state
# 原来的回报只有在找到目标点和碰到障碍物的时候才有回报,是稀疏回报
# 蒙特卡洛方法对于稀疏回报问题估计方差无穷大
# 为此,我们每一步都给出了回报,将稀疏回报变成稠密汇报
def transform(self,state, action):
# 将当前状态转化为坐标
current_position=self.state_to_position(state)
next_position = [0,0]
flag_collide=0
flag_fnd=0
# 判断当i前坐标是否与障碍物碰撞
flag_collide=self.collide(current_position)
# 判断状态是否是终点
flag_find=self.find(current_position)
if flag_collide == 1:
return state, -10, True
if flag_find == 1:
return state, 10, True
# 状态转移
if action=='e':
next_position[0]=current_position[0]+120
next_position[1]=current_position[1]
if action=='s':
next_position[0]=current_position[0]
next_position[1]=current_position[1]+90
if action=='w':
next_position[0] = current_position[0] - 120
next_position[1] = current_position[1]
if action=='n':
next_position[0] = current_position[0]
next_position[1] = current_position[1] - 90
# 判断next_state是否与障碍物碰撞
flag_collide = self.collide(next_position)
# 如果碰撞,那么回报为-10,并结束
if flag_collide==1:
return self.position_to_state(current_position),-10,True
# 判断是否终点
flag_find = self.find(next_position)
if flag_find==1:
return self.position_to_state(next_position),10,True
# 每走一步回报-2
return self.position_to_state(next_position),-0.1, False
def gameover(self):
for event in pygame.event.get():
if event.type == QUIT:
exit()
def render(self):
if self.viewer is None:
pygame.init()
#画一个窗口
self.viewer=pygame.display.set_mode(self.screen_size,0,32)
pygame.display.set_caption("yuanyang")
#下载图片
self.bird_male = load_bird_male()
self.bird_female = load_bird_female()
self.background = load_background()
self.obstacle = load_obstacle()
#self.viewer.blit(self.bird_male, self.bird_male_init_position)
#在幕布上画图片
self.viewer.blit(self.bird_female, self.bird_female_init_position)
self.viewer.blit(self.background, (0, 0))
self.font = pygame.font.SysFont('times', 15)
self.viewer.blit(self.background,(0,0))
#画直线
for i in range(11):
pygame.draw.lines(self.viewer, (255, 255, 255), True, ((120*i, 0), (120*i, 900)), 1)
pygame.draw.lines(self.viewer, (255, 255, 255), True, ((0, 90* i), (1200, 90 * i)), 1)
self.viewer.blit(self.bird_female, self.bird_female_init_position)
#画障碍物
for i in range(8):
self.viewer.blit(self.obstacle, (self.obstacle1_x[i], self.obstacle1_y[i]))
self.viewer.blit(self.obstacle, (self.obstacle2_x[i], self.obstacle2_y[i]))
#画小鸟
self.viewer.blit(self.bird_male, self.bird_male_position)
# 画动作-值函数
for i in range(100):
y = int(i/10)
x = i % 10
# 往东的值函数
surface = self.font.render(str(round(float(self.action_value[i,0]),2)),True,(0,0,0))
self.viewer.blit(surface,(120*x+80,90*y+45))
# 往南的值函数
surface = self.font.render(str(round(float(self.action_value[i,1]),2)),True,(0,0,0))
self.viewer.blit(surface, (120 * x + 50, 90 * y + 70))
# 往西的值函数
surface = self.font.render(str(round(float(self.action_value[i, 2]), 2)), True, (0, 0, 0))
self.viewer.blit(surface, (120 * x + 10, 90 * y + 45))
# 往北的值函数
surface = self.font.render(str(round(float(self.action_value[i, 3]), 2)), True, (0, 0, 0))
self.viewer.blit(surface, (120 * x + 50, 90 * y + 10))
# 画路径点
for i in range(len(self.path)):
rec_position = self.state_to_position(self.path[i])
pygame.draw.rect(self.viewer, [255, 0, 0], [rec_position[0], rec_position[1], 120, 90], 3)
surface = self.font.render(str(i), True, (255, 0, 0))
self.viewer.blit(surface, (rec_position[0] + 5, rec_position[1] + 5))
pygame.display.update()
self.gameover()
# time.sleep(0.1)
self.FPSCLOCK.tick(30)
if __name__=="__main__":
yy=YuanYangEnv()
yy.render()
while True:
for event in pygame.event.get():
if event.type == QUIT:
exit()
TD_RL.py
import numpy as np
import random
import os
import pygame
import time
import matplotlib.pyplot as plt
from yuanyang_env_td import *
from yuanyang_env_td import YuanYangEnv
class TD_RL:
def __init__(self, yuanyang):
self.gamma = yuanyang.gamma
self.yuanyang = yuanyang
# 值函数初始值
self.qvalue = np.zeros((len(yuanyang.states), len(yuanyang.actions)))
# 定义贪婪策略
def greedy_policy(self, qfun, state):
amax = qfun[state, :].argmax()
return self.yuanyang.actions[amax]
# 定义epsilon-greedy 策略
def epsilon_greedy_policy(self, qfun, state, epsilon):
amax = qfun[state, :].argmax()
# 概率部分
if np.random.uniform() < 1 - epsilon:
# 最优动作
return self.yuanyang.actions[amax]
else:
return self.yuanyang.actions[int(random.random() * len(self.yuanyang.actions))]
# 找到动作所对应的序号
def find_anum(self, a):
for i in range(len(self.yuanyang.actions)):
if a == self.yuanyang.actions[i]:
return i
# Sarsa算法
# 初始化行为值函数
# 算法主体:1.利用采样策略控制智能体与环境交互,得到交互数据
# 2.利用时间差分的方法估计当前状态s处采取动作a时的i行为-值函数
# 3.智能体往前推进一步
# 输出最终的最优贪婪策略
def sarsa(self, num_iter, alpha, epsilon):
iter_num = []
self.qvalue = np.zeros((len(self.yuanyang.states), len(self.yuanyang.actions)))
# 外循环:实现多条轨迹循环,产生多次实验
for iter in range(num_iter):
# 随机初始化状态
epsilon = epsilon * 0.99
s_sample = []
# 初始状态s0
# s = self.yuanyang.reset()
# 初始状态设为0,也就是说每条轨迹从初始状态0处开始
# 接着调用子函数greedy_test(),该函数用来测试使用贪婪策略是否能找到目标点
# 如果能找到目标点,则打印出为了找到目标点,算法共迭代的次数,
# 找到点后继续学习,以便找到最优路径
# 如果找到最短路径,打印出找到最短路径需要的迭代次数,并结束学习
s = 0
flag = self.greedy_test()
if flag == 1:
iter_num.append(iter)
if len(iter_num) < 2:
print("sarsa第1次完成任务需要迭代次数是:", iter_num[0])
if flag == 2:
print("sarsa第1次实现最短路径需要的迭代次数是:", iter)
break
# 利用epsilon-greedy策略选初始动作
a = self.epsilon_greedy_policy(self.qvalue, s, epsilon)
t = False
count = 0
# 内循环:智能体与环境交互产生一条轨迹
# 第二个循环,1个实验,s0-s1-s2-s_terminate
while False == t and count < 30:
# 与环境交互得到下一个状态
s_next, r, t = self.yuanyang.transform(s, a)
a_num = self.find_anum(a)
# 如果智能体回到本次轨迹中已有的状态,给出负的回报
if s_next in s_sample:
r = -2
s_sample.append(s)
# 判断是否为终止状态
if t == True:
q_target = r
else:
# 下一个状态处的最大动作,此处为同策略
a1 = self.epsilon_greedy_policy(self.qvalue, s_next, epsilon)
a1_num = self.find_anum(a1)
# Q learning的更新公式
q_target = r + self.gamma * self.qvalue[s_next, a1_num]
# 利用td方法更新动作值函数alpha
self.qvalue[s, a_num] = self.qvalue[s, a_num] + alpha * (q_target - self.qvalue[s, a_num])
# 转到下个状态
s = s_next
# 行为策略
a = self.epsilon_greedy_policy(self.qvalue, s, epsilon)
count += 1
return self.qvalue
# 几乎和sarsa完全相同,唯一区别是值函数评估阶段
def qlearning(self, num_iter, alpha, epsilon):
iter_num = []
self.qvalue = np.zeros((len(self.yuanyang.states), len(self.yuanyang.actions)))
# 外循环:实现多条轨迹循环,产生多次实验
for iter in range(num_iter):
# 随机初始状态
s = yuanyang.reset()
s = 0
flag = self.greedy_test()
if flag == 1:
iter_num.append(iter)
if len(iter_num) < 2:
print("qlearning第1次完成任务需要迭代次数是:", iter_num[0])
if flag == 2:
print("qlearning第1次实现最短路径需要的迭代次数是:", iter)
break
s_sample = []
# 随机选初始动作
# 利用epsilon-greedy策略选初始动作
a = self.epsilon_greedy_policy(self.qvalue, s, epsilon)
t = False
count = 0
# 内循环:智能体与环境交互产生一条轨迹
# 第二个循环,1个实验,s0-s1-s2-s_terminate
while False == t and count < 30:
# 与环境交互得到下一个状态
s_next, r, t = self.yuanyang.transform(s, a)
a_num = self.find_anum(a)
# 如果智能体回到本次轨迹中已有的状态,给出负的回报
if s_next in s_sample:
r = -2
s_sample.append(s)
# 判断是否为终止状态
if t == True:
q_target = r
else:
# 下一个状态处的最大动作a1用greedy_policy实现
a1 = self.greedy_policy(self.qvalue, s_next)
a1_num = self.find_anum(a1)
# Q learning的更新公式
q_target = r + self.gamma * self.qvalue[s_next, a1_num]
# 利用td方法更新动作值函数alpha
self.qvalue[s, a_num] = self.qvalue[s, a_num] + alpha * (q_target - self.qvalue[s, a_num])
# 转到下个状态
s = s_next
# 行为策略
a = self.epsilon_greedy_policy(self.qvalue, s, epsilon)
count += 1
return self.qvalue
# 用来测试初始状态为0时,采用当前的贪婪策略是否找到目标点
def greedy_test(self):
s = 0
s_sample = []
done = False
flag = 0
step_num = 0
while False == done and step_num < 30:
a = self.greedy_policy(self.qvalue, s)
# 与环境交互
s_next, r, done = self.yuanyang.transform(s, a)
s_sample.append(s)
s = s_next
step_num += 1
# 如果找到目标点,flag标志位为1;
# 如果找到目标点的步数小于21,即最短路径,则标志位为2
if s == 9:
flag = 1
if s == 9 and step_num < 21:
flag = 2
return flag
# 主函数
# 首先实例化yuanyang类,和时间差分算法brain
# 调用时间差分算法类Sarsa算法,将行为值赋予qvalue1
# 调用时间差分算法类Qlearning算法,将行为值函数赋予qvalue2
# 打印学到的值函数
if __name__=="__main__":
yuanyang = YuanYangEnv()
brain = TD_RL(yuanyang)
#qvalue1 = brain.sarsa(num_iter=5000,alpha=0.1,epsilon=0.8)
qvalue2 = brain.qlearning(num_iter=5000,alpha=0.1,epsilon=0.1)
# 打印学到的值函数
yuanyang.action_value = qvalue2
# 测试学到的策略
flag = 1
s = 0
step_num = 0
path = []
# 将最优路径打印出来
while flag:
# 渲染路径点
path.append(s)
yuanyang.path = path
a = brain.greedy_policy(qvalue2,s)
print('%d->%s\t'%(s,a),qvalue2[s,0],qvalue2[s,1],qvalue2[s,2],qvalue2[s,3])
yuanyang.brid_male_polistion = yuanyang.state_to_position(s)
yuanyang.render()
time.sleep(0.25)
step_num+=1
s_,r,t = yuanyang.transform(s,a)
if t == True or step_num >30:
flag = 0
s = s_
# 渲染最后的路径点
yuanyang.bird_male_position = yuanyang.state_to_position(s)
path.append(s)
yuanyang.render()
while True:
yuanyang.render()
debug了好久才发现是参数问题。。。。。
4.2 Sarsa结果
sarsa 第一次完成任务需要的迭代次数为: 264
sarsa 第一次实现最短路径需要的迭代次数为: 280
0->e -1.6938601071601547 -8.221688473191929 -8.220110894950007 -1.8571971457047995
1->s -1.7340008224766394 -2.038007579842249 -3.3338431198 -1.6022903503416959
11->s -1.6085773294523065 -1.7433305585501473 -1.6289299188173199 -1.3764075987003737
21->e -1.4769383164864074 -1.8518320508965385 -2.0322819841761706 -1.5975749168759052
22->s -4.68559 -1.6183958896292985 -1.6269864619538834 -1.5773004364468766
32->s -3.439 -1.5916924321007726 -1.5594855599070083 -1.4508816714376356
42->e -1.1480245254915507 -1.4106246620768497 -1.4148957635834822 -1.279606910000288
43->s -1.0792247267523025 -1.3721699137374075 -2.5751 -0.9773194700079009
53->e -0.8346324019271221 -1.3619308157903096 -1.1012524917746531 -3.439
54->e -0.7565364387322464 -1.0293411185845796 -0.9377226272624627 -1.002097626391233
55->e -0.713238475200378 -0.9514563026077908 -0.7924658581183113 -0.8482442919209588
56->e -0.7535210463289848 -0.7582205948144669 -1.9 -0.9525764286130292
57->n -0.5284680185324115 -0.6248481019315834 -0.4949805316257869 -0.5658226414010143
47->e -0.33171221077198093 -1.0 -0.47191739730890003 -0.5806605674804866
48->n -0.39531117024400003 -0.3550962100275 -0.1952710819254177 -0.2178940511973896
38->n -0.24361697875000002 -0.2 -0.026291052079545567 -0.24952853215
28->n -0.20995000000000003 -0.38095 0.5618620385088806 -0.21268220000000002
18->e 2.7328808047990005 -0.019000000000000003 -0.019000000000000003 -0.201805
19->n -1.0 -0.2 6.5132155990000005 0.0
Process finished with exit code 0
4.3 Q-learning结果
qlearning第1次完成任务需要迭代次数是: 182
qlearning第1次实现最短路径需要的迭代次数是: 225
0->e -1.1466365124693265 -3.439 -5.6953279000000006 -1.3170095757382594
1->s -1.4850528594071573 -1.373328849195166 -4.973631022 -1.0405829217429576
11->s -1.2950258387073321 -1.4405129351714803 -1.9032709336787803 -0.9368148708006676
21->s -1.216887055083241 -1.1169924818828254 -1.5709178635585677 -1.0865941893670963
31->s -0.8698042826982872 -0.9444593237747578 -1.126132175654716 -0.8659740148922882
41->e -0.8034340632967804 -1.040258455809505 -0.9673038891621194 -1.021859055555545
42->s -0.7136532940907929 -0.9659892113422109 -0.7497882349649212 -0.7120244968344069
52->e -0.6973799095935546 -0.9019727008712654 -0.9176117651253209 -0.8035184864740251
53->e -0.7317746431115174 -0.784243983991206 -0.8134417607575515 -2.8558000000000003
54->e -0.55819465283893 -0.5966839965779502 -0.5788283193972343 -0.6890649135451383
55->e -0.8134303264107102 -0.9140158233836498 -0.988793989933743 -0.9839117262302468
56->e -0.7403440206746625 -0.7993178340497747 -2.0620000000000003 -0.852972036636197
57->e -0.381464107086051 -0.6096093371529612 -0.3841300486001714 -0.39118811254127495
58->n -0.6231983862735987 -0.5672200023483002 -0.39377190127628464 -0.4873282499108816
48->e -0.0054075198615305825 -0.4573826430441876 -0.24244000000000004 -0.22401596324370995
49->n -1.9 -0.20995000000000003 0.4674122848745747 -0.41231237132472287
39->n -1.9 -0.221893580775 1.6278207538383493 -0.201805
29->n -1.9 -0.04568820587500001 4.176620776599433 -0.20189525000000003
19->n -1.9 -0.012635 7.941088679053509 0.0