在强化学习DQN网络代码实现过程中,针对gym
环境中的LunarLander-v2
模拟游戏进行学习任务。我的部分代码如下,首先是Agent的状态存储过程:
def __init__(self,gamma,epsilon,lr,input_dims,batch_size,n_actions,
# ....此前省略 ....
self.state_memory = np.zeros((self.mem_size,*input_dims),dtype=np.float32)
self.new_state_memory = np.zeros((self.mem_size,*input_dims),dtype=np.float32)
self.action_memory = np.zeros(self.mem_size,dtype=np.int32)
self.reward_memory = np.zeros(self.mem_size,dtype=np.float32)
self.terminal_memory = np.zeros(self.mem_size,dtype=np.bool)
def store_transitons(self,state,action,reward,state_,done):
index = self.mem_cntr % self.mem_size
self.state_memory[index] = state
self.action_memory[index] = action
self.new_state_memory[index] = state_
self.terminal_memory[index] = done
self.reward_memory[index] = reward
下面是main代码中的学习循环过程:
if __name__ == '__main__':
env = gym.make('LunarLander-v2')
agent = Agent(gamma= 0.99 ,epsilon=1.0 , batch_size=64, n_actions=4,
eps_end= 0.01 ,input_dims=[8], lr=0.003)
scores,eps_history = [],[]
n_games = 10
for i in range(n_games):
score = 0
done = False
observation = env.reset()
while not done:
action = agent.choose_action(observation)
observation_, reward, done, info, __ = env.step(action)
score += reward
agent.store_transitons(observation,action,reward,observation_,done)
agent.learn()
observation = observation_
scores.append(score)
eps_history.append(agent.epsilon)
# .................#
执行上述代码后报错:
Traceback (most recent call last):
File ".\main_Lunar_lander.py", line 21, in <module>
agent.store_transitons(observation,action,reward,observation_,done)
File "D:\College\Projects\Person_Research\DQN_From_Yotube\DQN.py", line 59, in store_transitons
self.state_memory[index] = state
ValueError: setting an array element with a sequence. The requested array would exceed the maximum number of dimension of 1.
报错信息提示为数据维度不对应,也即最初通过observation = env.reset()
得到的的变量observation类型与Agent存储时的np.float32
类型不匹配,而gym官方文档对reset()
函数的描述为:
因此怀疑返回的observation有问题,于是通过print('observation:',observation)
打印观察返回的observation,显示如下:
reset()函数返回的是一个array类型以及其中数据的type!
因此需要将observation指定为真正需要的array信息即可,observation,__ = env.reset()
for i in range(n_games):
score = 0
done = False
observation,__ = env.reset()
while not done:
action = agent.choose_action(observation)
observation_, reward, done, info, __ = env.step(action)
score += reward
agent.store_transitons(observation,action,reward,observation_,done)
agent.learn()
observation = observation_
此后再次运行,网络就可以正常工作了: