悬崖寻路问题Sarsa和QLearning实现

Sarsa更新速度慢,但较为安全;QLearning速度快,但容易掉进悬崖

简单写一下悬崖寻路的代码:
Sarsa:

import gym
import numpy as np
import matplotlib.pyplot as plt


def epsilon_greedy(state, epsilon):
    # 探索
    if np.random.uniform(0, 1) < epsilon:
        return env.action_space.sample()

    # 利用
    else:
        return np.argmax(q_table[state, :])


def Sarsa():
    for i in range(num_epsides):

        # print("当前第 %s 次" % i)
        state = env.reset()

        epsilon = np.linspace(0.9, 0.1, num=num_epsides)[i]

        action = epsilon_greedy(state, epsilon)

        r = 0

        while True:
            next_state, reward, done, _ = env.step(action)
            # print("当前第 %s 回合" % i)
            # print("从第 %s 步移动到 第 %s 步" % (state, next_state))
            next_action = epsilon_greedy(next_state, epsilon)
            # print("在第 %s 步选择 %s 动作" % (next_state, next_action))

            q_table[state, action] += alpha * (reward + gamma * q_table[next_state, next_action] - q_table[state, action])

            # 更新状态和动作
            state = next_state
            action = next_action

            r += reward
            if done:
                print("第 %s 回合获得总奖励为 %s" % (i, r))
                break
        rewards.append(r)

def printBestRoute():
    best_route = []
    state = env.reset()
    best_route.append(state)
    while True:
        action = np.argmax(q_table[state, :])
        print(action)
        next_state, _, done, _ = env.step(action)
        state = next_state
        best_route.append(state)

        if done:
            break


def drawRewards():
    plt.rcParams['axes.unicode_minus'] = False
    plt.figure(figsize=(16, 7))
    plt.plot(rewards, 'b-', label='Sarsa')
    # plt.plot(rewards, 'r-', label='Q-learning')
    plt.legend(loc='best', fontsize=15)  # 设置图例位置
    plt.tick_params(labelsize=15)  # 刻度线设置
    plt.xlabel('迭代次数', fontsize=15)
    plt.ylabel('累积奖励', fontsize=15)
    plt.title('sarsa', fontsize=20)
    plt.show()


if __name__ == '__main__':
    env = gym.make('CliffWalking-v0')

    rewards = []
    alpha = 0.8
    gamma = 0.95
    q_table = np.zeros([env.observation_space.n, env.action_space.n])
    num_epsides = 600

    Sarsa()
    # printBestRoute()
    drawRewards()
    print(rewards)

QLearning:

import gym
import numpy as np
import matplotlib.pyplot as plt


def epsilon_greedy(state, epsilon):
    if np.random.uniform(0, 1) < epsilon:
        return env.action_space.sample()

    else:
        return np.argmax(q_table[state, :])


def q_learning():
    for i in range(1, num_epsides):

        state = env.reset()

        epsilon = np.linspace(0.9, 0.1, num_epsides)[i]

        epsides_reward = 0

        while True:
            action = epsilon_greedy(state, epsilon)
            next_state, reward, done, _ = env.step(action)
            print("当前第 %s 回合" % i)
            print("从第 %s 步移动到 第 %s 步" % (state, next_state))

            print("在第 %s 步选择 %s 动作" % (next_state, action))
            q_table[state, action] += alpha * (reward + gamma * max(q_table[next_state]) - q_table[state, action])

            state = next_state
            epsides_reward += reward

            if done:
                print("第 %s 回合获得总奖励为 %s" % (i, epsides_reward))
                break

        rewards.append(epsides_reward)

def printBestRoute():
    best_route = []
    state = env.reset()
    best_route.append(state)
    while True:
        action = np.argmax(q_table[state, :])
        print(action)
        next_state, _, done, _ = env.step(action)
        state = next_state
        best_route.append(state)

        if done:
            break


def drawRewards():
    plt.rcParams['axes.unicode_minus'] = False
    plt.figure(figsize=(16, 7))
    # plt.plot(rewards, 'b-', label='Sarsa')
    plt.plot(rewards, 'r-', label='Q-learning')
    plt.legend(loc='best', fontsize=15)  # 设置图例位置
    plt.tick_params(labelsize=15)  # 刻度线设置
    plt.xlabel('迭代次数', fontsize=15)
    plt.ylabel('累积奖励', fontsize=15)
    plt.title('Q-learning', fontsize=20)
    plt.show()


if __name__ == '__main__':
    env = gym.make('CliffWalking-v0', render_mode='human')
    np.random.seed(0)

    rewards = []
    alpha = 0.8
    gamma = 0.95
    q_table = np.zeros([env.observation_space.n, env.action_space.n])
    num_epsides = 600

    q_learning()
    printBestRoute()
    drawRewards()
    print(rewards)

从对比效果来看,QLearning速度快于Sarsa

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值