背景
Keras封装了TensorFlow;Keras-rl封装了Reinforcement Learning这部分的功能;OpenAI Gym则模块化了Tensorflow&Keras里的Environment这个概念,使得“应用和算法隔离”这个思想有了具体的实现,Gym的Env可以直接怼进sKeras/Keras-rl的Agent里。由于TensorFlow的发展到2.0时代,只有Python3.6(3.7)的pip里还有这几个组件能互相兼容的版本。这是为啥目前3.6很重要。
环境
- 安装conda: https://docs.conda.io/en/latest/miniconda.html
conda create --name keras python=3.6
conda activate keras
python -m pip install keras==2.3.1 tensorflow==1.13.1 keras-rl==0.4.2 gym==0.19.0
python -m pip install click
CartPole
Command: python <xx>.py --train
import click
import gym
import random
import numpy as np
from keras.layers import Dense, Flatten
from keras.models import Sequential
from keras.optimizers import Adam
env = gym.make('CartPole-v1')
states = env.observation_space.shape[0]
print('States', states) # should be 4: position,velocity, angular position, angular velocity
actions = env.action_space.n
print('Actions', actions) # should be 2: left, right
def run_demo():
''' Naieev Gym run '''
episodes = 10
for episode in range(1,episodes+1):
state = env.reset()
done = False
score = 0
while not done:
env.render()
action = random.choice([0,1])
n_state, reward, done, info = env.step(action)
score+=reward
print('episode {} score {}'.format(episode, score))
env.close()
''' end '''
# SARSA Agent + Epsilon Greedy Q Policy
from rl.agents import SARSAAgent
from rl.policy import EpsGreedyQPolicy
def agent(states, actions):
model = Sequential()
model.add(Flatten(input_shape = (1, states)))
model.add(Dense(24, activation='relu'))
model.add(Dense(24, activation='relu'))
model.add(Dense(24, activation='relu'))
model.add(Dense(actions, activation='linear'))
return model
def build():
model = agent(env.observation_space.shape[0], env.action_space.n)
policy = EpsGreedyQPolicy()
sarsa = SARSAAgent(model = model, policy = policy, nb_actions = env.action_space.n)
sarsa.compile('adam', metrics = ['mse'])
return sarsa
def do_training():
sarsa = build()
sarsa.fit(env, nb_steps = 50000, visualize = False, verbose = 1) # Training
scores = sarsa.test(env, nb_episodes = 100, visualize= False)
print('Average score over 100 test games:{}'.format(np.mean(scores.history['episode_reward'])))
sarsa.save_weights('sarsa_weights.h5f', overwrite=True)
@click.command()
@click.option("--train", is_flag=True, default=False, help='')
@click.option("--show", is_flag=True, default=False, help='')
@click.option("--demo", is_flag=True, default=False, help='')
def main(train, show, demo):
if train:
do_training()
elif show:
sarsa = build()
sarsa.load_weights('sarsa_weights.h5f')
_ = sarsa.test(env, nb_episodes = 2, visualize= True)
elif demo:
run_demo()
if __name__ == '__main__':
main()