GAN的基本结构
GAN的主要结构包括一个生成器G(Generator)和一个判别器D(Discriminator)
GAN 充分利用“对抗过程”训练两个神经网络,这两个网络会互相博弈直至达到一种理想的平衡状态,我们这个例子中的警察和罪犯就相当于这两个神经网络。其中一个神经网络叫做生成器网络 G(Z),它会使用输入随机噪声数据,生成和已有数据集非常接近的数据,它学习的是数据分布;另一个神经网络叫鉴别器网络 D(X),它会以生成的数据作为输入,尝试鉴别出哪些是生成的数据,哪些是真实数据。鉴别器的核心是实现二元分类,输出的结果是输入数据来自真实数据集(和合成数据或虚假数据相对)的概率。
整个过程的目标函数从正式意义上可以写为:
前面所说的 GAN 最终能达到一种理想的平衡状态,是指生成器应该能模拟真实的数据,鉴别器输出的概率应该为 0.5, 即生成的数据和真实数据一致。也就是说,它不确定来自生成器的新数据是真实还是虚假,二者的概率相等(这样熵最大)。
这里,使用GAN生成正弦信号,下面给出代码:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
# torch.manual_seed(1) # reproducible
# np.random.seed(1)
# Hyper Parameters
BATCH_SIZE = 64
LR_G = 0.0001 # learning rate for generator
LR_D = 0.0001 # learning rate for discriminator
N_IDEAS = 8 # think of this as number of ideas for generating an art work(Generator)
ART_COMPONENTS = 15 # it could be total point G can drew in the canvas
PAINT_POINTS = np.vstack([np.linspace(-1, 1, ART_COMPONENTS) for _ in range(BATCH_SIZE)])
def artist_works(): # painting from the famous artist (real target)
# a = np.random.uniform(1, 2, size=BATCH_SIZE)[:, np.newaxis]
r = 0.02 * np.random.randn(1, ART_COMPONENTS)
paintings = np.sin(PAINT_POINTS * np.pi) + r
paintings = torch.from_numpy(paintings).float()
return paintings
# G = nn.Sequential( # Generator
# nn.Linear(N_IDEAS, 128), # random ideas (could from normal distribution)
# nn.ReLU(),
# nn.Linear(128, ART_COMPONENTS), # making a painting from these random ideas
# )
#
# D = nn.Sequential( # Discriminator
# nn.Linear(ART_COMPONENTS, 128), # receive art work either from the famous artist or a newbie like G
# nn.ReLU(),
# nn.Linear(128, 1),
# nn.Sigmoid(), # tell the probability that the art work is made by artist
# )
class Ge(nn.Module):
def __init__(self):
super(Ge,self).__init__()
self.fc1=nn.Linear(N_IDEAS,128)
self.fc2=nn.Linear(128,ART_COMPONENTS)
def forward(self, x):
x=F.relu(self.fc1(x))
x=self.fc2(x)
return x
class De(nn.Module):
def __init__(self):
super(De,self).__init__()
self.fc1=nn.Linear(ART_COMPONENTS,128)
self.fc2=nn.Linear(128,1)
def forward(self,x):
x=F.relu(self.fc1(x))
x=F.sigmoid(self.fc2(x))
return x
G=Ge()
D=De()
opt_D = torch.optim.Adam(D.parameters(), lr=LR_D)
opt_G = torch.optim.Adam(G.parameters(), lr=LR_G)
plt.ion() # something about continuous plotting
D_loss_history = []
G_loss_history = []
for step in range(10000):
artist_paintings = artist_works() # real painting from artist
G_ideas = torch.randn(BATCH_SIZE, N_IDEAS) # random ideas
G_paintings = G(G_ideas) # fake painting from G (random ideas)
prob_artist0 = D(artist_paintings) # D try to increase this prob
prob_artist1 = D(G_paintings) # D try to reduce this prob
D_loss = - torch.mean(torch.log(prob_artist0) + torch.log(1. - prob_artist1))
G_loss = torch.mean(torch.log(1. - prob_artist1))
D_loss_history.append(D_loss)
G_loss_history.append(G_loss)
opt_D.zero_grad()
D_loss.backward(retain_graph=True) # reusing computational graph
opt_D.step()
opt_G.zero_grad()
G_loss.backward()
opt_G.step()
print("4444d",PAINT_POINTS[0])
if step % 1000 == 0: # plotting
plt.cla()
plt.plot(PAINT_POINTS[0], G_paintings.data.numpy()[0], c='r', lw=3, label='Generated painting', )
plt.plot(PAINT_POINTS[0], np.sin(PAINT_POINTS[0] * np.pi), c='b', lw=3, label='upper bound')
plt.text(-1, 0.75, 'D accuracy=%.2f (0.5 for D to converge)' % prob_artist0.data.numpy().mean(),
fontdict={'size': 13})
plt.text(-1, 0.5, 'D score= %.2f (-1.38 for G to converge)' % -D_loss.data.numpy(), fontdict={'size': 13})
plt.ylim((-1, 1));
plt.legend(loc='upper right', fontsize=10);
plt.draw();
plt.pause(0.01)
# plt.ioff()
# plt.show()
上面代码中,def artist_works()函数这里主要产生给定的正弦信号:
def artist_works(): # painting from the famous artist (real target)
# a = np.random.uniform(1, 2, size=BATCH_SIZE)[:, np.newaxis]
r = 0.02 * np.random.randn(1, ART_COMPONENTS)
paintings = np.sin(PAINT_POINTS * np.pi) + r
paintings = torch.from_numpy(paintings).float()
return paintings
下面这段代码主要是构建生成器与判别器网络,这里的网络是在pytorch下完成的。
class Ge(nn.Module):
def __init__(self):
super(Ge,self).__init__()
self.fc1=nn.Linear(N_IDEAS,128)
self.fc2=nn.Linear(128,ART_COMPONENTS)
def forward(self, x):
x=F.relu(self.fc1(x))
x=self.fc2(x)
return x
class De(nn.Module):
def __init__(self):
super(De,self).__init__()
self.fc1=nn.Linear(ART_COMPONENTS,128)
self.fc2=nn.Linear(128,1)
def forward(self,x):
x=F.relu(self.fc1(x))
x=F.sigmoid(self.fc2(x))
return x
下面这段代码为生成器和判别器的损失函数:
D_loss = - torch.mean(torch.log(prob_artist0) + torch.log(1. - prob_artist1))
G_loss = torch.mean(torch.log(1. - prob_artist1))
实现效果,第一幅图为刚开始随机数输入产生的曲线,第二幅图为鉴别器输出的概率为 0.5,可以看出效果很好:
有了上面GAN的经验,接下来介绍生成对抗模仿学习:
在这里,整个工程有两个文件组成,一个env_OppositeV4.py构建环境,一个GAIL_OppositeV4.py运行程序。
首先介绍env_OppositeV4.py代码构建环境,先看一个构建的环境效果图:
图中红色的部分为起点,绿色部分为终点,下面给出env_OppositeV4.py代码:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
import random
import cv2
class EnvOppositeV4(object):
def __init__(self, size):
self.map_size = size
self.raw_occupancy = np.zeros((self.map_size, self.map_size))
for i in range(self.map_size):
self.raw_occupancy[0][i] = 1
self.raw_occupancy[self.map_size - 1][i] = 1
self.raw_occupancy[i][0] = 1
self.raw_occupancy[i][self.map_size - 1] = 1
self.raw_occupancy[i][int((self.map_size - 1) / 2)] = 1
self.raw_occupancy[1][int((self.map_size - 1) / 2)] = 0
self.raw_occupancy[self.map_size - 2][int((self.map_size - 1) / 2)] = 0
self.occupancy = self.raw_occupancy.copy()
self.agt1_pos = [int((self.map_size - 1) / 2), 1]
self.goal1_pos = [int((self.map_size - 1) / 2), self.map_size - 2]
self.occupancy[self.agt1_pos[0]][self.agt1_pos[1]] = 1
def reset(self):
self.occupancy = self.raw_occupancy.copy()
self.agt1_pos = [int((self.map_size - 1) / 2), 1]
self.goal1_pos = [int((self.map_size - 1) / 2), self.map_size - 2]
self.occupancy[self.agt1_pos[0]][self.agt1_pos[1]] = 1
def get_state(self):
state = np.zeros((1, 2))
state[0, 0] = self.agt1_pos[0] / self.map_size
state[0, 1] = self.agt1_pos[1] / self.map_size
return state
def step(self, action_list):
reward = 0
# agent1 move
if action_list[0] == 0: # move up
if self.occupancy[self.agt1_pos[0] - 1][self.agt1_pos[1]] != 1: # if can move
self.agt1_pos[0] = self.agt1_pos[0] - 1
self.occupancy[self.agt1_pos[0] + 1][self.agt1_pos[1]] = 0
self.occupancy[self.agt1_pos[0]][self.agt1_pos[1]] = 1
elif action_list[0] == 1: # move down
if self.occupancy[self.agt1_pos[0] + 1][self.agt1_pos[1]] != 1: # if can move
self.agt1_pos[0] = self.agt1_pos[0] + 1
self.occupancy[self.agt1_pos[0] - 1][self.agt1_pos[1]] = 0
self.occupancy[self.agt1_pos[0]][self.agt1_pos[1]] = 1
elif action_list[0] == 2: # move left
if self.occupancy[self.agt1_pos[0]][self.agt1_pos[1] - 1] != 1: # if can move
self.agt1_pos[1] = self.agt1_pos[1] - 1
self.occupancy[self.agt1_pos[0]][self.agt1_pos[1] + 1] = 0
self.occupancy[self.agt1_pos[0]][self.agt1_pos[1]] = 1
elif action_list[0] == 3: # move right
if self.occupancy[self.agt1_pos[0]][self.agt1_pos[1] + 1] != 1: # if can move
self.agt1_pos[1] = self.agt1_pos[1] + 1
self.occupancy[self.agt1_pos[0]][self.agt1_pos[1] - 1] = 0
self.occupancy[self.agt1_pos[0]][self.agt1_pos[1]] = 1
if self.agt1_pos == self.goal1_pos:
reward = reward + 5
done = False
if reward == 5:
done = True
return reward, done
def get_global_obs(self):
obs = np.zeros((self.map_size, self.map_size, 3))
for i in range(self.map_size):
for j in range(self.map_size):
if self.occupancy[i][j] == 0:
obs[i, j, 0] = 1.0
obs[i, j, 1] = 1.0
obs[i, j, 2] = 1.0
obs[self.agt1_pos[0], self.agt1_pos[1], 0] = 1.0
obs[self.agt1_pos[0], self.agt1_pos[1], 1] = 0.0
obs[self.agt1_pos[0], self.agt1_pos[1], 2] = 0.0
return obs
def render(self):
obs = self.get_global_obs()
enlarge = 30
new_obs = np.ones((self.map_size*enlarge, self.map_size*enlarge, 3))
for i in range(self.map_size):
for j in range(self.map_size):
if obs[i][j][0] == 0.0 and obs[i][j][1] == 0.0 and obs[i][j][2] == 0.0:
cv2.rectangle(new_obs, (j * enlarge, i * enlarge), (j * enlarge + enlarge, i * enlarge + enlarge), (0, 0, 0), -1)
if obs[i][j][0] == 1.0 and obs[i][j][1] == 0.0 and obs[i][j][2] == 0.0:
cv2.rectangle(new_obs, (j * enlarge, i * enlarge), (j * enlarge + enlarge, i * enlarge + enlarge), (0, 0, 255), -1)
if obs[i][j][0] == 0.0 and obs[i][j][1] == 1.0 and obs[i][j][2] == 0.0:
cv2.rectangle(new_obs, (j * enlarge, i * enlarge), (j * enlarge + enlarge, i * enlarge + enlarge), (0, 255, 0), -1)
if obs[i][j][0] == 0.0 and obs[i][j][1] == 0.0 and obs[i][j][2] == 1.0:
cv2.rectangle(new_obs, (j * enlarge, i * enlarge), (j * enlarge + enlarge, i * enlarge + enlarge), (255, 0, 0), -1)
cv2.imshow('image', new_obs)
cv2.waitKey(100)
上面代码中,这个部分生成如下图,其实就是生成环境的矩形框,1的部分到时候赋予黑颜色,0的部分赋予白色,就构建出了上面的图,这里也计算了agent的目标位置与起始位置。
def __init__(self, size):
self.map_size = size
self.raw_occupancy = np.zeros((self.map_size, self.map_size))
for i in range(self.map_size):
self.raw_occupancy[0][i] = 1
self.raw_occupancy[self.map_size - 1][i] = 1
self.raw_occupancy[i][0] = 1
self.raw_occupancy[i][self.map_size - 1] = 1
self.raw_occupancy[i][int((self.map_size - 1) / 2)] = 1
self.raw_occupancy[1][int((self.map_size - 1) / 2)] = 0
self.raw_occupancy[self.map_size - 2][int((self.map_size - 1) / 2)] = 0
self.occupancy = self.raw_occupancy.copy()
self.agt1_pos = [int((self.map_size - 1) / 2), 1]
self.goal1_pos = [int((self.map_size - 1) / 2), self.map_size - 2]
self.occupancy[self.agt1_pos[0]][self.agt1_pos[1]] = 1
通过下面代码把数字为1的地方赋予黑色,把0的地方赋予白色,结果如下图。
def get_global_obs(self):
obs = np.zeros((self.map_size, self.map_size, 3))
for i in range(self.map_size):
for j in range(self.map_size):
if self.occupancy[i][j] == 0:
obs[i, j, 0] = 1.0
obs[i, j, 1] = 1.0
obs[i, j, 2] = 1.0
obs[self.agt1_pos[0], self.agt1_pos[1], 0] = 1.0
obs[self.agt1_pos[0], self.agt1_pos[1], 1] = 0.0
obs[self.agt1_pos[0], self.agt1_pos[1], 2] = 0.0
return obs
通过下面的代码把框图放大。
def render(self):
obs = self.get_global_obs()
enlarge = 30
new_obs = np.ones((self.map_size*enlarge, self.map_size*enlarge, 3))
for i in range(self.map_size):
for j in range(self.map_size):
if obs[i][j][0] == 0.0 and obs[i][j][1] == 0.0 and obs[i][j][2] == 0.0:
cv2.rectangle(new_obs, (j * enlarge, i * enlarge), (j * enlarge + enlarge, i * enlarge + enlarge), (0, 0, 0), -1)
if obs[i][j][0] == 1.0 and obs[i][j][1] == 0.0 and obs[i][j][2] == 0.0:
cv2.rectangle(new_obs, (j * enlarge, i * enlarge), (j * enlarge + enlarge, i * enlarge + enlarge), (0, 0, 255), -1)
if obs[i][j][0] == 0.0 and obs[i][j][1] == 1.0 and obs[i][j][2] == 0.0:
cv2.rectangle(new_obs, (j * enlarge, i * enlarge), (j * enlarge + enlarge, i * enlarge + enlarge), (0, 255, 0), -1)
if obs[i][j][0] == 0.0 and obs[i][j][1] == 0.0 and obs[i][j][2] == 1.0:
cv2.rectangle(new_obs, (j * enlarge, i * enlarge), (j * enlarge + enlarge, i * enlarge + enlarge), (255, 0, 0), -1)
cv2.imshow('image',new_obs)
cv2.waitKey(100)
下面这段代码主要是描述agent的动作与reward。
def step(self, action_list):
reward = 0
# agent1 move
if action_list[0] == 0: # move up
if self.occupancy[self.agt1_pos[0] - 1][self.agt1_pos[1]] != 1: # if can move
self.agt1_pos[0] = self.agt1_pos[0] - 1
self.occupancy[self.agt1_pos[0] + 1][self.agt1_pos[1]] = 0
self.occupancy[self.agt1_pos[0]][self.agt1_pos[1]] = 1
elif action_list[0] == 1: # move down
if self.occupancy[self.agt1_pos[0] + 1][self.agt1_pos[1]] != 1: # if can move
self.agt1_pos[0] = self.agt1_pos[0] + 1
self.occupancy[self.agt1_pos[0] - 1][self.agt1_pos[1]] = 0
self.occupancy[self.agt1_pos[0]][self.agt1_pos[1]] = 1
elif action_list[0] == 2: # move left
if self.occupancy[self.agt1_pos[0]][self.agt1_pos[1] - 1] != 1: # if can move
self.agt1_pos[1] = self.agt1_pos[1] - 1
self.occupancy[self.agt1_pos[0]][self.agt1_pos[1] + 1] = 0
self.occupancy[self.agt1_pos[0]][self.agt1_pos[1]] = 1
elif action_list[0] == 3: # move right
if self.occupancy[self.agt1_pos[0]][self.agt1_pos[1] + 1] != 1: # if can move
self.agt1_pos[1] = self.agt1_pos[1] + 1
self.occupancy[self.agt1_pos[0]][self.agt1_pos[1] - 1] = 0
self.occupancy[self.agt1_pos[0]][self.agt1_pos[1]] = 1
if self.agt1_pos == self.goal1_pos:
reward = reward + 5
done = False
if reward == 5:
done = True
return reward, done
到这里,agent运行环境已经介绍完成。
下面给出GAIL_OppositeV4.py代码:
from torch.distributions.categorical import Categorical
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import torch.optim as optim
from env_OppositeV4 import EnvOppositeV4
import numpy as np
import csv
from collections import deque
import os
class Actor(nn.Module):
def __init__(self, N_action):
super(Actor, self).__init__()
self.N_action = N_action
self.fc1 = nn.Linear(2, 64)
self.fc2 = nn.Linear(64, 32)
self.fc3 = nn.Linear(32, self.N_action)
def get_action(self, h):
h = F.relu(self.fc1(h))
h = F.relu(self.fc2(h))
h = F.softmax(self.fc3(h), dim=1)
m = Categorical(h.squeeze(0))
a = m.sample()
log_prob = m.log_prob(a)
return a.item(), h, log_prob
class Discriminator(nn.Module):
def __init__(self, s_dim, N_action):
super(Discriminator, self).__init__()
self.s_dim = s_dim
self.N_action = N_action
self.fc1 = nn.Linear(self.s_dim + self.N_action, 128)
self.fc2 = nn.Linear(128, 64)
self.fc3 = nn.Linear(64, 1)
def forward(self, state, action):
state_action = torch.cat([state, action], 1)
x = torch.relu(self.fc1(state_action))
x = torch.relu(self.fc2(x))
x = torch.sigmoid(self.fc3(x))
return x
class GAIL(object):
def __init__(self, s_dim, N_action):
self.s_dim = s_dim
self.N_action = N_action
self.actor1 = Actor(self.N_action)
self.disc1 = Discriminator(self.s_dim, self.N_action)
self.d1_optimizer = torch.optim.Adam(self.disc1.parameters(), lr=1e-3)
self.a1_optimizer = torch.optim.Adam(self.actor1.parameters(), lr=1e-3)
self.loss_fn = torch.nn.MSELoss()
self.adv_loss_fn = torch.nn.BCELoss()
self.gamma = 0.9
def get_action(self, obs1):
action1, pi_a1, log_prob1 = self.actor1.get_action(torch.from_numpy(obs1).float())
return action1, pi_a1, log_prob1
def int_to_tensor(self, action):
temp = torch.zeros(1, self.N_action)
temp[0, action] = 1
return temp
def train_D(self, s1_list, a1_list, e_s1_list, e_a1_list):
p_s1 = torch.from_numpy(s1_list[0]).float()
p_a1 = self.int_to_tensor(a1_list[0])
for i in range(1, len(s1_list)):
temp_p_s1 = torch.from_numpy(s1_list[i]).float()
p_s1 = torch.cat([p_s1, temp_p_s1], dim=0)
temp_p_a1 = self.int_to_tensor(a1_list[i])
p_a1 = torch.cat([p_a1, temp_p_a1], dim=0)
e_s1 = torch.from_numpy(e_s1_list[0]).float()
e_a1 = self.int_to_tensor(e_a1_list[0])
for i in range(1, len(e_s1_list)):
temp_e_s1 = torch.from_numpy(e_s1_list[i]).float()
e_s1 = torch.cat([e_s1, temp_e_s1], dim=0)
temp_e_a1 = self.int_to_tensor(e_a1_list[i])
e_a1 = torch.cat([e_a1, temp_e_a1], dim=0)
p1_label = torch.zeros(len(s1_list), 1)
e1_label = torch.ones(len(e_s1_list), 1)
e1_pred = self.disc1(e_s1, e_a1)
# print('e1_pred', e1_pred)
loss = self.adv_loss_fn(e1_pred, e1_label)
p1_pred = self.disc1(p_s1, p_a1)
# print('p1_pred', p1_pred)
loss = loss + self.adv_loss_fn(p1_pred, p1_label)
self.d1_optimizer.zero_grad()
loss.backward()
self.d1_optimizer.step()
def train_G(self, s1_list, a1_list, log_pi_a1_list, r1_list, e_s1_list, e_a1_list):
T = len(s1_list)
p_s1 = torch.from_numpy(s1_list[0]).float()
p_a1 = self.int_to_tensor(a1_list[0])
for i in range(1, len(s1_list)):
temp_p_s1 = torch.from_numpy(s1_list[i]).float()
p_s1 = torch.cat([p_s1, temp_p_s1], dim=0)
temp_p_a1 = self.int_to_tensor(a1_list[i])
p_a1 = torch.cat([p_a1, temp_p_a1], dim=0)
e_s1 = torch.from_numpy(e_s1_list[0]).float()
e_a1 = self.int_to_tensor(e_a1_list[0])
for i in range(1, len(e_s1_list)):
temp_e_s1 = torch.from_numpy(e_s1_list[i]).float()
e_s1 = torch.cat([e_s1, temp_e_s1], dim=0)
temp_e_a1 = self.int_to_tensor(e_a1_list[i])
e_a1 = torch.cat([e_a1, temp_e_a1], dim=0)
p1_pred = self.disc1(p_s1, p_a1)
fake_reward = p1_pred.mean()
a1_loss = torch.FloatTensor([0.0])
for t in range(T):
a1_loss = a1_loss + fake_reward * log_pi_a1_list[t]
a1_loss = -a1_loss / T
# print(a1_loss)
self.a1_optimizer.zero_grad()
a1_loss.backward()
self.a1_optimizer.step()
class REINFORCE(object):
def __init__(self, N_action):
self.N_action = N_action
self.actor1 = Actor(self.N_action)
def get_action(self, obs):
action1, pi_a1, log_prob1 = self.actor1.get_action(torch.from_numpy(obs).float())
return action1, pi_a1, log_prob1
def train(self, a1_list, pi_a1_list, r_list):
a1_optimizer = torch.optim.Adam(self.actor1.parameters(), lr=1e-3)
T = len(r_list)
G_list = torch.zeros(1, T)
G_list[0, T - 1] = torch.FloatTensor([r_list[T - 1]])
for k in range(T - 2, -1, -1):
G_list[0, k] = r_list[k] + 0.95 * G_list[0, k + 1]
a1_loss = torch.FloatTensor([0.0])
for t in range(T):
a1_loss = a1_loss + G_list[0, t] * torch.log(pi_a1_list[t][0, a1_list[t]])
a1_loss = -a1_loss / T
a1_optimizer.zero_grad()
a1_loss.backward()
a1_optimizer.step()
def save_model(self):
torch.save(self.actor1, 'V4_actor.pkl')
def load_model(self):
self.actor1 = torch.load('V4_actor.pkl')
if __name__ == '__main__':
torch.set_num_threads(1)
env = EnvOppositeV4(9)
max_epi_iter = 100
max_MC_iter = 100
# train expert policy by REINFORCE algorithm
agent = REINFORCE(N_action=5)
if os.path.exists('./V4_actor.pkl'):
agent.load_model()
else:
print('无保存模型,将从头开始训练!')
for epi_iter in range(max_epi_iter):
env.reset()
a1_list = []
pi_a1_list = []
r_list = []
acc_r = 0
for MC_iter in range(max_MC_iter):
env.render()
state = env.get_state()
action1, pi_a1, log_prob1 = agent.get_action(state)
a1_list.append(action1)
pi_a1_list.append(pi_a1)
reward, done = env.step([action1, 0])
acc_r = acc_r + reward
r_list.append(reward)
if done:
break
print('Train expert, Episode', epi_iter, 'average reward', acc_r / MC_iter)
if done:
agent.train(a1_list, pi_a1_list, r_list)
# record expert policy
agent.save_model()
exp_s_list = []
exp_a_list = []
env.reset()
for MC_iter in range(max_MC_iter):
env.render()
state = env.get_state()
action1, pi_a1, log_prob1 = agent.get_action(state)
exp_s_list.append(state)
exp_a_list.append(action1)
reward, done = env.step([action1, 0])
print('step', MC_iter, 'agent 1 at', exp_s_list[MC_iter], 'agent 1 action', exp_a_list[MC_iter], 'reward', reward, 'done', done)
if done:
break
# generative adversarial imitation learning from [exp_s_list, exp_a_list]
agent = GAIL(s_dim=2, N_action=5)
for epi_iter in range(max_epi_iter):
env.reset()
s1_list = []
a1_list = []
r1_list = []
log_pi_a1_list = []
acc_r = 0
for MC_iter in range(max_MC_iter):
# env.render()
state = env.get_state()
action1, pi_a1, log_prob1 = agent.get_action(state)
s1_list.append(state)
a1_list.append(action1)
log_pi_a1_list.append(log_prob1)
reward, done = env.step([action1, 0])
acc_r = acc_r + reward
r1_list.append(reward)
if done:
break
print('Imitate by GAIL, Episode', epi_iter, 'average reward', acc_r/MC_iter)
# train Discriminator
agent.train_D(s1_list, a1_list, exp_s_list, exp_a_list)
# train Generator
agent.train_G(s1_list, a1_list, log_pi_a1_list, r1_list, exp_s_list, exp_a_list)
# learnt policy
print('expert trajectory')
for i in range(len(exp_a_list)):
print('step', i, 'agent 1 at', exp_s_list[i], 'agent 1 action', exp_a_list[i])
print('learnt trajectory')
env.reset()
for MC_iter in range(max_MC_iter):
# env.render()
state = env.get_state()
action1, pi_a1, log_prob1 = agent.get_action(state)
exp_s_list.append(state)
exp_a_list.append(action1)
reward, done = env.step([action1, 0])
print('step', MC_iter, 'agent 1 at', exp_s_list[MC_iter], 'agent 1 action', exp_a_list[MC_iter])
if done:
break
运行结果为:
expert trajectory
step 0 agent 1 at [[0.44444444 0.11111111]] agent 1 action 1
step 1 agent 1 at [[0.55555556 0.11111111]] agent 1 action 4
step 2 agent 1 at [[0.55555556 0.11111111]] agent 1 action 3
step 3 agent 1 at [[0.55555556 0.22222222]] agent 1 action 1
step 4 agent 1 at [[0.66666667 0.22222222]] agent 1 action 0
step 5 agent 1 at [[0.55555556 0.22222222]] agent 1 action 0
step 6 agent 1 at [[0.44444444 0.22222222]] agent 1 action 3
step 7 agent 1 at [[0.44444444 0.33333333]] agent 1 action 4
step 8 agent 1 at [[0.44444444 0.33333333]] agent 1 action 3
step 9 agent 1 at [[0.44444444 0.33333333]] agent 1 action 0
step 10 agent 1 at [[0.33333333 0.33333333]] agent 1 action 4
step 11 agent 1 at [[0.33333333 0.33333333]] agent 1 action 1
step 12 agent 1 at [[0.44444444 0.33333333]] agent 1 action 0
step 13 agent 1 at [[0.33333333 0.33333333]] agent 1 action 0
step 14 agent 1 at [[0.22222222 0.33333333]] agent 1 action 1
step 15 agent 1 at [[0.33333333 0.33333333]] agent 1 action 1
step 16 agent 1 at [[0.44444444 0.33333333]] agent 1 action 1
step 17 agent 1 at [[0.55555556 0.33333333]] agent 1 action 2
step 18 agent 1 at [[0.55555556 0.22222222]] agent 1 action 3
step 19 agent 1 at [[0.55555556 0.33333333]] agent 1 action 3
step 20 agent 1 at [[0.55555556 0.33333333]] agent 1 action 3
step 21 agent 1 at [[0.55555556 0.33333333]] agent 1 action 3
step 22 agent 1 at [[0.55555556 0.33333333]] agent 1 action 3
step 23 agent 1 at [[0.55555556 0.33333333]] agent 1 action 0
step 24 agent 1 at [[0.44444444 0.33333333]] agent 1 action 4
step 25 agent 1 at [[0.44444444 0.33333333]] agent 1 action 0
step 26 agent 1 at [[0.33333333 0.33333333]] agent 1 action 0
step 27 agent 1 at [[0.22222222 0.33333333]] agent 1 action 1
step 28 agent 1 at [[0.33333333 0.33333333]] agent 1 action 3
step 29 agent 1 at [[0.33333333 0.33333333]] agent 1 action 1
step 30 agent 1 at [[0.44444444 0.33333333]] agent 1 action 3
step 31 agent 1 at [[0.44444444 0.33333333]] agent 1 action 3
step 32 agent 1 at [[0.44444444 0.33333333]] agent 1 action 0
step 33 agent 1 at [[0.33333333 0.33333333]] agent 1 action 0
step 34 agent 1 at [[0.22222222 0.33333333]] agent 1 action 2
step 35 agent 1 at [[0.22222222 0.22222222]] agent 1 action 3
step 36 agent 1 at [[0.22222222 0.33333333]] agent 1 action 3
step 37 agent 1 at [[0.22222222 0.33333333]] agent 1 action 1
step 38 agent 1 at [[0.33333333 0.33333333]] agent 1 action 3
step 39 agent 1 at [[0.33333333 0.33333333]] agent 1 action 0
step 40 agent 1 at [[0.22222222 0.33333333]] agent 1 action 3
step 41 agent 1 at [[0.22222222 0.33333333]] agent 1 action 3
step 42 agent 1 at [[0.22222222 0.33333333]] agent 1 action 1
step 43 agent 1 at [[0.33333333 0.33333333]] agent 1 action 3
step 44 agent 1 at [[0.33333333 0.33333333]] agent 1 action 1
step 45 agent 1 at [[0.44444444 0.33333333]] agent 1 action 3
step 46 agent 1 at [[0.44444444 0.33333333]] agent 1 action 1
step 47 agent 1 at [[0.55555556 0.33333333]] agent 1 action 3
step 48 agent 1 at [[0.55555556 0.33333333]] agent 1 action 1
step 49 agent 1 at [[0.66666667 0.33333333]] agent 1 action 3
step 50 agent 1 at [[0.66666667 0.33333333]] agent 1 action 3
step 51 agent 1 at [[0.66666667 0.33333333]] agent 1 action 0
step 52 agent 1 at [[0.55555556 0.33333333]] agent 1 action 0
step 53 agent 1 at [[0.44444444 0.33333333]] agent 1 action 3
step 54 agent 1 at [[0.44444444 0.33333333]] agent 1 action 1
step 55 agent 1 at [[0.55555556 0.33333333]] agent 1 action 1
step 56 agent 1 at [[0.66666667 0.33333333]] agent 1 action 3
step 57 agent 1 at [[0.66666667 0.33333333]] agent 1 action 4
step 58 agent 1 at [[0.66666667 0.33333333]] agent 1 action 1
step 59 agent 1 at [[0.77777778 0.33333333]] agent 1 action 1
step 60 agent 1 at [[0.77777778 0.33333333]] agent 1 action 4
step 61 agent 1 at [[0.77777778 0.33333333]] agent 1 action 0
step 62 agent 1 at [[0.66666667 0.33333333]] agent 1 action 3
step 63 agent 1 at [[0.66666667 0.33333333]] agent 1 action 3
step 64 agent 1 at [[0.66666667 0.33333333]] agent 1 action 3
step 65 agent 1 at [[0.66666667 0.33333333]] agent 1 action 1
step 66 agent 1 at [[0.77777778 0.33333333]] agent 1 action 0
step 67 agent 1 at [[0.66666667 0.33333333]] agent 1 action 1
step 68 agent 1 at [[0.77777778 0.33333333]] agent 1 action 3
step 69 agent 1 at [[0.77777778 0.44444444]] agent 1 action 3
step 70 agent 1 at [[0.77777778 0.55555556]] agent 1 action 0
step 71 agent 1 at [[0.66666667 0.55555556]] agent 1 action 0
step 72 agent 1 at [[0.55555556 0.55555556]] agent 1 action 0
step 73 agent 1 at [[0.44444444 0.55555556]] agent 1 action 0
step 74 agent 1 at [[0.33333333 0.55555556]] agent 1 action 1
step 75 agent 1 at [[0.44444444 0.55555556]] agent 1 action 4
step 76 agent 1 at [[0.44444444 0.55555556]] agent 1 action 0
step 77 agent 1 at [[0.33333333 0.55555556]] agent 1 action 1
step 78 agent 1 at [[0.44444444 0.55555556]] agent 1 action 3
step 79 agent 1 at [[0.44444444 0.66666667]] agent 1 action 0
step 80 agent 1 at [[0.33333333 0.66666667]] agent 1 action 3
step 81 agent 1 at [[0.33333333 0.77777778]] agent 1 action 1
learnt trajectory
step 0 agent 1 at [[0.44444444 0.11111111]] agent 1 action 1
step 1 agent 1 at [[0.55555556 0.11111111]] agent 1 action 4
step 2 agent 1 at [[0.55555556 0.11111111]] agent 1 action 3
step 3 agent 1 at [[0.55555556 0.22222222]] agent 1 action 1
step 4 agent 1 at [[0.66666667 0.22222222]] agent 1 action 0
step 5 agent 1 at [[0.55555556 0.22222222]] agent 1 action 0
step 6 agent 1 at [[0.44444444 0.22222222]] agent 1 action 3
step 7 agent 1 at [[0.44444444 0.33333333]] agent 1 action 4
step 8 agent 1 at [[0.44444444 0.33333333]] agent 1 action 3
step 9 agent 1 at [[0.44444444 0.33333333]] agent 1 action 0
step 10 agent 1 at [[0.33333333 0.33333333]] agent 1 action 4
step 11 agent 1 at [[0.33333333 0.33333333]] agent 1 action 1
step 12 agent 1 at [[0.44444444 0.33333333]] agent 1 action 0
step 13 agent 1 at [[0.33333333 0.33333333]] agent 1 action 0
step 14 agent 1 at [[0.22222222 0.33333333]] agent 1 action 1
step 15 agent 1 at [[0.33333333 0.33333333]] agent 1 action 1
step 16 agent 1 at [[0.44444444 0.33333333]] agent 1 action 1
step 17 agent 1 at [[0.55555556 0.33333333]] agent 1 action 2
step 18 agent 1 at [[0.55555556 0.22222222]] agent 1 action 3
step 19 agent 1 at [[0.55555556 0.33333333]] agent 1 action 3
step 20 agent 1 at [[0.55555556 0.33333333]] agent 1 action 3
step 21 agent 1 at [[0.55555556 0.33333333]] agent 1 action 3
step 22 agent 1 at [[0.55555556 0.33333333]] agent 1 action 3
step 23 agent 1 at [[0.55555556 0.33333333]] agent 1 action 0
step 24 agent 1 at [[0.44444444 0.33333333]] agent 1 action 4
step 25 agent 1 at [[0.44444444 0.33333333]] agent 1 action 0
step 26 agent 1 at [[0.33333333 0.33333333]] agent 1 action 0
step 27 agent 1 at [[0.22222222 0.33333333]] agent 1 action 1
step 28 agent 1 at [[0.33333333 0.33333333]] agent 1 action 3
step 29 agent 1 at [[0.33333333 0.33333333]] agent 1 action 1
step 30 agent 1 at [[0.44444444 0.33333333]] agent 1 action 3
step 31 agent 1 at [[0.44444444 0.33333333]] agent 1 action 3
step 32 agent 1 at [[0.44444444 0.33333333]] agent 1 action 0
step 33 agent 1 at [[0.33333333 0.33333333]] agent 1 action 0
step 34 agent 1 at [[0.22222222 0.33333333]] agent 1 action 2
step 35 agent 1 at [[0.22222222 0.22222222]] agent 1 action 3
step 36 agent 1 at [[0.22222222 0.33333333]] agent 1 action 3
step 37 agent 1 at [[0.22222222 0.33333333]] agent 1 action 1
step 38 agent 1 at [[0.33333333 0.33333333]] agent 1 action 3
step 39 agent 1 at [[0.33333333 0.33333333]] agent 1 action 0
step 40 agent 1 at [[0.22222222 0.33333333]] agent 1 action 3
step 41 agent 1 at [[0.22222222 0.33333333]] agent 1 action 3
step 42 agent 1 at [[0.22222222 0.33333333]] agent 1 action 1
step 43 agent 1 at [[0.33333333 0.33333333]] agent 1 action 3
step 44 agent 1 at [[0.33333333 0.33333333]] agent 1 action 1
step 45 agent 1 at [[0.44444444 0.33333333]] agent 1 action 3
step 46 agent 1 at [[0.44444444 0.33333333]] agent 1 action 1
step 47 agent 1 at [[0.55555556 0.33333333]] agent 1 action 3
step 48 agent 1 at [[0.55555556 0.33333333]] agent 1 action 1
step 49 agent 1 at [[0.66666667 0.33333333]] agent 1 action 3
step 50 agent 1 at [[0.66666667 0.33333333]] agent 1 action 3
step 51 agent 1 at [[0.66666667 0.33333333]] agent 1 action 0
step 52 agent 1 at [[0.55555556 0.33333333]] agent 1 action 0
step 53 agent 1 at [[0.44444444 0.33333333]] agent 1 action 3
step 54 agent 1 at [[0.44444444 0.33333333]] agent 1 action 1
step 55 agent 1 at [[0.55555556 0.33333333]] agent 1 action 1
step 56 agent 1 at [[0.66666667 0.33333333]] agent 1 action 3
step 57 agent 1 at [[0.66666667 0.33333333]] agent 1 action 4
step 58 agent 1 at [[0.66666667 0.33333333]] agent 1 action 1
step 59 agent 1 at [[0.77777778 0.33333333]] agent 1 action 1
step 60 agent 1 at [[0.77777778 0.33333333]] agent 1 action 4
可以看出learnt trajectory与expert trajectory轨迹一样。
好了,现在来介绍里面的细节部分:
对于我们这个自己构建的环境,我们没有专家轨迹怎么办呢?那就自己来制作专家轨迹。
这里,使用下面代码进行样本收集:
for epi_iter in range(max_epi_iter):
env.reset()
a1_list = []
pi_a1_list = []
r_list = []
acc_r = 0
for MC_iter in range(max_MC_iter):
env.render()
state = env.get_state()
action1, pi_a1, log_prob1 = agent.get_action(state)
a1_list.append(action1)
pi_a1_list.append(pi_a1)
reward, done = env.step([action1, 0])
acc_r = acc_r + reward
r_list.append(reward)
下面这段代码为只有agent到达绿色的目标点采用来训练网络更新参数。
if done:
agent.train(a1_list, pi_a1_list, r_list)