强化学习——QLearning和Sarsa算法及其Python实现

主要是强化学习的课程,作业中涉及到了QLearning和Sarsa算法,特此记录。
宝藏博主的强化学习专栏中包含了这两个算法的讲解,极为清晰,非常推荐。链接:机器学习+深度学习+强化学习

QLearning

理论讲解

代码实现

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import time

ALPHA = 0.1
GAMMA = 0.95
EPSILION = 0.9
N_STATE = 20
ACTIONS = ['left', 'right']
MAX_EPISODES = 200
FRESH_TIME = 0.1


#############  1. Define Q table  ##############

def build_q_table(n_state, actions):
    q_table = pd.DataFrame(
    np.zeros((n_state, len(actions))),
    np.arange(n_state),
    actions
    )
    return q_table

#############  2. Define action  ##############

def choose_action(state, q_table):
    #epslion - greedy policy
    state_action = q_table.loc[state,:]
    if np.random.uniform()>EPSILION or (state_action==0).all():
        action_name = np.random.choice(ACTIONS)
    else:
        action_name = state_action.idxmax()
    return action_name

#############  3. Environment feedback  ##############

def get_env_feedback(state, action):
    if action=='right':
        if state == N_STATE-2:
            next_state = 'terminal'
            reward = 1
        else:
            next_state = state+1
            reward = -0.5
    else:
        if state == 0:
            next_state = 0
            
        else:
            next_state = state-1
        reward = -0.5
    return next_state, reward


#############  4. Update environment   ##############

def update_env(state,episode, step_counter):
    env = ['-'] *(N_STATE-1)+['T']
    if state =='terminal':
        print("Episode {}, the total step is {}".format(episode+1, step_counter))
        final_env = ['-'] *(N_STATE-1)+['T']
        return True, step_counter
    else:
        env[state]='*'
        env = ''.join(env)
        print(env)
        time.sleep(FRESH_TIME)
        return False, step_counter
        
#############  5. Agent   #################
'''
Please complete the code for this section.
Return value:
-- q_table : Refer to the function 'build_q_table'
-- step_counter_times : List: the number of total steps for every episode.
'''
def q_learning():
    q_table = build_q_table(N_STATE, ACTIONS)
    step_counter_times = []
    '''
    main loop: 
    '''
    for episode in range(MAX_EPISODES):
        state = 0                                   # Initalize state
        istate_terminal = False                         # Judgment variable: whether to end episode.
        step_counter = 0
        update_env(state, episode, step_counter)
        while not istate_terminal:
            A = choose_action(state, q_table)
            next_state, reward = get_env_feedback(state, A)
            q_predict = q_table.loc[state, A]

            if next_state != 'terminal':
                q_target = reward + GAMMA * q_table.loc[next_state, :].max()
            else:
                q_target = reward
                istate_terminal = True
            q_table.loc[state, A] += ALPHA * (q_target - q_predict)
            state = next_state
            update_env(state, episode, step_counter + 1)
            step_counter += 1
        step_counter_times.append(step_counter)

    return q_table, step_counter_times



def main():
    q_table, step_counter_times= q_learning()
    print("Q table\n{}\n".format(q_table))
    print('end')
    
    plt.plot(step_counter_times,'g-')
    plt.ylabel("steps")
    plt.savefig('QLearning_figure.png')
    
    plt.show()
    print("The step_counter_times is {}".format(step_counter_times))

main() 
  • 训练结果图:
    在这里插入图片描述

Sarsa算法

算法原理

代码实现

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import time

ALPHA = 0.1
GAMMA = 0.95
EPSILION = 0.9
N_STATE = 6
ACTIONS = ['left', 'right']
MAX_EPISODES = 200
FRESH_TIME = 0.1

#############  1. Define Q table  ##############

def build_q_table(n_state, actions):
    q_table = pd.DataFrame(
    np.zeros((n_state, len(actions))),
    np.arange(n_state),
    actions
    )
    return q_table

#############  2. Define action  ##############

def choose_action(state, q_table):
    #epslion - greedy policy
    state_action = q_table.loc[state,:]
    if np.random.uniform()>EPSILION or (state_action==0).all():
        action_name = np.random.choice(ACTIONS)
    else:
        action_name = state_action.idxmax()
    return action_name

#############  3. Environment feedback  ##############

def get_env_feedback(state, action):
    if action=='right':
        if state == N_STATE-2:
            next_state = 'terminal'
            reward = 1
        else:
            next_state = state+1
            reward = -0.5
    else:
        if state == 0:
            next_state = 0
            
        else:
            next_state = state-1
        reward = -0.5
    return next_state, reward

#############  4. Update environment   ##############

def update_env(state,episode, step_counter):
    env = ['-'] *(N_STATE-1)+['T']
    if state =='terminal':
        print("Episode {}, the total step is {}".format(episode+1, step_counter))
        final_env = ['-'] *(N_STATE-1)+['T']
        return True, step_counter
    else:
        env[state]='*'
        env = ''.join(env)
        print(env)
        time.sleep(FRESH_TIME)
        return False, step_counter
        
#############  5. Agent   #################
'''
Please complete the code for this section.
Return value:
-- q_table : Refer to the function 'build_q_table'
-- step_counter_times : List: the number of total steps for every episode.
'''
def sarsa_learning():
    q_table = build_q_table(N_STATE, ACTIONS)
    step_counter_times = []

    '''
    main loop: 
    '''
    for episode in range(MAX_EPISODES):
        state = 0
        is_terminal = False
        step_counter = 0
        update_env(state, episode, step_counter)
        while not is_terminal:
            A = choose_action(state, q_table)
            next_state, reward = get_env_feedback(state, A)
            q_predict = q_table.loc[state, A]

            if next_state != 'terminal':
                q_target = reward + GAMMA * q_table.loc[next_state, A]
            else:
                q_target = reward
                is_terminal = True
            q_table.loc[state, A] += ALPHA * (q_target - q_predict)
            state = next_state
            update_env(state, episode, step_counter + 1)
            step_counter += 1
        step_counter_times.append(step_counter)
                
    return q_table, step_counter_times

def main():
    q_table, step_counter_times= sarsa_learning()
    print("Q table\n{}\n".format(q_table))
    print('end')
    
    plt.plot(step_counter_times,'g-')
    plt.ylabel("steps")
    
    plt.savefig('Sarsa_figure.png')
    plt.show()
    print("The step_counter_times is {}".format(step_counter_times))

main() 
  • 训练结果图:
    在这里插入图片描述
  • 3
    点赞
  • 12
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值