DQN网络代码调用env.reset()后存储观测信息报错问题ValueError: setting an array element with a sequence.解决处理

在强化学习DQN网络代码实现过程中,针对gym环境中的LunarLander-v2模拟游戏进行学习任务。我的部分代码如下,首先是Agent的状态存储过程:

def __init__(self,gamma,epsilon,lr,input_dims,batch_size,n_actions,    
        # ....此前省略 ....
        self.state_memory = np.zeros((self.mem_size,*input_dims),dtype=np.float32)
        self.new_state_memory = np.zeros((self.mem_size,*input_dims),dtype=np.float32)
        
        self.action_memory = np.zeros(self.mem_size,dtype=np.int32)
        self.reward_memory = np.zeros(self.mem_size,dtype=np.float32)
        self.terminal_memory = np.zeros(self.mem_size,dtype=np.bool)
    def store_transitons(self,state,action,reward,state_,done):
        index = self.mem_cntr % self.mem_size
        self.state_memory[index] = state
        self.action_memory[index] = action
        self.new_state_memory[index] = state_
        self.terminal_memory[index] = done
        self.reward_memory[index] = reward

下面是main代码中的学习循环过程:

if __name__ == '__main__':
    env = gym.make('LunarLander-v2')
    agent = Agent(gamma= 0.99 ,epsilon=1.0 , batch_size=64, n_actions=4,
    eps_end= 0.01 ,input_dims=[8], lr=0.003)
    scores,eps_history = [],[]
    n_games = 10
    for i in range(n_games):
        score = 0
        done = False
        observation  = env.reset() 
        while not done:
            action = agent.choose_action(observation)
            observation_, reward, done, info, __ = env.step(action)
            score += reward
            agent.store_transitons(observation,action,reward,observation_,done)
            agent.learn()
            observation = observation_
        scores.append(score)
        eps_history.append(agent.epsilon)
        # .................#

执行上述代码后报错:

Traceback (most recent call last):
  File ".\main_Lunar_lander.py", line 21, in <module>
    agent.store_transitons(observation,action,reward,observation_,done)
  File "D:\College\Projects\Person_Research\DQN_From_Yotube\DQN.py", line 59, in store_transitons
    self.state_memory[index] = state
ValueError: setting an array element with a sequence. The requested array would exceed the maximum number of dimension of 1.

报错信息提示为数据维度不对应,也即最初通过observation = env.reset()得到的的变量observation类型与Agent存储时的np.float32类型不匹配,而gym官方文档reset()函数的描述为:在这里插入图片描述
因此怀疑返回的observation有问题,于是通过print('observation:',observation)打印观察返回的observation,显示如下:在这里插入图片描述
reset()函数返回的是一个array类型以及其中数据的type!
因此需要将observation指定为真正需要的array信息即可,observation,__ = env.reset()

    for i in range(n_games):
        score = 0
        done = False
        observation,__  = env.reset() 
        while not done:
            action = agent.choose_action(observation)
            observation_, reward, done, info, __ = env.step(action)
            score += reward
            agent.store_transitons(observation,action,reward,observation_,done)
            agent.learn()
            observation = observation_

此后再次运行,网络就可以正常工作了:
在这里插入图片描述

评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值