@dataclass装饰器的作用

33 篇文章 0 订阅
31 篇文章 1 订阅

@dataclass
@dataclass: 这是一个装饰器,用于自动为类生成特殊方法,例如 __init__()__repr__() 等。在这里使用 @dataclass 装饰器可以自动为 ReplayBuffer 类生成初始化方法和字符串表示方法,而无需手动编写。

@dataclass
class ReplayBuffer:
    maxsize: int
    size: int = 0
    state: list = field(default_factory=list)
    action: list = field(default_factory=list)
    next_state: list = field(default_factory=list)
    reward: list = field(default_factory=list)
    done: list = field(default_factory=list)

    def push(self, state, action, reward, done, next_state):
        if self.size < self.maxsize:
            self.state.append(state)
            self.action.append(action)
            self.reward.append(reward)
            self.done.append(done)
            self.next_state.append(next_state)
        else:
            position = self.size % self.maxsize
            self.state[position] = state
            self.action[position] = action
            self.reward[position] = reward
            self.done[position] = done
            self.next_state[position] = next_state
        self.size += 1

    def sample(self, n):
        total_number = self.size if self.size < self.maxsize else self.maxsize
        indices = np.random.randint(total_number, size=n)
        state = [self.state[i] for i in indices]
        action = [self.action[i] for i in indices]
        reward = [self.reward[i] for i in indices]
        done = [self.done[i] for i in indices]
        next_state = [self.next_state[i] for i in indices]
        return state, action, reward, done, next_state

代码解释
这段代码定义了一个名为 ReplayBuffer 的类,它是一个用于存储经验回放数据的缓冲区。

让我们来逐步解释:

  1. @dataclass: 这是一个装饰器,用于自动为类生成特殊方法,例如 __init__()__repr__() 等。在这里使用 @dataclass 装饰器可以自动为 ReplayBuffer 类生成初始化方法和字符串表示方法,而无需手动编写。

  2. 类的属性:

    • maxsize: 缓冲区的最大容量。
    • size: 当前缓冲区中存储的数据量,默认为0。
    • state, action, reward, done, next_state: 分别对应经验回放中的状态、动作、奖励、完成标志和下一个状态,它们都是列表类型,默认为空列表。这些列表将用于存储经验回放数据。
  3. push() 方法:

    • 接受 stateactionrewarddonenext_state 等参数,表示一条经验回放数据。
    • 如果缓冲区未满,将数据添加到列表中;如果缓冲区已满,则按照循环队列的方式覆盖掉最旧的数据。
    • 每次添加数据后,更新 size 属性。
  4. sample() 方法:

    • 接受一个参数 n,表示要从缓冲区中抽取的样本数量。
    • 根据 size 属性和 maxsize 属性,确定要从缓冲区中抽取的样本总数。
    • 使用 numpy.random.randint() 函数生成 n 个随机索引,用于从缓冲区中选择样本。
    • 根据随机索引,从各个列表中抽取对应位置的数据,并返回作为抽样结果。

这段代码实现了一个简单的经验回放缓冲区,用于存储和抽样训练数据,以供深度学习模型进行训练。

  • 1
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

喝凉白开都长肉的大胖子

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值