AirSim中使用强化学习算法Q-learning Sarsa控制无人机到达目标点


前言

本专栏将展示一系列经典的强化学习算法在AirSim无人机仿真中的应用,主要参考了莫烦python的视频教程。
本文章是记录 Q-learning 和Sarsa learning 控制无人机到达指定坐标点的简单实验记录,为了简单起见,现在仅控制x轴方向移动


下面是使用Q-learning算法训练100个个episode后的效果
请添加图片描述

0、准备工作

代码工程结构

其中 data 文件夹用于存放配置文件以及数据等,另外一个文件夹用于存放代码文件。
请添加图片描述

configs文件编写

data 文件夹下建立一个 configs.yaml 文件,写入以下内容

base_name: 'Drone'

vehicle_index: [0]

# 多机联调,UE4端IP地址
multi_computing: True
simulation_address: "192.168.3.4"

一、主函数逻辑编写

import yaml
import time
# from Q_learning_brain import QLearningTable
from drone_position_ctrl_env import DronePosCtrl
from Q_learning_and_Sarsa_brain import SarsaTable
from Q_learning_and_Sarsa_brain import QLearningTable

max_episodes = 100


def q_learning_start():
    for episode in range(max_episodes):
        # initial observation
        env.AirSim_client.reset()
        env.env_setting()
        time.sleep(2)
        env.takeoff()
        time.sleep(3)
        observation = env.reset()

        while True:
            # environment refresh
            env.render()

            # choose action based on observation
            action = q_learning_client.choose_action(str(observation))
            print("observation: ", observation)

            # take action and get next observation and reward
            next_observation, reward, done = env.step(action)

            print("next observation: ", next_observation)
            print("reward: ", reward)

            # to learn from this transition
            q_learning_client.learn(str(observation), action, reward, str(next_observation))

            # refresh observation
            observation = next_observation

            if done:
                break

    print("Learning process over!")
    env.reset()


def sarsa_learning_start():
    for episode in range(max_episodes):
        # initial observation
        env.AirSim_client.reset()
        env.env_setting()
        time.sleep(2)
        env.takeoff()
        time.sleep(3)
        observation = env.reset()

        action = sarsa_learning_client.choose_action(str(observation))

        while True:
            # environment refresh
            env.render()

            # take action and get next observation and reward
            next_observation, reward, done = env.step(action)

            print("next observation: ", next_observation)
            print("reward: ", reward)

            # choose action based on observation
            next_action = sarsa_learning_client.choose_action(str(next_observation))

            # to learn from this transition
            sarsa_learning_client.learn(str(observation), action, reward, str(next_observation), next_action)

            # refresh observation
            observation = next_observation
            action = next_action

            if done:
                break

    print("Learning process over!")
    env.reset()


if __name__ == "__main__":
    with open('../data/configs.yaml', "r", encoding='utf-8') as configs_file:
        _configs = yaml.load(configs_file.read(), Loader=yaml.FullLoader)

    env = DronePosCtrl(configs=_configs, vehicle_index=0)
    q_learning_client = QLearningTable(actions=list(range(env.n_actions)))
    sarsa_learning_client = SarsaTable(actions=list(range(env.n_actions)))

    q_learning_start()
    # sarsa_learning_start()
    q_learning_client.show_q_table()
    # sarsa_learning_client.show_q_table()

二、Q-learning和Sarsa learning代码实现

import numpy as np
import pandas as pd


class BaseRL(object):
    def __init__(self, action_spaces, learning_rate=0.01, reward_decay=0.9, e_greedy=0.9):
        self.actions = action_spaces
        self.learning_rate = learning_rate
        self.gamma = reward_decay
        self.epsilon = e_greedy

        self.q_table = pd.DataFrame(columns=self.actions, dtype=np.float64)

    def check_state_exist(self, state):
        if state not in self.q_table.index:
            # append this state to the table
            self.q_table = self.q_table.append(
                pd.Series(
                    [0] * len(self.actions),
                    index=self.q_table.columns,
                    name=state,
                )
            )
        else:
            pass

    def choose_action(self, observation):
        self.check_state_exist(observation)

        if np.random.rand() < self.epsilon:
            # choose the optimal action
            state_action = self.q_table.loc[observation, :]
            # some actions may have the same value, randomly choose on in these actions
            action = np.random.choice(state_action[state_action == np.max(state_action)].index)
        else:
            # randomly select a action
            action = np.random.choice(self.actions)

        return action

    def learn(self, *args):
        pass

    def show_q_table(self):
        print("Q-table:\n", self.q_table)


# off-policy
class QLearningTable(BaseRL):
    def __init__(self, actions, learning_rate=0.01, reward_decay=0.9, e_greedy=0.9):
        super(QLearningTable, self).__init__(actions, learning_rate, reward_decay, e_greedy)

    def learn(self, state, action, reward, next_state):
        self.check_state_exist(state=state)
        q_predict = self.q_table.loc[state, action]
        if next_state != 'terminal':
            q_target = reward + self.gamma * self.q_table.loc[next_state, :].max()
        else:
            q_target = reward

        self.q_table.loc[state, action] += self.learning_rate * (q_target - q_predict)


# on-policy
class SarsaTable(BaseRL):
    def __init__(self, actions, learning_rate=0.01, reward_decay=0.9, e_greedy=0.9):
        super(SarsaTable, self).__init__(actions, learning_rate, reward_decay, e_greedy)

    def learn(self, state, action, reward, next_state, next_action):
        self.check_state_exist(next_state)
        q_predict = self.q_table.loc[state, action]
        if next_state != 'terminal':
            # next state is not terminal
            q_target = reward + self.gamma * self.q_table.loc[next_state, next_action]
        else:
            q_target = reward

        self.q_table.loc[state, action] += self.learning_rate * (q_target - q_predict)

三、环境交互实现

import sys
import time
import yaml
import airsim
import random
import threading
import numpy as np

sys.path.append('..')


class DronePosCtrl(object):
    def __init__(self, configs, vehicle_index):

        self.configs = configs

        # >---------------->>>  label for threading   <<<----------------< #
        # 方便开多线程单独控制每台无人机
        self.drone_index = vehicle_index
        self.base_name = configs["base_name"]
        self.now_drone_name = self.base_name + str(vehicle_index)
        # >---------------->>>   --------------------------------------    <<<----------------< #

        # >---------------->>>  position settings   <<<----------------< #
        self.target_position = [8.0, 0.0, 2.0]
        self.myself_position = {"x": 0, "y": 0, "z": 0, "yaw": 0}

        # 极半径常量
        self.polar_radius = 6356725
        # 赤道半径常量
        self.equatorial_radius = 6378137
        # 记录原点的 gps 纬度, 经度以及高度
        self.origin_info = {"latitude": 0.0, "longitude": 0.0, "altitude": 0.0}

        # >---------------->>>   --------------------------------------    <<<----------------< #

        # , 'move front', 'move back'
        self.action_spaces = ['move_front', 'move_back']
        self.n_actions = len(self.action_spaces)

        if configs["multi_computing"]:
            # create API client for ctrl
            self.AirSim_client = airsim.MultirotorClient(str(configs["simulation_address"]))

        else:
            self.AirSim_client = airsim.MultirotorClient()

        self.AirSim_client.confirmConnection()
        self.env_setting()
        # self.takeoff()

    def env_setting(self):
        if self.drone_index == -1:
            for index in self.configs["vehicle_index"]:
                self.AirSim_client.enableApiControl(True, vehicle_name=self.base_name+str(index))
                self.AirSim_client.armDisarm(True, vehicle_name=self.base_name + str(index))
        else:
            self.AirSim_client.enableApiControl(True, vehicle_name=self.now_drone_name)
            self.AirSim_client.armDisarm(True, vehicle_name=self.now_drone_name)

    def reset(self):
        # self.AirSim_client.reset()
        # for index in self.configs["vehicle_index"]:
        #     self.AirSim_client.enableApiControl(False, vehicle_name=self.base_name + str(index))
        #     self.AirSim_client.armDisarm(False, vehicle_name=self.base_name + str(index))

        gt_dict = self.get_ground_truth_pos(vehicle_name=self.now_drone_name)
        return gt_dict['position']

    def takeoff(self):

        if self.AirSim_client.getMultirotorState().landed_state == airsim.LandedState.Landed:
            print(f"Drone{self.drone_index} is taking off now···")
            if self.drone_index == -1:
                for index in self.configs["vehicle_index"]:
                    # 需要判断是不是最后那台
                    if not index == self.configs["vehicle_index"][len(self.configs["vehicle_index"]) - 1]:
                        self.AirSim_client.takeoffAsync(timeout_sec=10, vehicle_name=self.base_name+str(index))
                    else:
                        self.AirSim_client.takeoffAsync(timeout_sec=10, vehicle_name=self.base_name+str(index)).join()
            elif self.drone_index == self.configs["target_vehicle_index"]:
                self.AirSim_client.takeoffAsync(timeout_sec=10, vehicle_name=self.now_drone_name).join()
            else:
                self.AirSim_client.takeoffAsync(timeout_sec=10, vehicle_name=self.now_drone_name)
        else:
            print(f"Drone{self.drone_index} is flying··· ")
            if self.drone_index == -1:
                for index in self.configs["vehicle_index"]:
                    # 需要判断是不是最后那台
                    if not index == self.configs["vehicle_index"][len(self.configs["vehicle_index"]) - 1]:
                        self.AirSim_client.hoverAsync(vehicle_name=self.base_name+str(index))
                    else:
                        self.AirSim_client.hoverAsync(vehicle_name=self.base_name+str(index)).join()
            else:
                self.AirSim_client.hoverAsync(vehicle_name=self.now_drone_name).join()

    def get_ground_truth_pos(self, vehicle_name="Drone0"):
        temp_pos = [0.0, 0.0, 0.0]
        temp_vel = [0.0, 0.0, 0.0]

        vehicle_state = self.AirSim_client.simGetGroundTruthKinematics(vehicle_name=vehicle_name)

        temp_pos[0] = round(vehicle_state.position.x_val, 1)
        temp_pos[1] = round(vehicle_state.position.y_val, 1)
        temp_pos[2] = round(vehicle_state.position.z_val, 1)

        temp_vel[0] = vehicle_state.linear_velocity.x_val
        temp_vel[1] = vehicle_state.linear_velocity.y_val
        temp_vel[2] = vehicle_state.linear_velocity.z_val

        ground_truth_dict = {
            "position": temp_pos,
            "velocity": temp_vel
        }
        return ground_truth_dict

    def move_by_position(self, position_3d, vehicle_name="Drone0"):
        print("position input: ", position_3d)
        # 索引为-1时表示控制全部
        if self.drone_index == -1:
            for drone_index in self.configs["vehicle_index"]:
                # 只控制除目标无人机外的所有无人机
                if not drone_index == self.configs["target_vehicle_index"]:
                    self.AirSim_client.moveToPositionAsync(position_3d[0], position_3d[1], position_3d[2], timeout_sec=2,
                                                           velocity=2, vehicle_name=self.base_name + str(drone_index))
                else:
                    pass
        else:
            # 对当前线程控制的无人机对象施加持续0.5秒的速度控制
            if vehicle_name != self.now_drone_name:
                self.AirSim_client.moveToPositionAsync(position_3d[0], position_3d[1], position_3d[2], timeout_sec=2,
                                                       velocity=2, vehicle_name=vehicle_name)
            else:
                self.AirSim_client.moveToPositionAsync(position_3d[0], position_3d[1], position_3d[2], timeout_sec=2,
                                                       velocity=2, vehicle_name=self.now_drone_name)

    def move_by_velocity(self, velocity_3d, vehicle_name="Drone0"):
        print("velocity: ", velocity_3d)
        # 索引为-1时表示控制全部
        if self.drone_index == -1:
            for drone_index in self.configs["vehicle_index"]:
                # 只控制除目标无人机外的所有无人机
                if not drone_index == self.configs["target_vehicle_index"]:
                    self.AirSim_client.moveByVelocityAsync(velocity_3d[0], velocity_3d[1], velocity_3d[2],
                                                           duration=0.6, drivetrain=airsim.DrivetrainType.ForwardOnly,
                                                           yaw_mode=airsim.YawMode(is_rate=True, yaw_or_rate=0.0),
                                                           vehicle_name=self.base_name + str(drone_index))
                else:
                    pass
        else:
            # 对当前线程控制的无人机对象施加持续0.5秒的速度控制
            if vehicle_name != self.now_drone_name:
                self.AirSim_client.moveByVelocityAsync(velocity_3d[0], velocity_3d[1], velocity_3d[2], duration=0.6,
                                                       vehicle_name=vehicle_name)
            else:
                self.AirSim_client.moveByVelocityAsync(velocity_3d[0], velocity_3d[1], velocity_3d[2], duration=0.6,
                                                       vehicle_name=self.now_drone_name)

    def step(self, action):
        status = self.get_ground_truth_pos()
        now_position = status['position']
        desired_velocity = [0.0, 0.0, 0.0]
        desired_position = now_position
        desired_position[2] = 0.0

        # move ahead
        if self.action_spaces[action] == self.action_spaces[0]:
            if now_position[0] < self.target_position[0]:
                desired_velocity[0] = 2.0
                desired_position[0] += 1.5
            else:
                desired_velocity[0] = 0.0
                desired_position[0] += 0.0

        # move back
        elif self.action_spaces[action] == self.action_spaces[1]:
            if now_position[0] > 0:
                desired_velocity[0] = -2.0
                desired_position[0] -= 1.5
            else:
                desired_velocity[0] = 0.0
                desired_position[0] -= 0.0

        # self.move_by_velocity(desired_velocity)
        self.move_by_position(desired_position)
        time.sleep(2)
        self.AirSim_client.hoverAsync(vehicle_name=self.now_drone_name).join()

        status = self.get_ground_truth_pos()
        next_position = status['position']

        if now_position[0] >= self.target_position[0]:
            reward = 100
            done = True
            next_position = 'terminal'
            print("task finished!")
        else:
            if next_position[0] - now_position[0] < 0:
                reward = -10
            else:
                reward = 0

            done = False

            if now_position[0] <= -1:
                reward = -100
                done = True
                next_position = 'terminal'

        return next_position, reward, done

    def render(self):
        pass

    # def env_test(self):
    #     # state = env.reset()
    #
    #     for i in range(10):
    #         action_index = random.randint(0, len(self.action_spaces)-1)
    #         action = self.action_spaces[action_index]
    #         state, reward, done = env.step(action)
    #
    #         if done:
    #             env.reset()
    #             return None


# if __name__ == "__main__":
#     with open('../data/configs.yaml', "r", encoding='utf-8') as configs_file:
#         _configs = yaml.load(configs_file.read(), Loader=yaml.FullLoader)
#
#     env = DronePosCtrl(configs=_configs, vehicle_index=0)
#     env.env_test()



总结

该文章仅仅实现了一个方向的控制,下一篇文章将会改进,以及使用更多其它的强化学习算法。
由于博主最近较忙,文章没有详细解释,不过有问题疑问可以留言评论,看到后会及时回复。

  • 7
    点赞
  • 47
    收藏
    觉得还不错? 一键收藏
  • 5
    评论
Q-learningSARSA都属于时序差分强化学习方法,而不是蒙特卡洛强化学习方法。 时序差分强化学习是一种结合了动态规划和蒙特卡洛方法的强化学习方法。它通过使用经验数据进行增量式的更新,同时利用了当前和未来的估计值来逼近最优值函数。 具体来说,Q-learningSARSA都是基于Q值函数的时序差分强化学习算法。 1. Q-learning:Q-learning是一种基于动态规划的无模型强化学习算法。它使用了时序差分(TD)方法,通过不断迭代更新Q值函数的估计值,使其逼近最优的Q值。Q-learning算法通过将当前状态和动作的估计值与下一个状态和动作的最大估计值相结合,来更新Q值函数的估计值。 2. SARSASARSA是一种基于时序差分的强化学习算法,也是一种模型-free的强化学习算法SARSA算法使用了时序差分的方法,通过不断迭代更新Q值函数的估计值。与Q-learning不同的是,SARSA算法采用了一个策略(Policy)来决定下一个动作,并在更新Q值时使用下一个动作的估计值。 时序差分强化学习方法与蒙特卡洛强化学习方法相比,具有更高的效率和更好的适应性。它可以在每个时间步骤进行更新,不需要等到任务结束后才进行更新,从而更快地收敛到最优策略。而蒙特卡洛强化学习方法则需要等到任务结束后才能获取完整的回报信息,进行全局更新。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 5
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值