使用DQN解决cartpole问题(深度强化学习入门)
"""
Created on Mon Nov 22 11:16:50 2021
@author: wss
"""
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import collections
import random
import torch.optim as optim
Lr = 0.1
Buffer_size = 10000
Eps = 0.1
GAMMA = 0.99
Transition = collections.namedtuple('Transition',
('state', 'action', 'next_state', 'reward'))
class ReplayMemory(object):
def __init__(self, capacity):
self.memory = collections.deque([],maxlen=capacity)
def push(self, *args):
"""Save a transition"""
self.memory.append(Transition(*args))
def sample(self, batch_size):
return random.sample(self.memory, batch_size)
def __len__(self):
return len(self.memory)
class Net(nn.Module):
def __init__(self,n_in,n_hidden,n_out):
super(Net,self).__init__()
self.fc1 = nn.Linear(n_in, n_hidden)
self.fc2 = nn.Linear(n_hidden, n_hidden)
self.fc3 = nn.Linear(n_hidden, n_out)
def forward(self,x):
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
out = self.fc3(x)
return out
class DQN(object):
def __init__(self,n_in,n_hidden,n_out):
self.net = Net(n_in,n_hidden,n_out)
self.target_net = Net(n_in,n_hidden,n_out)
self.optimer = optim.Adam(self.net.parameters(),lr = Lr)
self.loss_func = nn.MSELoss()
self.target_net.load_state_dict(self.net.state_dict())
self.buffer = ReplayMemory(Buffer_size)
def select_action(self,state):
threshold = random.random()
Q_actions = self.net(torch.Tensor(state))
if threshold<Eps :
return np.random.randint(0,Q_actions.shape[0])
else:
return torch.argmax(Q_actions).numpy()
def update_param(self,batch_size):
if self.buffer.__len__() < batch_size:
return
transitions = self.buffer.sample(batch_size)
batch = Transition(*zip(*transitions))
tmp = np.vstack(batch.action)
state_batch = torch.Tensor(batch.state)
action_batch = torch.LongTensor(tmp.astype(int))
reward_batch = torch.Tensor(batch.reward)
next_state_batch = torch.Tensor(batch.next_state)
q_pred_s1 = torch.max(self.target_net(next_state_batch).detach(), dim=1,
keepdim=True)[0]
q_pred_s0 = self.net(state_batch).gather(1, action_batch)
q_td_tar = reward_batch.unsqueeze(1) + GAMMA * q_pred_s1
loss = self.loss_func(q_pred_s0, q_td_tar)
self.optimer.zero_grad()
loss.backward()
self.optimer.step()
if __name__ == '__main__':
import gym
num_episode =10000
batch_size = 32
target_update = 20
env = gym.make('CartPole-v0').unwrapped
Agent = DQN(env.observation_space.shape[0], 256, env.action_space.n)
average_time =0
for i_episode in range(num_episode):
state = env.reset()
total_time =0
while True:
env.render()
action = Agent.select_action(state)
next_state,reward,done,_=env.step(action)
total_time+=1
if done:
average_time +=total_time
break
Agent.buffer.push(state,action,next_state,reward)
state = next_state
Agent.update_param(batch_size)
if i_episode % target_update == 0:
Agent.target_net.load_state_dict(Agent.net.state_dict())
if (i_episode+1) % 100 == 0:
print("一百轮的平均时间",average_time/100)
average_time =0
print('Complete')
env.render()
env.close()
刚刚接触深度学习以及强化学习,不知道为什么这个DQN并没有随着训练越来越来越好?