按照周志华西瓜书第16章K-摇臂赌博机的伪码编的程序:
# -*- coding: utf-8 -*-
"""
e贪心和Softmax
2-摇臂赌博机
摇臂1:0.4概率奖励1,0.6-0
摇臂2:0.2-1, 0.8-1
@author: y1064
"""
import numpy as np
import matplotlib.pyplot as plt
K = 2 # 摇臂数
R = [[1,0],[1,0]] # 奖赏函数
probs = [[0.4,0.6],[0.2,0.8]] # 对应的概率
T = 5000 # 尝试次数
e = 0.1 # 探索概率
def e_greedy(e):
"""伊普西龙贪心"""
Q = [0,0] # 记录摇臂的平均奖赏
count = [0,0] # 记录摇臂的探索次数
r = 0
r_list = []
for t in range(T):
if np.random.uniform()<e:
k = np.random.choice(range(K)) # 从摇臂中均匀地选择
else:
k = np.argmax(Q)
v = np.random.choice(R[k],p=probs[k])
r+=v
Q[k] = (Q[k]*count[k]+v)/(count[k]+1) # 更新平均奖赏
count[k]+=1
r_list.append(r/(t+1)) # 平均累积奖赏
return r_list
def softmax(tau):
"""softmax """
Q = [0,0]
count = [0,0]
r = 0
r_list = []
for t in range(T):
sum_p=sum([np.exp(i/tau) for i in Q])
P=[np.exp(i/tau)/sum_p for i in Q]
k=int(np.random.choice([0,1],p=P))
v = np.random.choice(R[k], p=probs[k])
r+=v
Q[k] = (Q[k]*count[k]+v)/(count[k]+1) # 更新平均奖赏
count[k]+=1
r_list.append(r/(t+1)) # 平均累积奖赏
return r_list
plt.plot(e_greedy(e=0.1),label='e-greedy,e=0.1')
plt.plot(e_greedy(e=0.01),label='e-greedy,e=0.01')
plt.plot(softmax(tau=0.01),label='softmax,tau=0.01')
plt.plot(softmax(tau=0.1),label='softmax,tau=0.1')
plt.legend()
结果跟书中差距很大,暂时看不出代码哪里错了,求助!刚开始强化学习,积极性就要被无情打消了。