表格形式的Q_learning算法+经验回放

环境如下:
在这里插入图片描述
这是一个简单的环境,绿色方块代表终点,白色方块代表可行点,灰色方块代表陷阱
Q_learning算法+经验回放训练得到value表格,可以得到比较好的结果
代码如下:
(jupyter notebook上的代码,所以顺序看起来有点儿奇怪)

#注意,下面代码中的价值Q_pi应该是最优价值Q_star,从sarsa算法代码改过来的,忘改了命名了,不过不影响结果
def get_state(row,col):
    if row!=3:
        return 'ground'
    elif col==0:
        return 'ground'
    elif col==11:
        return 'terminal'
    else:
        return 'trap'
for i in range(4):
    for j in range(12):
        print(get_state(i,j),sep='\t',end=' ')
    print('\n')
    
def envirment(row,col,action):
    if action==0:
        row-=1
    elif action==1:
        row+=1
    elif action==2:
        col-=1
    elif action==3:
        col+=1
    next_row=min(max(0,row),3)
    next_col=min(max(0,col),11)
    reward=-1
    if get_state(next_row,next_col)=='trap':
        reward=-100
    elif get_state(next_row,next_col)=='terminal':
        reward=100
    return next_row,next_col,reward
#envirment(3,0,0)

import numpy as np
import random
Q_pi=np.zeros([4,12,4])
def get_action(row,col):#获取下一步的动作
    if random.random()<0.1:
        return random.choice(range(4))#随机选一个动作
    else:
        return Q_pi[row,col].argmax()#选择Q_pi大的动作
        
def TD_Qlearning(row,col,action,reward,next_row,next_col):
#     TD_target=reward+0.9*Q_pi[next_row,next_col,next_action] #sarsa
    TD_target=reward+0.9*Q_pi[next_row,next_col].max()#Q_learn
    TD_error=Q_pi[row,col,action]-TD_target
    return TD_error
    
def jingyanhuifang(data):
    for i in range(20):
        row,col,action,reward,next_row,next_col=random.choice(data)
        TD_error=TD_Qlearning(row,col,action,reward,next_row,next_col)
        Q_pi[row,col,action]-=0.1*TD_error
    while len(data)>=500: 
        data.pop(0)

def train():
    data=[]
    for epoch in range(30000):
        row=random.choice(range(4))
        col=0
        action=get_action(row,col)
        reward_sum=0
#         print(action)
        while get_state(row,col) not in ['terminal','trap']:
            next_row,next_col,reward=envirment(row,col,action)
            reward_sum+=reward
#             print(row,col,next_row,next_col)
            next_action=get_action(next_row,next_col)
            data.append((row,col,action,reward,next_row,next_col))#这里,记录经验
            if len(data)>=100: #如果经验没达到100条,增加经验不训练
                jingyanhuifang(data)
            else:
                row=next_row
                col=next_col
                action=next_action
                continue
            TD_error=TD_Qlearning(row,col,action,reward,next_row,next_col)#Q_learn时可以少传一个变量next_actio
            Q_pi[row,col,action]-=0.1*TD_error 
#             print(row,col,next_row,next_col)
            row=next_row
            col=next_col
            action=next_action
            jingyanhuifang(data)#经验回放
        
#         print("epoch")
        if epoch%150==0:
            print(epoch,reward_sum)
train()

#打印游戏,方便测试
def show(row, col, action):
    graph = [
        '□', '□', '□', '□', '□', '□', '□', '□', '□', '□', '□', '□', '□', '□',
        '□', '□', '□', '□', '□', '□', '□', '□', '□', '□', '□', '□', '□', '□',
        '□', '□', '□', '□', '□', '□', '□', '□', '□', '○', '○', '○', '○', '○',
        '○', '○', '○', '○', '○', '❤'
    ]

    action = {0: '↑', 1: '↓', 2: '←', 3: '→'}[action]

    graph[row * 12 + col] = action

    graph = ''.join(graph)

    for i in range(0, 4 * 12, 12):
        print(graph[i:i + 12])


show(1, 1, 0)

from IPython import display
import time


def test():
    #起点
    row = random.choice(range(4))
    col = 0

    #最多玩N步
    for _ in range(200):

        #获取当前状态,如果状态是终点或者掉陷阱则终止
        if get_state(row, col) in ['trap', 'terminal']:
            break

        #选择最优动作
        action = Q_pi[row, col].argmax()

        #打印这个动作
        display.clear_output(wait=True)
        time.sleep(0.1)
        show(row, col, action)

        #执行动作
        row, col, reward = envirment(row, col, action)


test()

#打印所有格子的动作倾向
for row in range(4):
    line = ''
    for col in range(12):
        action = Q_pi[row, col].argmax()
        action = {0: '↑', 1: '↓', 2: '←', 3: '→'}[action]
        line += action
    print(line)

这里简化了,只抽取一条四元组,但是可以随机抽取多个四元组,然后算出的梯度求平均
结果:
value表格指示的action
在这里插入图片描述
这种就是比较好的结果了,都是朝向右方或者下方的。
测试结果如下:
在这里插入图片描述

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

chp的博客

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值