import numpy as np
import matplotlib.pyplot as plt
'''
显示在训练过程中总回报随着玩家摇动次数而变化的曲线
ε-greedy策略,玻尔兹曼策略,UCB策略
ε-greedy策略是最常用的,UCB策略是回报最高的
'''
class KBGame:
# 初始化
# def __init__(self, *args, **kwargs):
def __init__(self):
self.q = np.array([0.0, 0.0, 0.0]) # 每个臂的平均回报,假设臂的数目为3个,初始值都为0
self.action_counts = np.array([0, 0, 0]) # 摇动每个臂的次数,初始值为0
self.current_cumulative_rewards = 0.0 # 当前累计回报总数,初始值为0.0
self.actions = [1, 2, 3] # 动作空间,用1、2、3分别表示3个不同的摇臂
self.counts = 0 # 玩家玩游戏的次数
self.counts_history = [] # 玩家玩游戏的次数记录
self.cumulative_rewards_history = [] # 累计回报的记录
self.a = 1 # 玩家当前动作,这里去摇第一个臂
self.reward = 0 # 当前回报,初始值为0
@staticmethod
# 用于模拟多臂赌博机如何给出回报
# def step(self, a):
def step(a):
r = 0
if a == 1:
r = np.random.normal(1, 1) # 摇动摇臂1,得到的回报符合均值为1、标准差为1的正态分布
if a == 2:
r = np.random.normal(2, 1) # 摇动摇臂2,得到的回报符合均值为3、标准差为1的正态分布
if a == 3:
r = np.random.normal(1.5, 1) # 摇动摇臂3,得到的回报符合均值为1.5、标准差为1的正态分布
return r
# 实现3种选择动作的策略
# 字典**kwargs用于传递相应的策略所对应的超参数
# 如e_greedy策略中的epsilon,UCB策略中的e_greedy,玻尔兹曼策略中的temperature
def choose_action(self, policy, **kwargs):
action = 0
if policy == 'e_greedy':
if np.random.random() < kwargs['epsilon']: # 生成一个0-1的数<ε
action = np.random.randint(1, 4) # [1,4)
else:
action = np.argmax(self.q)+1 # 取出q中元素最大值所对应的索引(位置,从0开始)
if policy == 'ucb':
c_ratio = kwargs['c_ratio']
# 由于UCB策略中每个动作的次数在分母中,UCB算法的第一步是依次摇动每个臂
# 如果有等于零的,那么选择该动作
if 0 in self.action_counts:
action = np.where(self.action_counts == 0)[0][0]+1
else:
value = self.q + c_ratio * np.sqrt(np.log(self.counts) / self.action_counts)
action = np.argmax(value)+1 # 取出q中元素最大值所对应的索引(位置,从0开始)
if policy == 'boltzmann':
tau = kwargs['temperature']
p = np.exp(self.q / tau) / (np.sum(np.exp(self.q / tau)))
action = np.random.choice([1, 2, 3], p=p.ravel()) # 从数组中随机抽取元素;扁平化操作
return action
# 交互进行学习训练
'''
智能体通过要学习的策略选择动作,然后将动作传给step()方法,
相当于跟多臂赌博机进行了一次交互,从多臂赌博机中获得回报r,
智能体根据立即回报更新每个动作的平均回报q,计算当前的累计回报并做相应的保存。
'''
def train(self, play_total, policy, **kwargs):
reward_1 = []
reward_2 = []
reward_3 = []
for i in range(play_total):
action = 0
if policy == 'e_greedy':
action = self.choose_action(policy, epsilon=kwargs['epsilon'])
if policy == 'ucb':
action = self.choose_action(policy, c_ratio=kwargs['c_ratio'])
if policy == 'boltzmann':
action = self.choose_action(policy, temperature=kwargs['temperature'])
self.a = action
# print(self.a)
# 与环境交互一次
self.reward = self.step(self.a) # 当前回报
self.counts += 1 # 玩家玩游戏的次数
# 更新值函数
self.q[self.a-1] = \
(self.q[self.a-1]*self.action_counts[self.a-1]+self.reward) / \
(self.action_counts[self.a-1]+1) # self.action_counts摇动每个臂的次数
self.action_counts[self.a - 1] += 1 # 摇动每个臂的次数
reward_1.append([self.q[0]]) # 在list最后添加q[0]的值
reward_2.append([self.q[1]])
reward_3.append([self.q[2]])
self.current_cumulative_rewards += self.reward # 当前累计回报总数
self.cumulative_rewards_history.append(self.current_cumulative_rewards) # 累计回报的记录,元组
self.counts_history.append(i) # 玩家玩游戏的次数记录,元组
# 成员变量进行重置
def reset(self):
self.q = np.array([0.0, 0.0, 0.0]) # 每个臂的平均回报,假设臂的数目为3个,初始值都为0
self.action_counts = np.array([0, 0, 0]) # 摇动每个臂的次数,初始值为0
self.current_cumulative_rewards = 0.0 # 当前累计回报总数,初始值为0.0
self.counts = 0 # 玩家玩游戏的次数
self.counts_history = [] # 玩家玩游戏的次数记录
self.cumulative_rewards_history = [] # 累计回报的记录
self.a = 1 # 玩家当前动作,这里去摇第一个臂
self.reward = 0 # 当前回报,初始值为0
# 画图显示
def plot(self, colors, policy, style):
plt.figure(1) # 新建一个名叫Figure1的画图窗口
plt.plot(self.counts_history, self.cumulative_rewards_history,
colors, label=policy, linestyle=style)
plt.legend() # 主要的作用就是给图加上图例
plt.xlabel('n', fontsize=18) # 参数1:坐标轴显示的字,参数2:字体大小,【参数3:字体类型】
plt.ylabel('total rewards', fontsize=18)
if __name__ == '__main__':
np.random.seed(0) # 设置相同的seed,每次生成的随机数相同
k_gamble = KBGame()
total = 2000
# e_greedy
k_gamble.train(play_total=total, policy='e_greedy', epsilon=0.05)
k_gamble.plot(colors='r', policy='e_greedy', style='-.')
k_gamble.reset()
# boltzmann
k_gamble.train(play_total=total, policy='boltzmann', temperature=1)
k_gamble.plot(colors='b', policy='boltzmann', style='--')
k_gamble.reset()
# ucb
k_gamble.train(play_total=total, policy='ucb', c_ratio=0.5)
k_gamble.plot(colors='g', policy='ucb', style='-')
plt.show()
深入浅出强化学习编程实践-01多臂赌博机
最新推荐文章于 2023-10-24 12:02:16 发布