周博磊强化学习纲要(cliffwalk)q_learning与SARSA代码分析

2 篇文章 1 订阅

分析周博磊老师强化学习纲要第三节课相关代码。

0.Python基础知识

0.0 Python函数相关

当Python函数返回多个值时,将多个值组成元组返回。

print("测试python函数返回多个值:")
def measure():
    Size = 17;
    Length = 28;
    return Size,Length;
num = measure()
print(num)
print(type(num))
Size,Length = measure()
print(Size,Length)

在这里插入图片描述

0.1 zip()函数相关

zip() 函数用于将可迭代的对象作为参数,将对象中对应的元素打包成一个个元组,然后返回由这些元组组成的对象,这样做的好处是节约了不少的内存。利用 * 号操作符,可以将元组解压为列表。

在Python3中zip()函数返回的是zip类,可以通过list()完成转换。

在这里插入图片描述

  1. 分析程序的调试结果可知,zip(a,b)返回的结果在调试器中不能查看具体的值 ;列表a与列表b的长度相同,对应位置组成元组即(1,4),(2,5),(3,6)。
  2. zip(a,c)中a列表的长度小于c列表的长度,a列表只能与c列表中前三个对应。
  3. zip(*zip(a,b))将zip(a,b)的结果解压为列表,zip(a,b)等价于(1,4),(2,5),(3,6)再进行解压缩操作后得到两个元组即,(1,2,3)与(4,5,6)。

在这里插入图片描述

0.2 列表(List)与Numpy

列表 (List):可以包含不同类型的元素,但一般情况下,各个元素的类型相同。

Numpy:NumPy 包的核心是ndarray对象。这封装了同质数据类型的n维数组,许多操作在编译代码中执行以提高性能。

np.array()主要用于将列表list或元组tuple转换为ndarray数组。

listTest = []
for i in range(10):
    listTest.append(i)
print(listTest)
npNum = np.array(listTest)
print(npNum)
print(npNum.shape)

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-dfCxxPpb-1642215389895)(/Users/lybing/Desktop/截屏2022-01-13 下午4.08.43.png)]

0.3 列表(list)与for循环语句,if语句合并使用

list与for循环结合二维列表每行的最小值与最大值

list_2 = [[1,2,3],
         [4,5,6],
         [7,8,9]]
maxRow = [max(row) for row in list_2]
print(maxRow) #[3,6,9]

list与for循环,if语句结合在列表找查找名字开头为T字母,且将结果的首字母转换为大写。

names_list = ["Washington", "Trump", "Obama", "bush", "Clinton", "Reagan"]
result_list = [name.capitalize() for name in names_list if name.startswith('T')]
print(result_list)
names_list = ["Washington", "Trump", "Obama", "bush", "Clinton", "Reagan","Biden"]
presidentName = [name if name.startswith('Biden') else "not president" for name in names_list]
print(presidentName)

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-OXVAwubX-1642215389896)(/Users/lybing/Desktop/截屏2022-01-13 下午5.03.09.png)]

分析上述代码可知,若将name放在if之前,则name已经属于是列表的元素,故必须要有else。

0.4 关于Python中axis

使用axis = 0表示沿着每一行标签\索引值向下执行方法。

使用axis = 1表示沿着每一列标签\索引值向右执行方法。

axis解释

当axis = 1时,相当于将每行的元素全部相加,再求均值,得到结果。

当axis = 0时,相当于将每列的元素全部相加,再求均值,得到结果。

在这里插入图片描述

1.分析相关代码

### Q-learning for cliff walk
q_learning_rewards, q_values = q_learning(env, gamma=0.9, learning_rate=0.5, render=False)
env.render(q_values, colorize_q=True)

middleNum = [q_learning(env,render=False,exploration_rate=0.1,learning_rate=0.5) for _ in range(10)]
q_learning_rewards, q_values_tuple = zip( *[q_learning(env, render=False, exploration_rate=0.1,learning_rate=0.5) for _ in range(10)])


avg_rewards = np.mean(q_learning_rewards, axis=0)
print([np.mean(avg_rewards)]*len(avg_rewards))
mean_reward = [np.mean(avg_rewards)] * len(avg_rewards)
fig, ax = plt.subplots()
ax.set_xlabel('Episodes using Q-learning')
ax.set_ylabel('Rewards')
ax.plot(avg_rewards)
ax.plot(mean_reward, 'g--')
plt.show()
print('Mean Reward using Q-Learning: {}'.format(mean_reward[0]))
  1. q_learning函数默认运行500次episode,函数的返回为q_learning_rewards(类型为list,共有500个reward),q_values为通过numpy创建的48行4列的二维数组,48表示有48个状态,4表示每个状态有4个动作。
  2. middleNum = [q_learning(env,render=False,exploration_rate=0.1,learning_rate=0.5) for _ in range(10)] 返回值middleNum是列表,列表由10个元组组成,每个元组由两个元素组成,其中一个元素为500个reward组成的list,另一个为q_values二维数组。[(r1,v1),(r2,v2),……(r10,v10)]为middleNum的数据结构。
  3. 通过zip的解压操作,将结果转换为q_learning_rewards和q_values_tuple,q_learing_rewards的类型为大小为10的元组,元组中每个数据类型为大小为500的list,q_values_tuple的类型为大小为10的元组,元组中每个数据类型为大小为(48,4)的二维数组。middleNum的r1,r2,r3,r4……r10组成元组得到q_learning_rewards。
  4. avg_rewards的类型为numpy,返回结果为一行500组的数组,因为是在q_learning_rewards的行方向求均值,q_learning_rewards相当于10行500列,在方向上求均值就得到1行500列的均值。
  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值