使用DQN解决cartpole问题(深度强化学习入门)
"""
Created on Mon Nov 22 11:16:50 2021
@author: wss
"""
import numpy as np
import torch
import torch. nn as nn
import torch. nn. functional as F
import collections
import random
import torch. optim as optim
Lr = 0.1
Buffer_size = 10000
Eps = 0.1
GAMMA = 0.99
Transition = collections. namedtuple( 'Transition' ,
( 'state' , 'action' , 'next_state' , 'reward' ) )
class ReplayMemory ( object ) :
def __init__ ( self, capacity) :
self. memory = collections. deque( [ ] , maxlen= capacity)
def push ( self, * args) :
"""Save a transition"""
self. memory. append( Transition( * args) )
def sample ( self, batch_size) :
return random. sample( self. memory, batch_size)
def __len__ ( self) :
return len ( self. memory)
class Net ( nn. Module) :
def __init__ ( self, n_in, n_hidden, n_out) :
super ( Net, self) . __init__( )
self. fc1 = nn. Linear( n_in, n_hidden)
self. fc2 = nn. Linear( n_hidden, n_hidden)
self. fc3 = nn. Linear( n_hidden, n_out)
def forward ( self, x) :
x = F. relu( self. fc1( x) )
x = F. relu( self. fc2( x) )
out = self. fc3( x)
return out
class DQN ( object ) :
def __init__ ( self, n_in, n_hidden, n_out) :
self. net = Net( n_in, n_hidden, n_out)
self. target_net = Net( n_in, n_hidden, n_out)
self. optimer = optim. Adam( self. net. parameters( ) , lr = Lr)
self. loss_func = nn. MSELoss( )
self. target_net. load_state_dict( self. net. state_dict( ) )
self. buffer = ReplayMemory( Buffer_size)
def select_action ( self, state) :
threshold = random. random( )
Q_actions = self. net( torch. Tensor( state) )
if threshold< Eps :
return np. random. randint( 0 , Q_actions. shape[ 0 ] )
else :
return torch. argmax( Q_actions) . numpy( )
def update_param ( self, batch_size) :
if self. buffer . __len__( ) < batch_size:
return
transitions = self. buffer . sample( batch_size)
batch = Transition( * zip ( * transitions) )
tmp = np. vstack( batch. action)
state_batch = torch. Tensor( batch. state)
action_batch = torch. LongTensor( tmp. astype( int ) )
reward_batch = torch. Tensor( batch. reward)
next_state_batch = torch. Tensor( batch. next_state)
q_pred_s1 = torch. max ( self. target_net( next_state_batch) . detach( ) , dim= 1 ,
keepdim= True ) [ 0 ]
q_pred_s0 = self. net( state_batch) . gather( 1 , action_batch)
q_td_tar = reward_batch. unsqueeze( 1 ) + GAMMA * q_pred_s1
loss = self. loss_func( q_pred_s0, q_td_tar)
self. optimer. zero_grad( )
loss. backward( )
self. optimer. step( )
if __name__ == '__main__' :
import gym
num_episode = 10000
batch_size = 32
target_update = 20
env = gym. make( 'CartPole-v0' ) . unwrapped
Agent = DQN( env. observation_space. shape[ 0 ] , 256 , env. action_space. n)
average_time = 0
for i_episode in range ( num_episode) :
state = env. reset( )
total_time = 0
while True :
env. render( )
action = Agent. select_action( state)
next_state, reward, done, _= env. step( action)
total_time+= 1
if done:
average_time += total_time
break
Agent. buffer . push( state, action, next_state, reward)
state = next_state
Agent. update_param( batch_size)
if i_episode % target_update == 0 :
Agent. target_net. load_state_dict( Agent. net. state_dict( ) )
if ( i_episode+ 1 ) % 100 == 0 :
print ( "一百轮的平均时间" , average_time/ 100 )
average_time = 0
print ( 'Complete' )
env. render( )
env. close( )
刚刚接触深度学习以及强化学习,不知道为什么这个DQN并没有随着训练越来越来越好?