【零基础强化学习】100行代码教你训练——基于SARSA的CliffWalking爬悬崖游戏


更多代码: gitee主页:https://gitee.com/GZHzzz
博客主页CSDN:https://blog.csdn.net/gzhzzaa

写在前面

show me code, no bb

import gym
import numpy as np
import time

class SarsaAgent(object):
    def __init__(self, obs_n, act_n, learning_rate=0.01, gamma=0.9, e_greed=0.1):
        self.act_n = act_n      # 动作维度,有几个动作可选
        self.lr = learning_rate # 学习率
        self.gamma = gamma      # reward的衰减率
        self.epsilon = e_greed  # 按一定概率随机选动作
        self.Q = np.zeros((obs_n, act_n))

    # 根据输入观察值,采样输出的动作值,带探索
    def sample(self, obs):
        if np.random.uniform(0, 1) < (1.0 - self.epsilon): #根据table的Q值选动作
            action = self.predict(obs)
        else:
            action = np.random.choice(self.act_n) #有一定概率随机探索选取一个动作
        return action

    # 根据输入观察值,预测输出的动作值
    def predict(self, obs):
        Q_list = self.Q[obs, :]
        maxQ = np.max(Q_list)
        action_list = np.where(Q_list == maxQ)[0]  # maxQ可能对应多个action
        action = np.random.choice(action_list)
        return action

    # 学习方法,也就是更新Q-table的方法
    def learn(self, obs, action, reward, next_obs, next_action, done):
        """ on-policy
            obs: 交互前的obs, s_t
            action: 本次交互选择的action, a_t
            reward: 本次动作获得的奖励r
            next_obs: 本次交互后的obs, s_t+1
            next_action: 根据当前Q表格, 针对next_obs会选择的动作, a_t+1
            done: episode是否结束
        """
        predict_Q = self.Q[obs, action]
        if done:
            target_Q = reward # 没有下一个状态了
        else:
            target_Q = reward + self.gamma * self.Q[next_obs, next_action] # Sarsa
        self.Q[obs, action] += self.lr * (target_Q - predict_Q) # 修正q

    # 保存Q表格数据到文件
    def save(self):
        npy_file = './q_table.npy'
        np.save(npy_file, self.Q)
        print(npy_file + ' saved.')

    # 从文件中读取Q值到Q表格中
    def restore(self, npy_file='./q_table.npy'):
        self.Q = np.load(npy_file)
        print(npy_file + ' loaded.')

def run_episode(env, agent, render=False):
    total_steps = 0 # 记录每个episode走了多少step
    total_reward = 0
    obs = env.reset() # 重置环境, 重新开一局(即开始新的一个episode)
    action = agent.sample(obs) # 根据算法选择一个动作

    while True:
        next_obs, reward, done, _ = env.step(action) # 与环境进行一个交互
        next_action = agent.sample(next_obs) # 根据算法选择一个动作
        # 训练 Sarsa 算法
        agent.learn(obs, action, reward, next_obs, next_action, done)

        action = next_action
        obs = next_obs  # 存储上一个观察值
        total_reward += reward
        total_steps += 1 # 计算step数
        if render:
            env.render() #渲染新的一帧图形
        if done:
            break
    return total_reward, total_steps


def test_episode(env, agent):
    total_reward = 0
    obs = env.reset()
    while True:
        action = agent.predict(obs) # greedy
        next_obs, reward, done, _ = env.step(action)
        total_reward += reward
        obs = next_obs
        time.sleep(0.5)
        env.render()
        if done:
            break
    return total_reward


# 使用gym创建悬崖环境
env = gym.make("CliffWalking-v0")  # 0 up, 1 right, 2 down, 3 left

# 创建一个agent实例,输入超参数
agent = SarsaAgent(
        obs_n=env.observation_space.n,
        act_n=env.action_space.n,
        learning_rate=0.1,
        gamma=0.9,
        e_greed=0.1)


# 训练500个episode,打印每个episode的分数
for episode in range(500):
    ep_reward, ep_steps = run_episode(env, agent, False)
    print('Episode %s: steps = %s , reward = %.1f' % (episode, ep_steps, ep_reward))

# 全部训练结束,查看算法效果
test_reward = test_episode(env, agent)
print('test reward = %.1f' % (test_reward))
agent.save()

  • 代码全部亲自跑过,你懂的!😝

结果展示

  • S是起点,C是障碍,G是目标
  • agent从S开始走,目标是找到到G的最短路径
  • 每走一步的reward可以建模成-1,最终目标是让累计奖励最大,也就是路径最短

在这里插入图片描述
在这里插入图片描述

  • 代表智能体一步一步向前走!😎
  • 可以看到智能体是远离障碍走的,显得有点胆小

SARSA与Q-learning

在这里插入图片描述

  1. sarsa下一步的Q对应的action是经过贪婪-探索的实际与环境交互的动作(属于on-policy),加了探索的动作会对环境中reward比较低的状态很敏感,所以实验结果很胆小
  2. q-learning下一步的Q对应的action是直接选取最大值,不是实际与环境交互的动作(属于off-policy),只选最大值的总动作意味着只关心高奖励的状态,低奖励影响不大,所以实验结果贴着障碍物走,很大胆

写在最后

十年磨剑,与君共勉!
更多代码gitee主页:https://gitee.com/GZHzzz
博客主页CSDN:https://blog.csdn.net/gzhzzaa

  • Fighting!😎

基于pytorch的经典模型基于pytorch的典型智能体模型
强化学习经典论文强化学习经典论文
在这里插入图片描述

while True:
	Go life

在这里插入图片描述

谢谢点赞交流!(❁´◡`❁)

  • 7
    点赞
  • 20
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 7
    评论
多智能体强化学习 (Multi-agent Reinforcement Learning, MARL) 是指一个由多个个体组成的环境中,每个个体都有自己的决策空间,目标是通过与环境的交互,获取最大的累积奖励。MARL 的特点是不同个体之间相互影响,一个个体的决策将会影响其他个体的决策,因此 MARL 的复杂度比单智能体强化学习要高。其主要应用于博弈论、自动驾驶、机器人、智能交通等领域。 基于Sarsa的多智能体强化学习算法可以通过如下步骤实现: 1. 初始化每个智能体的策略,价值函数以及环境模型。 2. 每个智能体与环境交互进学习,按照如下步骤进: a. 根据当前状态,每个智能体选择一个为。这里使用 $\epsilon$-贪心策略,即以一定概率随机选择为,以一定概率选择当前最优为。 b. 执为,更新环境状态。 c. 获取奖励,用于更新价值函数。 d. 根据新状态和价值函数更新智能体的策略。这里使用Sarsa(state-action-reward-state-action)算法,即使用当前策略选择一个为,然后观察下一个状态及奖励,利用下一个状态和奖励更新当前价值函数,再根据新的价值函数更新策略。 e. 将状态更新为新状态,继续执下一个动作。 3. 迭代多次执以上步骤,直到收敛。 下面是基于Sarsa的多智能体强化学习的Python代码: ```python import numpy as np import random #定义环境 class Gridworld: def __init__(self, size): self.size = size self.state = np.zeros(2, dtype=np.int32) self.actions = np.array([[0,1],[0,-1],[1,0],[-1,0]]) self.rewards = np.array([[0,-10],[-10,0],[0,-10],[0,-10]]) #判断当前状态是否终止状态 def is_terminal(self, state): if ((state == [0,0]).all() or (state == [self.size-1,self.size-1]).all()): return True else: return False #获取当前状态的所有可选为 def get_actions(self): return self.actions #更新状态 def update_state(self, action): new_state = self.state + action if new_state[0] < 0 or new_state[0] >= self.size or new_state[1] < 0 or new_state[1] >= self.size: return False else: self.state = new_state return True #获取当前状态的奖励 def get_reward(self): return self.rewards[np.where(np.all(self.actions == self.action, axis=1))[0][0]] #定义智能体 class Agent: def __init__(self, id, grid): self.id = id self.grid = grid self.q_table = np.zeros((grid.size, grid.size, 4)) #价值函数 self.epsilion = 0.1 #探索概率 self.alpha = 0.5 #学习率 self.gamma = 0.9 #衰减系数 #根据当前状态选择一个为 def choose_action(self, state): if random.uniform(0,1) < self.epsilion: action = random.choice(self.grid.get_actions()) else: action = self.greedy_policy(state) return action #根据epsilon-greedy策略选择一个为 def greedy_policy(self, state): values = self.q_table[state[0], state[1], :] max_value = np.max(values) actions = self.grid.get_actions() candidate_actions = [a for a in actions if values[np.where(np.all(self.grid.actions == a, axis=1))[0][0]] == max_value] return random.choice(candidate_actions) #执一个周期,包括选择为、执为、更新价值函数和策略 def run_cycle(self, state): self.action = self.choose_action(state) self.grid.update_state(self.action) reward = self.grid.get_reward() next_state = self.grid.state next_action = self.choose_action(next_state) value = self.q_table[state[0], state[1], np.where(np.all(self.grid.actions == self.action, axis=1))[0][0]] next_value = self.q_table[next_state[0], next_state[1], np.where(np.all(self.grid.actions == next_action, axis=1))[0][0]] td_error = reward + self.gamma * next_value - value self.q_table[state[0], state[1], np.where(np.all(self.grid.actions == self.action, axis=1))[0][0]] += self.alpha * td_error self.epsilion *= 0.99 #探索概率指数衰减 #执多个周期 def run_cycles(self, num_cycles): for i in range(num_cycles): if self.grid.is_terminal(self.grid.state): self.grid.state = np.zeros(2, dtype=np.int32) state = self.grid.state self.run_cycle(state) #定义多智能体 class MultiAgent: def __init__(self, num_agents, grid): self.grid = grid self.agents = [Agent(i, grid) for i in range(num_agents)] #执一个周期,让每个智能体分别执一个周期 def run_cycle(self): for agent in self.agents: if self.grid.is_terminal(self.grid.state): self.grid.state = np.zeros(2, dtype=np.int32) state = self.grid.state agent.run_cycle(state) #执多个周期 def run_cycles(self, num_cycles): for i in range(num_cycles): self.run_cycle() #设定环境大小和智能体数量 size = 4 num_agents = 2 #初始化环境和多智能体 grid = Gridworld(size) multi_agent = MultiAgent(num_agents, grid) #执多个周期 multi_agent.run_cycles(1000) #输出每个智能体的价值函数 for agent in multi_agent.agents: print('agent', agent.id) print(agent.q_table) ```
评论 7
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

北郭zz

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值