本篇文章是 OpenAI Spinnging Up 中 Part 3: Intro to Policy Optimization 中代码的学习笔记, 原文在 https://spinningup.openai.com/en/latest/spinningup/rl_intro3.html , 代码在 https://github.com/openai/spinningup/blob/master/spinup/examples/pytorch/pg_math/1_simple_pg.py .
先给出代码
import torch
import torch.nn as nn
from torch.distributions.categorical import Categorical
from torch.optim import Adam
import numpy as np
import gym
from gym.spaces import Discrete, Box
def mlp(sizes, activation=nn.Tanh, output_activation=nn.Identity):
# Build a feedforward neural network.
layers = []
for j in range(len(sizes)-1):
act = activation if j < len(sizes)-2 else output_activation
layers += [nn.Linear(sizes[j], sizes[j+1]), act()]
return nn.Sequential(*layers)
def train(env_name='CartPole-v0', hidden_sizes=[32], lr=1e-2,
epochs=50, batch_size=5000, render=False):
# make environment, check spaces, get obs / act dims
env = gym.make(env_name)
assert isinstance(env.observation_space, Box), \
"This example only works for envs with continuous state spaces."
assert isinstance(env.action_space, Discrete), \
"This example only works for envs with discrete action spaces."
obs_dim = env.observation_space.shape[0]
n_acts = env.action_space.n
# make core of policy network
logits_net = mlp(sizes=[obs_dim]+hidden_sizes+[n_acts])
# make function to compute action distribution
def get_policy(obs):
logits = logits_net(obs)
return Categorical(logits=logits)
# make action selection function (outputs int actions, sampled from policy)
def get_action(obs):
return get_policy(obs).sample().item()
# make loss function whose gradient, for the right data, is policy gradient
def compute_loss(obs, act, weights):
logp = get_policy(obs).log_prob(act)
return -(logp * weights).mean()
# make optimizer
optimizer = Adam(logits_net.parameters(), lr=lr)
# for training policy
def train_one_epoch():
# make some empty lists for logging.
batch_obs = [] # for observations
batch_acts = [] # for actions
batch_weights = [] # for R(tau) weighting in policy gradient
batch_rets = [] # for measuring episode returns
batch_lens = [] # for measuring episode lengths
# reset episode-specific variables
obs = env.reset() <