参考资源 https://github.com/GAOYANGAU/DRLPytorch
Nature-DQN.py
import torch
import torch.nn as nn
from collections import deque
import numpy as np
import gym
import random
from net import AtariNet
from util import preprocess
BATCH_SIZE = 32
LR = 0.001
START_EPSILON = 1.0
FINAL_EPSILON = 0.1
EPSILON = START_EPSILON
EXPLORE = 1000000
GAMMA = 0.99
TOTAL_EPISODES = 10000000
MEMORY_SIZE = 1000000
MEMORY_THRESHOLD = 100000
UPDATE_TIME = 10000
TEST_FREQUENCY = 1000
env = gym.make('Pong-v0')
env = env.unwrapped
ACTIONS_SIZE = env.action_space.n
class Agent(object):
def __init__(self):
self.network, self.target_network = AtariNet(ACTIONS_SIZE), AtariNet(ACTIONS_SIZE)
self.memory = deque()
self.learning_count = 0
self.optimizer = torch.optim.Adam(self.network.parameters(), lr=LR)
self.loss_func = nn.MSELoss()
def action(self, state, israndom):
if israndom and random.random() < EPSILON:
return np.random.randint(0, ACTIONS_SIZE)
state = torch.unsqueeze(torch.FloatTensor(state), 0)
actions_value = self.network.forward(state)
return torch.max(actions_value, 1)[1].data.numpy()[0]
def learn(self, state, action, reward, next_state, done):
if done:
self.memory.append((state, action, reward, next_state, 0))
else:
self.memory.append((state, action, reward, next_state, 1))
if len(self.memory) > MEMORY_SIZE:
self.memory.popleft()
if len(self.memory) < MEMORY_THRESHOLD:
return
if self.learning_count % UPDATE_TIME == 0:
self.target_network.load_state_dict(self.network.state_dict())
self.learning_count += 1
batch = random.sample(self.memory, BATCH_SIZE)
state = torch.FloatTensor([x[0] for x in batch])
action = torch.LongTensor([[x[1]] for x in batch])
reward = torch.FloatTensor([[x[2]] for x in batch])
next_state = torch.FloatTensor([x[3] for x in batch])
done = torch.FloatTensor([[x[4]] for x in batch])
eval_q = self.network.forward(state).gather(1, action)
next_q = self.target_network(next_state).detach()
target_q = reward + GAMMA * next_q.max(1)[0].view(BATCH_SIZE, 1) * done
loss = self.loss_func(eval_q, target_q)
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
agent = Agent()
for i_episode in range(TOTAL_EPISODES):
state = env.reset()
state = preprocess(state)
while True:
# env.render()
action = agent.action(state, True)
next_state, reward, done, info = env.step(action)
next_state = preprocess(next_state)
agent.learn(state, action, reward, next_state, done)
state = next_state
if done:
break
if EPSILON > FINAL_EPSILON:
EPSILON -= (START_EPSILON - FINAL_EPSILON) / EXPLORE
# TEST
if i_episode % TEST_FREQUENCY == 0:
state = env.reset()
state = preprocess(state)
total_reward = 0
while True:
# env.render()
action = agent.action(state, israndom=False)
next_state, reward, done, info = env.step(action)
next_state = preprocess(next_state)
total_reward += reward
state = next_state
if done:
break
print('episode: {} , total_reward: {}'.format(i_episode, round(total_reward, 3)))
env.close()
net.py
import torch
from torch import nn
class AtariNet(nn.Module):
def __init__(self, num_actions):
super(AtariNet, self).__init__()
self.conv1 = nn.Sequential(
nn.Conv2d(4, 32, kernel_size=8, stride=4),
nn.ReLU()
)
self.conv2 = nn.Sequential(
nn.Conv2d(32, 64, kernel_size=4, stride=2),
nn.ReLU()
)
self.conv3 = nn.Sequential(
nn.Conv2d(64, 64, kernel_size=3, stride=1),
nn.ReLU()
)
self.hidden = nn.Sequential(
nn.Linear(64 * 7 * 7, 512, bias=True),
nn.ReLU()
)
self.out = nn.Sequential(
nn.Linear(512, num_actions, bias=True)
)
self.apply(self.init_weights)
def init_weights(self, m):
if type(m) == nn.Conv2d:
m.weight.data.normal_(0.0, 0.02)
if type(m) == nn.Linear:
torch.nn.init.xavier_uniform_(m.weight)
m.bias.data.fill_(0.01)
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = self.conv3(x)
x = x.view(x.size(0), -1)
x = self.hidden(x)
x = self.out(x)
return x
util.py
import cv2.cv2 as cv2
import numpy as np
def preprocess(observation):
"""
image preprocess
:param :observation: 输入一帧彩色图像210x160x3
:return: 1x84x84
"""
# 输入210x160x3的图片像素,采样成110x84的矩阵
observation = cv2.cvtColor(cv2.resize(observation, (84, 110)), cv2.COLOR_BGR2GRAY)
# 抛弃左侧26行,剩余84x84的矩阵
observation = observation[26:110,:]
# 二值化图像,1或255,大小仍为84x84
ret, observation = cv2.threshold(observation,1,255,cv2.THRESH_BINARY)
# 重整大小为84x84x1
x = np.reshape(observation,(84,84,1))
# 交换顺序,012 201 变为 1x84x84(求转置)
x = x.transpose((2, 0, 1))
return x