下面我们接着说代码的第二部分:
二、定义经验回放
import random
from collections import deque
class ReplayBufferQue:
'''DQN的经验回放池,每次采样batch_size个样本,这里分两类采样方式,用的是第二类PGReplay'''
def __init__(self, capacity: int) -> None:
self.capacity = capacity
self.buffer = deque(maxlen=self.capacity)
def push(self, transitions):
'''_summary_
Args:
transitions (tuple): _description_
'''
self.buffer.append(transitions)
def sample(self, batch_size: int, sequential: bool = False):
if batch_size > len(self.buffer):
batch_size = len(self.buffer)
if sequential: # sequential sampling
rand = random.randint(0, len(self.buffer) - batch_size)
batch = [self.buffer[i] for i in range(rand, rand + batch_size)]
return zip(*batch)
else:
batch = random.sample(self.buffer, batch_size)
return zip(*batch)
def clear(self):
self.buffer.clear()
def __len__(self):
return len(self.buffer)
class PGReplay(ReplayBufferQue):
'''PG的经验回放池,每次采样所有样本,因此只需要继承ReplayBufferQue,重写sample方法即可
'''
def __init__(self):
self.buffer = deque()
def sample(self):
''' sample all the transitions
'''
batch = list(self.buffer)
return zip(*batch)
这是深度学习必备的经验回放池,说就简单点就是存取数据的,这里分了两个大类:
第一类是ReplayBufferQue。首先初始化,参数 capacity表示回放缓冲区的容量。然后创建一个双向队列 self.buffer 作为回放缓冲区,并设置最大长度为 capacity,超出部分会自动移除。之后定义一个push方法,将数据transitions添加到回放缓冲区 self.buffer 中。再定义了一个sample,用于从回放缓冲区中采样数据。之后根据参数 sequential 的取值,进行不同的采样方式。如果 sequential 为 True,则进行顺序采样;否则进行随机采样。clear 用于清空回放缓冲区。Len用于返回回放缓冲区的长度。最后会返回打包好的随机数据。
(关于双向队列deque,可以学习这篇文章:python双向队列deque)
第二类是PGReplay ,继承自 ReplayBufferQue 类。(这一类是采样所有数据)这样就可以避免代码重写。子类具有经验回放池的基本功能,并且可以重写父类中的方法(如 sample() 方法)。首先初始化,这里并没有传入参数,但会创建一个空的双向队列 self.buffer。之后用list将整个缓冲区的数据转换为列表,并传出去。最后返回zip对象。
(上面提到了一个知识点,就是子类和父类,子类如果继承父类,很多重复代码都可以不用写了,具体解释和深刻理解可以看这篇文章:Python:类的继承,调用父类的属性和方法基础详解)
经验回放技术的核心思想是将Agent与环境交互得到的经验数据(在每一步中收集到状态、动作、奖励、下一个状态等经验数据。)存储在一个经验回放缓冲区(Replay Buffer)中,并从中随机抽样一批数据用于训练深度神经网络。
回放缓冲区:通常使用循环缓冲区(Circular Buffer)或双端队列(deque)来实现。
注意:
(1)zip() 函数用于将多个可迭代对象中对应位置的元素打包成元组。而*符号用于解压参数,将包含多个可迭代对象的列表拆分为独立的参数传递给函数。打个比方:batch 应该是一个包含多个列表(或其他可迭代对象)的列表,如 [[1, 2, 3], [4, 5, 6], [7, 8, 9]]。通过 zip(*batch),将这些子列表中对应位置的元素打包成元组,例如 (1, 4, 7)、(2, 5, 8)、(3, 6, 9)。
(2)函数注解中使用 -> None 表示该函数的返回值类型为 None,即函数不返回任何有意义的值,类似于在函数体中显式地使用 return None。