今天使用Q learning实现了一下倒立摆哈,我这里把代码分享给大家学习啦:
pendulum环境
# a few packages we need to import
import numpy as np
import matplotlib.pyplot as plt
import matplotlib
import matplotlib.animation as animation
import IPython
class Pendulum:
"""
This class describes an inverted pendulum and provides some helper functions
"""
def __init__(self):
"""
constructor of the class
"""
#gravity constant
self.g=9.81
# number of dimensions (angle and angular velocity)
self.state_dims = 2
# the maximum velocity
self.vmax = 6.
# the range of allowable states
self.state_range = np.array([[0, 2*np.pi],[-self.vmax, self.vmax]])
#simulation step
self.delta_t = 0.1
# internal integration step
self._internaldt = 0.01
self._integration_ratio = int(round(self.delta_t / self._internaldt))
def next_state(self,x,u):
"""
This function integrates the pendulum for one step of self.delta_t seconds
Inputs:
x: state of the pendulum (x,v) as a 2D numpy array
u: control as a scalar
Output:
the state of the pendulum as a 2D numpy array at the end of the integration
"""
x_next = x[0]
v_next = x[1]
for i in range(self._integration_ratio):
xx_next = (x_next + self._internaldt * v_next)%(2*np.pi)
v_next = np.clip(v_next + self._internaldt * (u-self.g*np.sin(x_next)), -self.vmax, self.vmax)
x_next = xx_next
return np.array([x_next,v_next])
def simulate(self, x0, policy, T):
"""
This function simulates the pendulum for T seconds from initial state x0 using a policy
(policy is called as policy(x) and returns one control)
Inputs:
x0: the initial conditions of the pendulum as a 2D array (angle and velocity)
T: the time to integrate for
Output:
x (2D array) and u (1D array) containing the time evolution of states and control
"""
horizon_length = int(T/self.delta_t)
x=np.empty([2, horizon_length+1])
x[:,0] = x0
u=np.empty([horizon_length])
for i in range(horizon_length):
u[i] = policy(x[:,i])
x[:,i+1] = self.next_state(x[:,i], u[i])
return x, u
def animate_robot(self, x, dt = 0.01):
"""
This function makes an animation showing the behavior of the pendulum
takes as input the result of a simulation - dt is the sampling time (0.1s normally)
"""
# here we check if we need to down-sample the data for display
#downsampling (we want 100ms DT or higher)
min_dt = 0.1
if(dt < min_dt):
steps = int(min_dt/dt)
use_dt = int(min_dt * 1000)
else:
steps = 1
use_dt = int(dt * 1000)
plotx = x[:,::steps]
fig = matplotlib.figure.Figure(figsize=[6,6])
matplotlib.backends.backend_agg.FigureCanvasAgg(fig)
ax = fig.add_subplot(111, autoscale_on=False, xlim=[-1.3,1.3], ylim=[-1.3,1.3])
ax.grid()
list_of_lines = []
#create the cart pole
line, = ax.plot([], [], 'k', lw=2)
list_of_lines.append(line)
line, = ax.plot([], [], 'o', lw=2)
list_of_lines.append(line)
cart_height = 0.25
def animate(i):
for l in list_of_lines: #reset all lines
l.set_data([],[])
x_pend = np.sin(plotx[0,i])
y_pend = -np.cos(plotx[0,i])
list_of_lines[0].set_data([0., x_pend], [0., y_pend])
list_of_lines[1].set_data([x_pend, x_pend], [y_pend, y_pend])
return list_of_lines
def init():
return animate(0)
ani = animation.FuncAnimation(fig, animate, np.arange(0, len(plotx[0,:])),
interval=use_dt, blit=True, init_func=init)
plt.close(fig)
plt.close(ani._fig)
IPython.display.display_html(IPython.core.display.HTML(ani.to_html5_video()))
Q learning的实现
import random
class QLearningTable:
"""
Skeleton class to help implement Q learning with a table
"""
def __init__(self, model, cost,num_states,nu, discount_factor=0.99, learning_rate=0.1, epsilon_greedy=0.1):
# we create tables to store value and policy functions
self.value_function = np.zeros([num_states])
self.policy = np.zeros([num_states])
# we create the Q table
self.q_function = np.zeros([num_states, nu])
self.model = model
self.cost = cost
# other parameters
self.epsilon = epsilon_greedy
self.gamma = discount_factor
self.alpha = learning_rate
self.num_states=num_states
self.nu=nu
def iterate(self, u_table,num_iter=1):
q_Last=np.zeros([num_states, nu])
for i in range(num_iter):#
print('iteration {}'.format(i))
# choose initial state x0
x_0 = np.array([0,0])
x_index = get_index(x_0)#random.randint(0, self.model.num_states-1)
for j in range(2000):
# choose an action using E-greedy policy
if random.uniform(0, 1)>self.epsilon :
u_index = np.argmin(self.q_function[x_index,:])
else:
u_index = random.randint(0,self.nu-1)
# observe x_t+1
next_index = next_state_index[x_index,u_index]
# compute g(x_t,u(x_t))
x=get_states(x_index)
u=u_table[u_index]
# compute TDerror
TDerror=self.cost(x, u)+self.gamma*min(self.q_function[next_index,:])-self.q_function[x_index, u_index]
self.q_function[x_index, u_index]=self.q_function[x_index, u_index]+self.alpha*TDerror
x_index=next_index
# we update the current Q function if there is any change otherwise we are done
if ((q_Last-self.q_function)**2 < 10e-2).all() :
break
else:
q_Last = self.q_function.copy()
for k in range(self.num_states):
self.policy[k]=u_table[np.argmin(self.q_function[k,:])]
self.value_function[k]=min(self.q_function[k, :])
训练
代价函数
def cost(x,u):
"""
a cost function for the inverted pendulum
"""
return (x[0]-np.pi)**2 + 0.01*x[1]**2 + 0.0001*u**2
基础配置
nq=50
nv=50
nu = 3
v_max = 6
u_max=5
# create lookup tables for discretized states
u_table = np.linspace(-u_max, u_max, nu)
q_table = np.linspace(0., 2*np.pi, nq, endpoint=False)
v_table = np.linspace(-v_max, v_max, nv)
num_states = nq * nv
训练
# we can create a robot
robot = Pendulum()
# we instanciate a Q learning object for a pendulum model and a cost function
Q_Learning = QLearningTable(robot,cost,num_states,nu)
可视化
def plot_results(robot, value_function, policy,my_policy, animate=True):
"""
This function plots the results. It displays the value function, the policy for all states.
Then it integrates the pendulum from state [0,0] and displays the states and control as a function of time
Finally it shows an animation of the result
"""
x0 = np.array([0.,0.])
x, u = robot.simulate(x0, policy, 20)
if animate:
robot.animate_robot(x, robot.delta_t)
my_policy=Q_Learning.policy
def policy(x):
# print(x)
index=get_index(x)
return my_policy[index]
plot_results(robot, Q_Learning.value_function, policy,my_policy, animate=True)
如果代码运行不通,可以私聊哈,因为代码放了一段时间了,我也懒得运行了,哈哈哈。但是核心代码是这些。