【转载】初探强化学习DQN的Pytorch代码解析

版权声明:本文为CSDN博主「难受啊!马飞…」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。
原文链接:https://blog.csdn.net/qq_33328642/article/details/123788966

首先上完整的代码。
这个代码是大连理工的一个小姐姐提供的。小姐姐毕竟是小姐姐,心细如丝,把理论讲的很清楚。但是代码我没怎么听懂。小姐姐在B站的视频可以给大家提供一下。不过就小姐姐这个名字,其实我是怀疑她是抠脚大汉,女装大佬。

不说了,先上完整的代码吧

1. 完整的代码

import gym
import math
import random
import numpy as np
import matplotlib.pyplot as plt
from collections import namedtuple, deque
from itertools import count
import time

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as T
from torchvision.transforms import InterpolationMode

env = gym.make(‘SpaceInvaders-v0’).unwrapped

# if gpu is to be used
device = torch.device(“cuda” if torch.cuda.is_available() else “cpu”)

######################################################################
# Replay Memory

Transition = namedtuple(‘Transition’,
(‘state’, ‘action’, ‘next_state’, ‘reward’))

class ReplayMemory(object):

<span class="token keyword">def</span> <span class="token function">__init__</span><span class="token punctuation">(</span>self<span class="token punctuation">,</span> capacity<span class="token punctuation">)</span><span class="token punctuation">:</span>
    self<span class="token punctuation">.</span>memory <span class="token operator">=</span> deque<span class="token punctuation">(</span><span class="token punctuation">[</span><span class="token punctuation">]</span><span class="token punctuation">,</span> maxlen<span class="token operator">=</span>capacity<span class="token punctuation">)</span>

<span class="token keyword">def</span> <span class="token function">push</span><span class="token punctuation">(</span>self<span class="token punctuation">,</span> <span class="token operator">*</span>args<span class="token punctuation">)</span><span class="token punctuation">:</span>
    self<span class="token punctuation">.</span>memory<span class="token punctuation">.</span>append<span class="token punctuation">(</span>Transition<span class="token punctuation">(</span><span class="token operator">*</span>args<span class="token punctuation">)</span><span class="token punctuation">)</span>

<span class="token keyword">def</span> <span class="token function">sample</span><span class="token punctuation">(</span>self<span class="token punctuation">,</span> batch_size<span class="token punctuation">)</span><span class="token punctuation">:</span>
    <span class="token keyword">return</span> random<span class="token punctuation">.</span>sample<span class="token punctuation">(</span>self<span class="token punctuation">.</span>memory<span class="token punctuation">,</span> batch_size<span class="token punctuation">)</span>

<span class="token keyword">def</span> <span class="token function">__len__</span><span class="token punctuation">(</span>self<span class="token punctuation">)</span><span class="token punctuation">:</span>
    <span class="token keyword">return</span> <span class="token builtin">len</span><span class="token punctuation">(</span>self<span class="token punctuation">.</span>memory<span class="token punctuation">)</span>

######################################################################
# DQN algorithm

class DQN(nn.Module):

<span class="token keyword">def</span> <span class="token function">__init__</span><span class="token punctuation">(</span>self<span class="token punctuation">,</span> h<span class="token punctuation">,</span> w<span class="token punctuation">,</span> outputs<span class="token punctuation">)</span><span class="token punctuation">:</span>
    <span class="token builtin">super</span><span class="token punctuation">(</span>DQN<span class="token punctuation">,</span> self<span class="token punctuation">)</span><span class="token punctuation">.</span>__init__<span class="token punctuation">(</span><span class="token punctuation">)</span>
    self<span class="token punctuation">.</span>conv1 <span class="token operator">=</span> nn<span class="token punctuation">.</span>Conv2d<span class="token punctuation">(</span><span class="token number">4</span><span class="token punctuation">,</span> <span class="token number">32</span><span class="token punctuation">,</span> kernel_size<span class="token operator">=</span><span class="token number">8</span><span class="token punctuation">,</span> stride<span class="token operator">=</span><span class="token number">4</span><span class="token punctuation">)</span>
    self<span class="token punctuation">.</span>bn1 <span class="token operator">=</span> nn<span class="token punctuation">.</span>BatchNorm2d<span class="token punctuation">(</span><span class="token number">32</span><span class="token punctuation">)</span>
    self<span class="token punctuation">.</span>conv2 <span class="token operator">=</span> nn<span class="token punctuation">.</span>Conv2d<span class="token punctuation">(</span><span class="token number">32</span><span class="token punctuation">,</span> <span class="token number">64</span><span class="token punctuation">,</span> kernel_size<span class="token operator">=</span><span class="token number">4</span><span class="token punctuation">,</span> stride<span class="token operator">=</span><span class="token number">2</span><span class="token punctuation">)</span>
    self<span class="token punctuation">.</span>bn2 <span class="token operator">=</span> nn<span class="token punctuation">.</span>BatchNorm2d<span class="token punctuation">(</span><span class="token number">64</span><span class="token punctuation">)</span>
    self<span class="token punctuation">.</span>conv3 <span class="token operator">=</span> nn<span class="token punctuation">.</span>Conv2d<span class="token punctuation">(</span><span class="token number">64</span><span class="token punctuation">,</span> <span class="token number">64</span><span class="token punctuation">,</span> kernel_size<span class="token operator">=</span><span class="token number">3</span><span class="token punctuation">,</span> stride<span class="token operator">=</span><span class="token number">1</span><span class="token punctuation">)</span>
    self<span class="token punctuation">.</span>bn3 <span class="token operator">=</span> nn<span class="token punctuation">.</span>BatchNorm2d<span class="token punctuation">(</span><span class="token number">64</span><span class="token punctuation">)</span>

    <span class="token keyword">def</span> <span class="token function">conv2d_size_out</span><span class="token punctuation">(</span>size<span class="token punctuation">,</span> kernel_size<span class="token punctuation">,</span> stride<span class="token punctuation">)</span><span class="token punctuation">:</span>
        <span class="token keyword">return</span> <span class="token punctuation">(</span>size <span class="token operator">-</span> <span class="token punctuation">(</span>kernel_size <span class="token operator">-</span> <span class="token number">1</span><span class="token punctuation">)</span> <span class="token operator">-</span> <span class="token number">1</span><span class="token punctuation">)</span> <span class="token operator">//</span> stride  <span class="token operator">+</span> <span class="token number">1</span>
    convw <span class="token operator">=</span> conv2d_size_out<span class="token punctuation">(</span>conv2d_size_out<span class="token punctuation">(</span>conv2d_size_out<span class="token punctuation">(</span>w<span class="token punctuation">,</span> <span class="token number">8</span><span class="token punctuation">,</span> <span class="token number">4</span><span class="token punctuation">)</span><span class="token punctuation">,</span> <span class="token number">4</span><span class="token punctuation">,</span> <span class="token number">2</span><span class="token punctuation">)</span><span class="token punctuation">,</span> <span class="token number">3</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">)</span>
    convh <span class="token operator">=</span> conv2d_size_out<span class="token punctuation">(</span>conv2d_size_out<span class="token punctuation">(</span>conv2d_size_out<span class="token punctuation">(</span>h<span class="token punctuation">,</span> <span class="token number">8</span><span class="token punctuation">,</span> <span class="token number">4</span><span class="token punctuation">)</span><span class="token punctuation">,</span> <span class="token number">4</span><span class="token punctuation">,</span> <span class="token number">2</span><span class="token punctuation">)</span><span class="token punctuation">,</span> <span class="token number">3</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">)</span>
    linear_input_size <span class="token operator">=</span> convw <span class="token operator">*</span> convh <span class="token operator">*</span> <span class="token number">64</span>
    self<span class="token punctuation">.</span>l1 <span class="token operator">=</span> nn<span class="token punctuation">.</span>Linear<span class="token punctuation">(</span>linear_input_size<span class="token punctuation">,</span> <span class="token number">512</span><span class="token punctuation">)</span>
    self<span class="token punctuation">.</span>l2 <span class="token operator">=</span> nn<span class="token punctuation">.</span>Linear<span class="token punctuation">(</span><span class="token number">512</span><span class="token punctuation">,</span> outputs<span class="token punctuation">)</span>

<span class="token keyword">def</span> <span class="token function">forward</span><span class="token punctuation">(</span>self<span class="token punctuation">,</span> x<span class="token punctuation">)</span><span class="token punctuation">:</span>
    x <span class="token operator">=</span> x<span class="token punctuation">.</span>to<span class="token punctuation">(</span>device<span class="token punctuation">)</span>
    x <span class="token operator">=</span> F<span class="token punctuation">.</span>relu<span class="token punctuation">(</span>self<span class="token punctuation">.</span>bn1<span class="token punctuation">(</span>self<span class="token punctuation">.</span>conv1<span class="token punctuation">(</span>x<span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token punctuation">)</span>
    x <span class="token operator">=</span> F<span class="token punctuation">.</span>relu<span class="token punctuation">(</span>self<span class="token punctuation">.</span>bn2<span class="token punctuation">(</span>self<span class="token punctuation">.</span>conv2<span class="token punctuation">(</span>x<span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token punctuation">)</span>
    x <span class="token operator">=</span> F<span class="token punctuation">.</span>relu<span class="token punctuation">(</span>self<span class="token punctuation">.</span>bn3<span class="token punctuation">(</span>self<span class="token punctuation">.</span>conv3<span class="token punctuation">(</span>x<span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token punctuation">)</span>
    x <span class="token operator">=</span> F<span class="token punctuation">.</span>relu<span class="token punctuation">(</span>self<span class="token punctuation">.</span>l1<span class="token punctuation">(</span>x<span class="token punctuation">.</span>view<span class="token punctuation">(</span>x<span class="token punctuation">.</span>size<span class="token punctuation">(</span><span class="token number">0</span><span class="token punctuation">)</span><span class="token punctuation">,</span> <span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token punctuation">)</span>
    <span class="token keyword">return</span> self<span class="token punctuation">.</span>l2<span class="token punctuation">(</span>x<span class="token punctuation">.</span>view<span class="token punctuation">(</span><span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">512</span><span class="token punctuation">)</span><span class="token punctuation">)</span>

######################################################################
# Input extraction

resize = T.Compose([T.ToPILImage(),
T.Grayscale(num_output_channels=1),
T.Resize((84, 84), interpolation=InterpolationMode.BICUBIC),
T.ToTensor()])

def get_screen():
# Transpose it into torch order (CHW).
screen = env.render(mode=‘rgb_array’).transpose((2, 0, 1))
screen = np.ascontiguousarray(screen, dtype=np.float32) / 255
screen = torch.from_numpy(screen)
# Resize, and add a batch dimension (BCHW)
return resize(screen).unsqueeze(0)

######################################################################
# Training

# 参数和网络初始化
BATCH_SIZE = 32
GAMMA = 0.99
EPS_START = 1.0
EPS_END = 0.1
EPS_DECAY = 10000
TARGET_UPDATE = 10

init_screen = get_screen()
_, _, screen_height, screen_width = init_screen.shape

# Get number of actions from gym action space
n_actions = env.action_space.n

policy_net = DQN(screen_height, screen_width, n_actions).to(device)
target_net = DQN(screen_height, screen_width, n_actions).to(device)
target_net.load_state_dict(policy_net.state_dict())
target_net.eval()

optimizer = optim.RMSprop(policy_net.parameters())
memory = ReplayMemory(100000)

steps_done = 0

def select_action(state):
global steps_done
sample = random.random()
eps_threshold = EPS_END + (EPS_START - EPS_END)
math.exp(-1. steps_done / EPS_DECAY)
steps_done += 1
if sample > eps_threshold:
with torch.no_grad():
return policy_net(state).max(1)[1].view(1, 1)
else:
return torch.tensor([[random.randrange(n_actions)]], device=device, dtype=torch.long)

episode_durations = []

def plot_durations():
plt.figure(1)
plt.clf()
durations_t = torch.tensor(episode_durations, dtype=torch.float)
plt.title(‘Training…’)
plt.xlabel(‘Episode’)
plt.ylabel(‘Duration’)
plt.plot(durations_t.numpy())
# Take 100 episode averages and plot them too
if len(durations_t) >= 100:
means = durations_t.unfold(0, 100, 1).mean(1).view(-1)
means = torch.cat((torch.zeros(99), means))
plt.plot(means.numpy())

plt<span class="token punctuation">.</span>pause<span class="token punctuation">(</span><span class="token number">0.001</span><span class="token punctuation">)</span>  <span class="token comment"># pause a bit so that plots are updated</span>

def optimize_model():
if len(memory) < BATCH_SIZE:
return
transitions = memory.sample(BATCH_SIZE)
batch = Transition(zip(transitions))

<span class="token comment"># Compute a mask of non-final states and concatenate the batch elements</span>
<span class="token comment"># (a final state would've been the one after which simulation ended)</span>
non_final_mask <span class="token operator">=</span> torch<span class="token punctuation">.</span>tensor<span class="token punctuation">(</span><span class="token builtin">tuple</span><span class="token punctuation">(</span><span class="token builtin">map</span><span class="token punctuation">(</span><span class="token keyword">lambda</span> s<span class="token punctuation">:</span> s <span class="token keyword">is</span> <span class="token keyword">not</span> <span class="token boolean">None</span><span class="token punctuation">,</span> batch<span class="token punctuation">.</span>next_state<span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token punctuation">,</span>
                              device<span class="token operator">=</span>device<span class="token punctuation">,</span> dtype<span class="token operator">=</span>torch<span class="token punctuation">.</span><span class="token builtin">bool</span><span class="token punctuation">)</span>
non_final_next_states <span class="token operator">=</span> torch<span class="token punctuation">.</span>cat<span class="token punctuation">(</span><span class="token punctuation">[</span>s <span class="token keyword">for</span> s <span class="token keyword">in</span> batch<span class="token punctuation">.</span>next_state <span class="token keyword">if</span> s <span class="token keyword">is</span> <span class="token keyword">not</span> <span class="token boolean">None</span><span class="token punctuation">]</span><span class="token punctuation">)</span>
state_batch <span class="token operator">=</span> torch<span class="token punctuation">.</span>cat<span class="token punctuation">(</span>batch<span class="token punctuation">.</span>state<span class="token punctuation">)</span>
action_batch <span class="token operator">=</span> torch<span class="token punctuation">.</span>cat<span class="token punctuation">(</span>batch<span class="token punctuation">.</span>action<span class="token punctuation">)</span>
reward_batch <span class="token operator">=</span> torch<span class="token punctuation">.</span>cat<span class="token punctuation">(</span>batch<span class="token punctuation">.</span>reward<span class="token punctuation">)</span>

state_action_values <span class="token operator">=</span> policy_net<span class="token punctuation">(</span>state_batch<span class="token punctuation">)</span><span class="token punctuation">.</span>gather<span class="token punctuation">(</span><span class="token number">1</span><span class="token punctuation">,</span> action_batch<span class="token punctuation">)</span>
next_state_values <span class="token operator">=</span> torch<span class="token punctuation">.</span>zeros<span class="token punctuation">(</span>BATCH_SIZE<span class="token punctuation">,</span> device<span class="token operator">=</span>device<span class="token punctuation">)</span>
next_state_values<span class="token punctuation">[</span>non_final_mask<span class="token punctuation">]</span> <span class="token operator">=</span> target_net<span class="token punctuation">(</span>non_final_next_states<span class="token punctuation">)</span><span class="token punctuation">.</span><span class="token builtin">max</span><span class="token punctuation">(</span><span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">]</span><span class="token punctuation">.</span>detach<span class="token punctuation">(</span><span class="token punctuation">)</span>
expected_state_action_values <span class="token operator">=</span> <span class="token punctuation">(</span>next_state_values <span class="token operator">*</span> GAMMA<span class="token punctuation">)</span> <span class="token operator">+</span> reward_batch

<span class="token comment"># Compute Huber loss</span>
criterion <span class="token operator">=</span> nn<span class="token punctuation">.</span>MSELoss<span class="token punctuation">(</span><span class="token punctuation">)</span>
loss <span class="token operator">=</span> criterion<span class="token punctuation">(</span>state_action_values<span class="token punctuation">,</span> expected_state_action_values<span class="token punctuation">.</span>unsqueeze<span class="token punctuation">(</span><span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">)</span>

<span class="token comment"># Optimize the model</span>
optimizer<span class="token punctuation">.</span>zero_grad<span class="token punctuation">(</span><span class="token punctuation">)</span>
loss<span class="token punctuation">.</span>backward<span class="token punctuation">(</span><span class="token punctuation">)</span>
<span class="token keyword">for</span> param <span class="token keyword">in</span> policy_net<span class="token punctuation">.</span>parameters<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">:</span>
    param<span class="token punctuation">.</span>grad<span class="token punctuation">.</span>data<span class="token punctuation">.</span>clamp_<span class="token punctuation">(</span><span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">)</span>
optimizer<span class="token punctuation">.</span>step<span class="token punctuation">(</span><span class="token punctuation">)</span>

def random_start(skip_steps=30, m=4):
env.reset()
state_queue = deque([], maxlen=m)
next_state_queue = deque([], maxlen=m)
done = False
for i in range(skip_steps):
if (i+1) <= m:
state_queue.append(get_screen())
elif m < (i + 1) <= 2*m:
next_state_queue.append(get_screen())
else:
state_queue.append(next_state_queue[0])
next_state_queue.append(get_screen())

    action <span class="token operator">=</span> env<span class="token punctuation">.</span>action_space<span class="token punctuation">.</span>sample<span class="token punctuation">(</span><span class="token punctuation">)</span>
    _<span class="token punctuation">,</span> _<span class="token punctuation">,</span> done<span class="token punctuation">,</span> _ <span class="token operator">=</span> env<span class="token punctuation">.</span>step<span class="token punctuation">(</span>action<span class="token punctuation">)</span>
    <span class="token keyword">if</span> done<span class="token punctuation">:</span>
        <span class="token keyword">break</span>
<span class="token keyword">return</span> done<span class="token punctuation">,</span> state_queue<span class="token punctuation">,</span> next_state_queue

######################################################################
# Start Training

num_episodes = 10000
m = 4
for i_episode in range(num_episodes):
# Initialize the environment and state
done, state_queue, next_state_queue = random_start()
if done:
continue

state <span class="token operator">=</span> torch<span class="token punctuation">.</span>cat<span class="token punctuation">(</span><span class="token builtin">tuple</span><span class="token punctuation">(</span>state_queue<span class="token punctuation">)</span><span class="token punctuation">,</span> dim<span class="token operator">=</span><span class="token number">1</span><span class="token punctuation">)</span>
<span class="token keyword">for</span> t <span class="token keyword">in</span> count<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">:</span>
    reward <span class="token operator">=</span> <span class="token number">0</span>
    m_reward <span class="token operator">=</span> <span class="token number">0</span>
    <span class="token comment"># 每m帧完成一次action</span>
    action <span class="token operator">=</span> select_action<span class="token punctuation">(</span>state<span class="token punctuation">)</span>

    <span class="token keyword">for</span> i <span class="token keyword">in</span> <span class="token builtin">range</span><span class="token punctuation">(</span>m<span class="token punctuation">)</span><span class="token punctuation">:</span>
        _<span class="token punctuation">,</span> reward<span class="token punctuation">,</span> done<span class="token punctuation">,</span> _ <span class="token operator">=</span> env<span class="token punctuation">.</span>step<span class="token punctuation">(</span>action<span class="token punctuation">.</span>item<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">)</span>
        <span class="token keyword">if</span> <span class="token keyword">not</span> done<span class="token punctuation">:</span>
            next_state_queue<span class="token punctuation">.</span>append<span class="token punctuation">(</span>get_screen<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">)</span>
        <span class="token keyword">else</span><span class="token punctuation">:</span>
            <span class="token keyword">break</span>
        m_reward <span class="token operator">+=</span> reward

    <span class="token keyword">if</span> <span class="token keyword">not</span> done<span class="token punctuation">:</span>
        next_state <span class="token operator">=</span> torch<span class="token punctuation">.</span>cat<span class="token punctuation">(</span><span class="token builtin">tuple</span><span class="token punctuation">(</span>next_state_queue<span class="token punctuation">)</span><span class="token punctuation">,</span> dim<span class="token operator">=</span><span class="token number">1</span><span class="token punctuation">)</span>
    <span class="token keyword">else</span><span class="token punctuation">:</span>
        next_state <span class="token operator">=</span> <span class="token boolean">None</span>
        m_reward <span class="token operator">=</span> <span class="token operator">-</span><span class="token number">150</span>
    m_reward <span class="token operator">=</span> torch<span class="token punctuation">.</span>tensor<span class="token punctuation">(</span><span class="token punctuation">[</span>m_reward<span class="token punctuation">]</span><span class="token punctuation">,</span> device<span class="token operator">=</span>device<span class="token punctuation">)</span>

    memory<span class="token punctuation">.</span>push<span class="token punctuation">(</span>state<span class="token punctuation">,</span> action<span class="token punctuation">,</span> next_state<span class="token punctuation">,</span> m_reward<span class="token punctuation">)</span>

    state <span class="token operator">=</span> next_state
    optimize_model<span class="token punctuation">(</span><span class="token punctuation">)</span>

    <span class="token keyword">if</span> done<span class="token punctuation">:</span>
        episode_durations<span class="token punctuation">.</span>append<span class="token punctuation">(</span>t <span class="token operator">+</span> <span class="token number">1</span><span class="token punctuation">)</span>
        plot_durations<span class="token punctuation">(</span><span class="token punctuation">)</span>
        <span class="token keyword">break</span>

<span class="token comment"># Update the target network, copying all weights and biases in DQN</span>
<span class="token keyword">if</span> i_episode <span class="token operator">%</span> TARGET_UPDATE <span class="token operator">==</span> <span class="token number">0</span><span class="token punctuation">:</span>
    target_net<span class="token punctuation">.</span>load_state_dict<span class="token punctuation">(</span>policy_net<span class="token punctuation">.</span>state_dict<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">)</span>
    torch<span class="token punctuation">.</span>save<span class="token punctuation">(</span>policy_net<span class="token punctuation">.</span>state_dict<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">,</span> <span class="token string">'weights/policy_net_weights_{0}.pth'</span><span class="token punctuation">.</span><span class="token builtin">format</span><span class="token punctuation">(</span>i_episode<span class="token punctuation">)</span><span class="token punctuation">)</span>

print(‘Complete’)
env.close()
torch.save(policy_net.state_dict(), ‘weights/policy_net_weights.pth’)

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134
  • 135
  • 136
  • 137
  • 138
  • 139
  • 140
  • 141
  • 142
  • 143
  • 144
  • 145
  • 146
  • 147
  • 148
  • 149
  • 150
  • 151
  • 152
  • 153
  • 154
  • 155
  • 156
  • 157
  • 158
  • 159
  • 160
  • 161
  • 162
  • 163
  • 164
  • 165
  • 166
  • 167
  • 168
  • 169
  • 170
  • 171
  • 172
  • 173
  • 174
  • 175
  • 176
  • 177
  • 178
  • 179
  • 180
  • 181
  • 182
  • 183
  • 184
  • 185
  • 186
  • 187
  • 188
  • 189
  • 190
  • 191
  • 192
  • 193
  • 194
  • 195
  • 196
  • 197
  • 198
  • 199
  • 200
  • 201
  • 202
  • 203
  • 204
  • 205
  • 206
  • 207
  • 208
  • 209
  • 210
  • 211
  • 212
  • 213
  • 214
  • 215
  • 216
  • 217
  • 218
  • 219
  • 220
  • 221
  • 222
  • 223
  • 224
  • 225
  • 226
  • 227
  • 228
  • 229
  • 230
  • 231
  • 232
  • 233
  • 234
  • 235
  • 236
  • 237
  • 238
  • 239
  • 240
  • 241
  • 242
  • 243
  • 244
  • 245
  • 246
  • 247
  • 248
  • 249
  • 250
  • 251
  • 252
  • 253
  • 254
  • 255
  • 256
  • 257
  • 258
  • 259
  • 260
  • 261
  • 262

2. 逐个函数的解析

2.1 定义Replay Memary

改代码中使用具名元组namedtuple()定义一个Transition ,用于存储agent与环境交互的(s,a,r,s_)

Transition = namedtuple('Transition',('state', 'action', 'next_state', 'reward'))

 
 
  • 1

这个具名元组很简单
举个例子:

Student = namedtuple('Student', ('name', 'gender'))
s = Student('小花', '女')#给属性赋值

# 属性访问,有多种方法访问属性
第一种方法
print(s.name)
print(s.gender)
‘’’
小花

‘’‘

第二种方法
print(s[0])
print(s[1])
’‘’
小花

‘’‘

还可以迭代
for i in s:
print(i)
’‘’
小花

‘’'

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25

2.2 ReplayMemory

class ReplayMemory(object):
    def __init__(self, capacity):
        self.memory = deque([], maxlen=capacity)#deque是为了实现插入和删除操作的双向列表,适用于队列和栈:
    def push(self, *args):
        self.memory.append(Transition(*args))
    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)#使用random.sample从memory中随机抽取batch_size个数据
    def __len__(self):
        return len(self.memory)

 
 
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • def init(self, capacity)没啥好说的,就是定义一个双向列表。
  • def push(self, *args)就是向memory中添加Transition,这个memary是一个列表,后面会详解。
  • def sample(self, batch_size)是随机采样。random.sample()其中的第一个参数是即将被采样的列表,第二个参数采样的批次。这个大家应该都懂。后面我也有例子。

2.3 DQN algorithm

class DQN(nn.Module):
    def __init__(self, h, w, outputs):
        super(DQN, self).__init__()
        self.conv1 = nn.Conv2d(4, 32, kernel_size=8, stride=4)#设置第一个卷积层
        self.bn1 = nn.BatchNorm2d(32)#设置第一个卷积层的偏置
        self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2)#设置第二个卷积层
        self.bn2 = nn.BatchNorm2d(64)#设置第2个卷积层的偏置
        self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1)#设置第3个卷积层
        self.bn3 = nn.BatchNorm2d(64)#设置第3个卷积层的偏置
        def conv2d_size_out(size, kernel_size, stride):
            return (size - (kernel_size - 1) - 1) // stride  + 1
        convw = conv2d_size_out(conv2d_size_out(conv2d_size_out(w, 8, 4), 4, 2), 3, 1)#,输入84 宽  7
        convh = conv2d_size_out(conv2d_size_out(conv2d_size_out(h, 8, 4), 4, 2), 3, 1)#,输入84 高  7
        linear_input_size = convw * convh * 64
        #计算最终的尺寸,因为最后的feature map的尺寸是7*7*64,如果拉长为1*n,则是7*7*64 = 3136
        self.l1 = nn.Linear(linear_input_size, 512)#这边就是先从3136到512.也就是全连接层的神经元的个数,说实话,这个方法好low
        self.l2 = nn.Linear(512, outputs)#最后模型输出为2,两个动作么。
<span class="token keyword">def</span> <span class="token function">forward</span><span class="token punctuation">(</span>self<span class="token punctuation">,</span> x<span class="token punctuation">)</span><span class="token punctuation">:</span>
    x <span class="token operator">=</span> x<span class="token punctuation">.</span>to<span class="token punctuation">(</span>device<span class="token punctuation">)</span>
    x <span class="token operator">=</span> F<span class="token punctuation">.</span>relu<span class="token punctuation">(</span>self<span class="token punctuation">.</span>bn1<span class="token punctuation">(</span>self<span class="token punctuation">.</span>conv1<span class="token punctuation">(</span>x<span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token comment">#用激活函数处理C1</span>
    x <span class="token operator">=</span> F<span class="token punctuation">.</span>relu<span class="token punctuation">(</span>self<span class="token punctuation">.</span>bn2<span class="token punctuation">(</span>self<span class="token punctuation">.</span>conv2<span class="token punctuation">(</span>x<span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token comment">#用激活函数处理C2</span>
    x <span class="token operator">=</span> F<span class="token punctuation">.</span>relu<span class="token punctuation">(</span>self<span class="token punctuation">.</span>bn3<span class="token punctuation">(</span>self<span class="token punctuation">.</span>conv3<span class="token punctuation">(</span>x<span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token comment">#用激活函数处理C3</span>
    x <span class="token operator">=</span> F<span class="token punctuation">.</span>relu<span class="token punctuation">(</span>self<span class="token punctuation">.</span>l1<span class="token punctuation">(</span>x<span class="token punctuation">.</span>view<span class="token punctuation">(</span>x<span class="token punctuation">.</span>size<span class="token punctuation">(</span><span class="token number">0</span><span class="token punctuation">)</span><span class="token punctuation">,</span> <span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token comment">#将第3次卷积的输出拉伸为一行</span>
    <span class="token keyword">return</span> self<span class="token punctuation">.</span>l2<span class="token punctuation">(</span>x<span class="token punctuation">.</span>view<span class="token punctuation">(</span><span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">512</span><span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token comment">#-1表示不知道数据由多少行,但是直到最后的数据一定是512列</span>
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25

这是一个常规的使用pytorch搭建网络模型的框架,相信大家都懂。而且我在里面也注释了。
需要注意的一点是:

  • def conv2d_size_out(size, kernel_size, stride):这个其实就是求最后一个卷积层的feature map的尺寸。这个DQN输入的是8484的图像,按照上面的代码,最后一层的feature map的尺寸就是77,一共64个。这样做只是为了和第一个全连接层衔接一下。其实吧,这样做感觉有点多余,正常的代码用flatten()就可以了。关于如何拉平feature map,大家可以看看其他方法。
  • 运行下面代码查看,当只有两个动作时,这个网络的输出。我一开始以为网络的输出应该也是按照批次来的,也就是说当模型使出32个批次的两个动作的q值应该是这个样的:[32,1,2].也就说是应该是32个1行两列的。但是实际上,是[32,2].即32行两列。这样就能解释代码的结构了。但是当我把模型拆开了之后才发现
class DQN(nn.Module):
    def __init__(self, h, w, outputs):
        super(DQN, self).__init__()
        self.conv1 = nn.Conv2d(4, 32, kernel_size=8, stride=4)#设置第一个卷积层
        self.bn1 = nn.BatchNorm2d(32)#设置第一个卷积层的偏置
        self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2)#设置第二个卷积层
        self.bn2 = nn.BatchNorm2d(64)#设置第2个卷积层的偏置
        self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1)#设置第3个卷积层
        self.bn3 = nn.BatchNorm2d(64)#设置第3个卷积层的偏置
    <span class="token keyword">def</span> <span class="token function">conv2d_size_out</span><span class="token punctuation">(</span>size<span class="token punctuation">,</span> kernel_size<span class="token punctuation">,</span> stride<span class="token punctuation">)</span><span class="token punctuation">:</span>
        <span class="token keyword">return</span> <span class="token punctuation">(</span>size <span class="token operator">-</span> <span class="token punctuation">(</span>kernel_size <span class="token operator">-</span> <span class="token number">1</span><span class="token punctuation">)</span> <span class="token operator">-</span> <span class="token number">1</span><span class="token punctuation">)</span> <span class="token operator">//</span> stride  <span class="token operator">+</span> <span class="token number">1</span>
    convw <span class="token operator">=</span> conv2d_size_out<span class="token punctuation">(</span>conv2d_size_out<span class="token punctuation">(</span>conv2d_size_out<span class="token punctuation">(</span>w<span class="token punctuation">,</span> <span class="token number">8</span><span class="token punctuation">,</span> <span class="token number">4</span><span class="token punctuation">)</span><span class="token punctuation">,</span> <span class="token number">4</span><span class="token punctuation">,</span> <span class="token number">2</span><span class="token punctuation">)</span><span class="token punctuation">,</span> <span class="token number">3</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">)</span><span class="token comment">#,输入84 宽  7</span>
    convh <span class="token operator">=</span> conv2d_size_out<span class="token punctuation">(</span>conv2d_size_out<span class="token punctuation">(</span>conv2d_size_out<span class="token punctuation">(</span>h<span class="token punctuation">,</span> <span class="token number">8</span><span class="token punctuation">,</span> <span class="token number">4</span><span class="token punctuation">)</span><span class="token punctuation">,</span> <span class="token number">4</span><span class="token punctuation">,</span> <span class="token number">2</span><span class="token punctuation">)</span><span class="token punctuation">,</span> <span class="token number">3</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">)</span><span class="token comment">#,输入84 高  7</span>
    linear_input_size <span class="token operator">=</span> convw <span class="token operator">*</span> convh <span class="token operator">*</span> <span class="token number">64</span>
    <span class="token comment">#计算最终的尺寸,因为最后的feature map的尺寸是7*7*64,如果拉长为1*n,则是7*7*64 = 3136</span>
    self<span class="token punctuation">.</span>l1 <span class="token operator">=</span> nn<span class="token punctuation">.</span>Linear<span class="token punctuation">(</span>linear_input_size<span class="token punctuation">,</span> <span class="token number">512</span><span class="token punctuation">)</span><span class="token comment">#这边就是先从3136到512.也就是全连接层的神经元的个数,说实话,这个方法好low</span>
    self<span class="token punctuation">.</span>l2 <span class="token operator">=</span> nn<span class="token punctuation">.</span>Linear<span class="token punctuation">(</span><span class="token number">512</span><span class="token punctuation">,</span> outputs<span class="token punctuation">)</span><span class="token comment">#最后模型输出为2,两个动作么。</span>

<span class="token keyword">def</span> <span class="token function">forward</span><span class="token punctuation">(</span>self<span class="token punctuation">,</span> x<span class="token punctuation">)</span><span class="token punctuation">:</span>
    <span class="token comment">#x = x.to(device)</span>
    x <span class="token operator">=</span> F<span class="token punctuation">.</span>relu<span class="token punctuation">(</span>self<span class="token punctuation">.</span>bn1<span class="token punctuation">(</span>self<span class="token punctuation">.</span>conv1<span class="token punctuation">(</span>x<span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token comment">#用激活函数处理C1</span>
    x <span class="token operator">=</span> F<span class="token punctuation">.</span>relu<span class="token punctuation">(</span>self<span class="token punctuation">.</span>bn2<span class="token punctuation">(</span>self<span class="token punctuation">.</span>conv2<span class="token punctuation">(</span>x<span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token comment">#用激活函数处理C2</span>
    x <span class="token operator">=</span> F<span class="token punctuation">.</span>relu<span class="token punctuation">(</span>self<span class="token punctuation">.</span>bn3<span class="token punctuation">(</span>self<span class="token punctuation">.</span>conv3<span class="token punctuation">(</span>x<span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token comment">#用激活函数处理C3</span>
    x <span class="token operator">=</span> F<span class="token punctuation">.</span>relu<span class="token punctuation">(</span>self<span class="token punctuation">.</span>l1<span class="token punctuation">(</span>x<span class="token punctuation">.</span>view<span class="token punctuation">(</span>x<span class="token punctuation">.</span>size<span class="token punctuation">(</span><span class="token number">0</span><span class="token punctuation">)</span><span class="token punctuation">,</span> <span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token comment">#将第3次卷积的输出拉伸为一行</span>
    <span class="token keyword">return</span> self<span class="token punctuation">.</span>l2<span class="token punctuation">(</span>x<span class="token punctuation">.</span>view<span class="token punctuation">(</span><span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">512</span><span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token comment">#-1表示不知道数据由多少行,但是直到最后的数据一定是512列</span>

policy_net = DQN(84, 84, 2)#Q
x = torch.rand(32,4,84, 84)

xout = policy_net(x)

print(xout.size())
#[32,2]
print(xout)

tensor([[ 3.4981e-02, 3.1048e-02],
[ 1.4112e-01, -5.2676e-02],
[-3.3868e-01, 3.9583e-02],
[ 7.5908e-02, -1.2230e-01],
[ 1.4027e-01, -1.7528e-02],
[-1.0966e-02, 6.2111e-02],
[-2.2511e-02, -6.1829e-02],
[ 3.2599e-02, -8.9155e-02],
[ 9.7833e-02, -5.0325e-02],
[-6.4633e-02, -8.8093e-02],
[-4.3771e-02, 1.5452e-01],
[-1.7478e-01, -1.3224e-01],
[ 1.9658e-02, 8.1575e-03],
[-1.6989e-01, -6.6487e-03],
[-1.6566e-01, -1.0833e-01],
[-9.5961e-02, 1.1235e-02],
[ 1.0005e-01, -1.1150e-02],
[ 1.8165e-02, 9.9491e-03],
[-2.3947e-01, 9.7802e-02],
[-5.2116e-02, 4.8583e-02],
[ 2.2504e-02, 3.8262e-04],
[-1.1822e-01, -2.0696e-01],
[-1.4129e-01, -1.9254e-01],
[-2.2170e-01, -1.2232e-01],
[ 3.3542e-02, 3.3005e-03],
[ 1.5150e-01, 1.5330e-01],
[-2.3675e-01, -2.4939e-01],
[-1.0502e-01, 7.2696e-02],
[-1.3213e-01, 1.5113e-01],
[ 6.1988e-02, 2.5367e-02],
[-4.2924e-01, -4.0167e-02],
[ 5.1474e-02, 2.6885e-01]], grad_fn=<AddmmBackward0>)

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68

2.4 图像预处理

resize = T.Compose([T.ToPILImage(),
                    T.Grayscale(num_output_channels=1),
                    T.Resize((84, 84), interpolation=InterpolationMode.BICUBIC),
                    T.ToTensor()])

 
 
  • 1
  • 2
  • 3
  • 4

#Compose法是将多种变换组合在一起。在这个步骤中,有Resize,灰度处理,
#ToTensor将PILImage转变为torch.FloatTensor的数据形式
#ToPILImage将shape为(C,H,W)的Tensor或shape为(H,W,C)的numpy.ndarray转换成PIL.Image,值不变

2.5 截屏函数

def get_screen():
    #截取游戏的屏幕,用于做训练数据的状态
    # Transpose it into torch order (CHW).
    screen = env.render(mode='rgb_array').transpose((2, 0, 1))
    #env.render扮演图像引擎的作用,以便直观地显示当前环境。transpose将图像的通道数换到最前面
    screen = np.ascontiguousarray(screen, dtype=np.float32) / 255
    #ascontiguousarray函数将一个内存不连续存储的数组转换为内存连续存储的数组,使得运行速度更快。
    screen = torch.from_numpy(screen)#即 从numpy.ndarray创建一个张量。
    # Resize, and add a batch dimension (BCHW)
    return resize(screen).unsqueeze(0)#在第0维度增加一个维度,让图像从chw变成bchw。其中b表示批次

 
 
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10

2.6 超参数

# 参数和网络初始化
BATCH_SIZE = 32#从transition提取样本的批次大小
GAMMA = 0.99#衰减系数
EPS_START = 1.0#贪婪参数初始值
EPS_END = 0.1#贪婪参数最小值
EPS_DECAY = 10000#贪婪参数变化次数
TARGET_UPDATE = 10#target net更新次数
init_screen = get_screen()#采集游戏画面,尺寸[32,4,84,84],第一个是批次的大小,第二个图像数量,最后两个是图像尺寸
_, _, screen_height, screen_width = init_screen.shape#得到画面的尺寸:宽高
n_actions = env.action_space.n#获取游戏的动作空间,左右两个
#初始化模型
policy_net = DQN(screen_height, screen_width, n_actions).to(device)#Q
target_net = DQN(screen_height, screen_width, n_actions).to(device)#T
target_net.load_state_dict(policy_net.state_dict())#初始阶段target net和main net是一样的参数
target_net.eval()#表示步更新,只评估输出。
optimizer = optim.RMSprop(policy_net.parameters())#使用RMSprop优化网络
memory = ReplayMemory(100000)#定义经验池的容量capacity
steps_done = 0

 
 
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18

这边没什么可说的,大家都看得懂。

policy_net = DQN(screen_height, screen_width, n_actions).to(device)#Q
target_net = DQN(screen_height, screen_width, n_actions).to(device)#T

 
 
  • 1
  • 2

这两句我师妹问过我是什么意思
这个其实就是初始化模型。只是作者在写这个代码的时候还有其他参数,因此需要带参初始化。
正常情况,我们写一个模型时,初始化没这么麻烦。

2.7 选择动作的函数

#动作选择函数,首先看的就是探索和开发的阈值系数 eps[0,1]
def select_action(state):
    global steps_done
    sample = random.random()## 产生 0 到 1 之间的随机浮点数
    eps_threshold = EPS_END + (EPS_START - EPS_END) * \
        math.exp(-1.*steps_done / EPS_DECAY)#最小到0.427
    steps_done += 1
    if sample > eps_threshold:#判断是随即动作还是最优动作
        #sample是(0,1),eps_threshold越来越小,一开始是选择最优策略(开发)
        with torch.no_grad():#torch.no_grad()一般用于神经网络的推理阶段, 表示张量的计算过程中无需计算梯度
            return policy_net(state).max(1)[1].view(1, 1)#使用最优动作
    else:
        #到后期会越来越趋向于(探索),u而就是随机选择一个动作。
        return torch.tensor([[random.randrange(n_actions)]], device=device, dtype=torch.long)#随机选择动作
#random.randrange(N)在0-N之间随机生成一个数,N是动作空间数

 
 
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 这边主要解释一下这个控制eps变量的eps_threshold
    其实这是一个单调递减函数,我把这个函数的曲线画出来了。按照作者的意思,这个eps_threshold的最小值时0.427.看下图
    在这里插入图片描述
    大家可以按照下面的函数自己运行一下:
    需要提醒的是,我们可以在这个函数里的i后面乘以一个数来控制eps_threshold的最小值。
    比如我把在i后面乘以2,那么eps_threshold数值会下降2倍。
plt.figure(1)
ax = plt.subplot(111)
x = np.linspace(0, 1000, 1000)  # 在0到2pi之间,均匀产生200点的数组
print(x)
r1 = []
for i in range(1000):
    r = 0.1 + (0.99 - 0.1) * \
        math.exp(-1.*(i / 1000))
    r1.append(r)
print(r1)
ax.plot(x, r1)
plt.show()

 
 
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12

2.8 画图函数

episode_durations = []#存储训练过程数据的列表
def plot_durations():
    plt.figure(1)
    plt.clf()#清除当前图形及其所有轴,但保持窗口打开,以便可以将其重新用于其他绘图。有了这个再次运行就不要关掉所有figure了
    durations_t = torch.tensor(episode_durations, dtype=torch.float)#转换成张量。
    plt.title('Training...')#图的名字
    plt.xlabel('Episode')#x轴坐标名
    plt.ylabel('Duration')#y轴坐标名
    plt.plot(durations_t.numpy())#画图
    # Take 100 episode averages and plot them too
    if len(durations_t) >= 100:
        means = durations_t.unfold(0, 100, 1).mean(1).view(-1)
        means = torch.cat((torch.zeros(99), means))
        plt.plot(means.numpy())
plt<span class="token punctuation">.</span>pause<span class="token punctuation">(</span><span class="token number">0.001</span><span class="token punctuation">)</span>  <span class="token comment"># pause a bit so that plots are updated</span>
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16

这个没啥说的

2.9 优化器

def optimize_model():
    if len(memory) < BATCH_SIZE:#查看记忆池是否存满
        return
    transitions = memory.sample(BATCH_SIZE)#从记忆池中随即采集BATCH_SIZE个样本
    batch = Transition(*zip(*transitions))#zip表示交叉元素,*号代表拆分
    # Compute a mask of non-final states and concatenate the batch elements
    # 计算非最终状态的掩码并连接批处理元素
    # (a final state would've been the one after which simulation ended)
    # 最终的状态应该是模拟结束后的状态
    non_final_mask = torch.tensor(tuple(map(lambda s: s is not None, batch.next_state)),device=device, dtype=torch.bool)
    #首先分析map()函数,labbda是一个简单的函数。把transition中的next_state赋值给s。
    #tuple()将状态转换为元组,元组是无法修改的
    non_final_next_states = torch.cat([s for s in batch.next_state if s is not None])
    state_batch  = torch.cat(batch.state) #合并batch中的状态 32个,竖着合并到一起尺寸是:[32,[s]]
    action_batch = torch.cat(batch.action)#合并batch中的动作,竖着合并到一起尺寸是:[32,[a]]
    reward_batch = torch.cat(batch.reward)#合并batch中的奖励,竖着合并到一起尺寸是:[32,[r]]
<span class="token comment">#然后将这些数据,首先是state_batch按批次送到网络中,</span>
<span class="token comment">#策略函数输入状态:image,输出一个,512列的张量。在批处理中,应该是[32,1,512]</span>
state_action_values <span class="token operator">=</span> policy_net<span class="token punctuation">(</span>state_batch<span class="token punctuation">)</span><span class="token punctuation">.</span>gather<span class="token punctuation">(</span><span class="token number">1</span><span class="token punctuation">,</span> action_batch<span class="token punctuation">)</span><span class="token comment">#列号变动,因为是512列</span>
next_state_values <span class="token operator">=</span> torch<span class="token punctuation">.</span>zeros<span class="token punctuation">(</span>BATCH_SIZE<span class="token punctuation">,</span> device<span class="token operator">=</span>device<span class="token punctuation">)</span><span class="token comment">#32维的张量</span>
next_state_values<span class="token punctuation">[</span>non_final_mask<span class="token punctuation">]</span> <span class="token operator">=</span> target_net<span class="token punctuation">(</span>non_final_next_states<span class="token punctuation">)</span><span class="token punctuation">.</span><span class="token builtin">max</span><span class="token punctuation">(</span><span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">]</span><span class="token punctuation">.</span>detach<span class="token punctuation">(</span><span class="token punctuation">)</span>
<span class="token comment">#按行求最大值,并提取对应的最大值。</span>
expected_state_action_values <span class="token operator">=</span> reward_batch <span class="token operator">+</span> <span class="token punctuation">(</span>next_state_values <span class="token operator">*</span> GAMMA<span class="token punctuation">)</span><span class="token comment">#更新状态值函数</span>

<span class="token comment"># Compute Huber loss</span>
criterion <span class="token operator">=</span> nn<span class="token punctuation">.</span>MSELoss<span class="token punctuation">(</span><span class="token punctuation">)</span>
loss <span class="token operator">=</span> criterion<span class="token punctuation">(</span>state_action_values<span class="token punctuation">,</span> expected_state_action_values<span class="token punctuation">.</span>unsqueeze<span class="token punctuation">(</span><span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token comment">#计算损失函数</span>

<span class="token comment"># Optimize the model</span>
optimizer<span class="token punctuation">.</span>zero_grad<span class="token punctuation">(</span><span class="token punctuation">)</span>
loss<span class="token punctuation">.</span>backward<span class="token punctuation">(</span><span class="token punctuation">)</span>
<span class="token keyword">for</span> param <span class="token keyword">in</span> policy_net<span class="token punctuation">.</span>parameters<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">:</span>
    param<span class="token punctuation">.</span>grad<span class="token punctuation">.</span>data<span class="token punctuation">.</span>clamp_<span class="token punctuation">(</span><span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">)</span>
optimizer<span class="token punctuation">.</span>step<span class="token punctuation">(</span><span class="token punctuation">)</span>
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35

来了,来了。我之前跟学生讲课的时候经常说,向看懂一个代码,一个算法。一定要搞清楚他们数据的流向,以及数据尺寸的变换流程。
然后第一块需要详细了解代码,这两块代码跟上面两个说会在后面讲是一块的。我需要举个例子。

    transitions = memory.sample(BATCH_SIZE)#从记忆池中随即采集BATCH_SIZE个样本
    batch = Transition(*zip(*transitions))#zip表示交叉元素,*号代表拆分

 
 
  • 1
  • 2

首先第一行是从memory中随机抽取一批样本,我们默认是32.。
然后就是下面的batch了。我们具体举个例子,一看便知。

import torch
import random
from collections import namedtuple, deque
#创建一个双向数组,队列长度是100。跟上面一样的
memory = deque([], maxlen=100)
#定义我们的Transition 。跟上面一样的
Transition = namedtuple('Transition',('state', 'action', 'next_state', 'reward'))
#给Transition 实例化
s1 = Transition(2,3,4,5)
s2 = Transition(1,2,3,4)
s3 = Transition(1,4,5,2)
s4 = Transition(2,5,7,3)
#然后赋值给memory
memory.append(s1)
memory.append(s2)
memory.append(s3)
memory.append(s4)
print(memory)
#原始的memory是这样的
#deque([Transition(state=2, action=3, next_state=4, reward=5), Transition(state=1, action=2, next_state=3, reward=4), Transition(state=1, action=4, next_state=5, reward=2), Transition(state=2, action=5, next_state=7, reward=3)], maxlen=100)
#随机采样2个批次
m2 = random.sample(memory, 2)
#采样后是这样的
#[Transition(state=1, action=4, next_state=5, reward=2), Transition(state=2, action=3, next_state=4, reward=5)]
#来了来了,
batch = Transition(*zip(*m2))
print(batch)
#Transition(state=(1, 2), action=(4, 3), next_state=(5, 4), reward=(2, 5))
#batch = Transition(*zip(*transitions))这句代码的一些列操作为了把单个的s,a,r,s_都给合并到一起。
#接着上面的代码,我们逐行下下看数据的变换格式
non_final_mask = torch.tensor(tuple(map(lambda s: s is not None, batch.next_state)), dtype=torch.bool)
print(non_final_mask)
#输出的是:tensor([True, True])
#也就是说,这个non_final_mask生成的是bool型变量,判断该状态是不是最终状态。
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35

下面面大家按照这个转换格式,就知道最后走势如何处理的了。
下面看一下这个语句

    state_action_values = policy_net(state_batch).gather(1, action_batch)#列号标动,因为是2列

 
 
  • 1

这个gather其实不是理解的聚集。
而类似与Qtable中的查表。计算的是Q值

  • policy_net(state_batch)这部分输入的是48484的图像,输出的是一个3212的张量,表示动作Q值。32是批次
  • .gather(1, action_batch),参考这个博客
  • 主要是gather中的这个action_batch,这个变量是动作标号。
    解释这个模块目前来讲直接解释还是有点困难,因为它是在很多前处理之后的。
    我们先向下看:

2.10 随机开始

def random_start(skip_steps=30, m=4):
    env.reset()#重新初始化函数,智能体每进行一次尝试到达终止状态后,都要重新开始再尝试,所以需要智能体有重新初始化功能。
    state_queue = deque([], maxlen=m)     #  当前状态    m等于4表示采集四张图像,每采集4帧会跳30帧
    next_state_queue = deque([], maxlen=m)#下一个状态
    done = False#done又是是否结束
    for i in range(skip_steps):
        if (i+1) <= m:   #i<m表示还没采集满4张图,
            state_queue.append(get_screen())#则向状态序列中继续添加图像
        elif m < (i + 1) <= 2*m:#如果大于4张,小于8张,
            next_state_queue.append(get_screen())#则将这些图像保存到下一个状态
        else:
            state_queue.append(next_state_queue[0])
            #否则的话就是大于8张,就是大于两个状态的,把上一个nextstate中的图像放到这个当前的state_queue
            next_state_queue.append(get_screen())
            #把当前的图像继续存放到下一个状态中。
            #由于两个状态容器都是用deque()的方式,因此
    action <span class="token operator">=</span> env<span class="token punctuation">.</span>action_space<span class="token punctuation">.</span>sample<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token comment">#采集一个动作</span>
    _<span class="token punctuation">,</span> _<span class="token punctuation">,</span> done<span class="token punctuation">,</span> _ <span class="token operator">=</span> env<span class="token punctuation">.</span>step<span class="token punctuation">(</span>action<span class="token punctuation">)</span><span class="token comment">#输入动作action,输出为:下一步状态,立即回报,是否终止,调试信息</span>
    <span class="token keyword">if</span> done<span class="token punctuation">:</span>
        <span class="token keyword">break</span>
<span class="token keyword">return</span> done<span class="token punctuation">,</span> state_queue<span class="token punctuation">,</span> next_state_queue
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22

2.11 开始训练

# Start Training

num_episodes = 10000
m = 4 #4张图像S
for i_episode in range(num_episodes):#迭代10000次
# Initialize the environment and state初始化环境和状态
done, state_queue, next_state_queue = random_start()
if done:
continue
state = torch.cat(tuple(state_queue), dim=1)#状态转换成元组
for t in count():
reward = 0
m_reward = 0
# 每m帧完成一次action
action = select_action(state)#根据当前状态选择一个动作。
for i in range(m):
_, reward, done, _ = env.step(action.item())#与环境交互获取奖励和是否终止
if not done:#如果不是终止状态,则
next_state_queue.append(get_screen())#采集图像添加到下一个状态
else:#如果是终止状态(者打完),就跳出循环
break
m_reward += reward#增加奖励

    <span class="token keyword">if</span> <span class="token keyword">not</span> done<span class="token punctuation">:</span><span class="token comment">#如果没有结束,</span>
        next_state <span class="token operator">=</span> torch<span class="token punctuation">.</span>cat<span class="token punctuation">(</span><span class="token builtin">tuple</span><span class="token punctuation">(</span>next_state_queue<span class="token punctuation">)</span><span class="token punctuation">,</span> dim<span class="token operator">=</span><span class="token number">1</span><span class="token punctuation">)</span>
    <span class="token keyword">else</span><span class="token punctuation">:</span><span class="token comment">#如果结束,</span>
        next_state <span class="token operator">=</span> <span class="token boolean">None</span><span class="token comment">#没有下一个状态,表示是死亡</span>
        m_reward <span class="token operator">=</span> <span class="token operator">-</span><span class="token number">150</span><span class="token comment">#那么奖励直接-150</span>
    m_reward <span class="token operator">=</span> torch<span class="token punctuation">.</span>tensor<span class="token punctuation">(</span><span class="token punctuation">[</span>m_reward<span class="token punctuation">]</span><span class="token punctuation">,</span> device<span class="token operator">=</span>device<span class="token punctuation">)</span>
    memory<span class="token punctuation">.</span>push<span class="token punctuation">(</span>state<span class="token punctuation">,</span> action<span class="token punctuation">,</span> next_state<span class="token punctuation">,</span> m_reward<span class="token punctuation">)</span><span class="token comment">#将这个环节的transition添加memary中</span>
    state <span class="token operator">=</span> next_state<span class="token comment">#将这个nextstate更新为当前状态</span>
    optimize_model<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token comment">#开始优化模型</span>

    <span class="token keyword">if</span> done<span class="token punctuation">:</span><span class="token comment">#如果结束了,</span>
        episode_durations<span class="token punctuation">.</span>append<span class="token punctuation">(</span>t <span class="token operator">+</span> <span class="token number">1</span><span class="token punctuation">)</span><span class="token comment">#将过程数据添加到列表中</span>
        plot_durations<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token comment">#画图</span>
        <span class="token keyword">break</span>

<span class="token comment"># Update the target network, copying all weights and biases in DQN</span>
<span class="token keyword">if</span> i_episode <span class="token operator">%</span> TARGET_UPDATE <span class="token operator">==</span> <span class="token number">0</span><span class="token punctuation">:</span><span class="token comment">#怕那段是否达到指定步骤,到达指定步骤则更新target</span>
    target_net<span class="token punctuation">.</span>load_state_dict<span class="token punctuation">(</span>policy_net<span class="token punctuation">.</span>state_dict<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">)</span>
    torch<span class="token punctuation">.</span>save<span class="token punctuation">(</span>policy_net<span class="token punctuation">.</span>state_dict<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">,</span> <span class="token string">'weights/policy_net_weights_{0}.pth'</span><span class="token punctuation">.</span><span class="token builtin">format</span><span class="token punctuation">(</span>i_episode<span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token comment">#保存模型</span>

print(‘Complete’)
env.close()#关闭环境
torch.save(policy_net.state_dict(), ‘weights/policy_net_weights.pth’)

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47

在这里插入图片描述
详细细节大家直接运行代码可能会很麻烦
我自己写了个效地demo来验证数据的流程了

import random
import torch
from collections import namedtuple, deque

state_que = deque([], maxlen=4)

memory = deque([], maxlen=100)
Transition = namedtuple(‘Transition’,(‘state’, ‘action’, ‘next_state’, ‘reward’))
st1 = torch.rand(2,2)
st2 = torch.rand(2,2)
st3 = torch.rand(2,2)
st4 = torch.rand(2,2)

a1 = torch.ones(1)
a2 = torch.ones(1)
a3 = torch.ones(1)
a4 = torch.ones(1)

#模拟截屏代码get_screen,并将其处理成(1,1,84,84)的格式,在本文中,我是用图像格式为2*2
nst1 = torch.rand(2,2)#unsqueeze(0)
nst1 = nst1.unsqueeze(0)
nst1 = nst1.unsqueeze(0)
nst2 = torch.rand(2,2)
nst2 = nst2.unsqueeze(0)
nst2 = nst2.unsqueeze(0)
nst3 = torch.rand(2,2)
nst3 = nst3.unsqueeze(0)
nst3 = nst3.unsqueeze(0)
nst4 = torch.rand(2,2)
nst4 = nst4.unsqueeze(0)
nst4 = nst4.unsqueeze(0)

#将相应的变量添加到Transition中
s1 = Transition(st1,a1,nst1,5)
s2 = Transition(st2,a2,nst2,4)
s3 = Transition(st3,a3,nst3,2)
s4 = Transition(st4,a4,nst4,3)
#添加到state_que中
state_que.append(nst1)
state_que.append(nst2)
state_que.append(nst3)
state_que.append(nst4)
print(‘state_que’,state_que)
#转换成元组
print(‘转换成元组和拼接’)
state = torch.cat(tuple(state_que), dim=1)
print(‘state’,state)
print(‘statesize’,state.size())

memory.append(s1)
memory.append(s2)
memory.append(s3)
memory.append(s4)

#print(memory)

m2 = random.sample(memory, 2)
print(‘m2’,m2)
print()
batch = Transition(zip(m2))
print(‘zip*-----------------------’)
print(‘batch:000’,batch.state)
non_final_mask = torch.tensor(tuple(map(lambda s: s is not None, batch.state)), dtype=torch.bool)
print(non_final_mask)
state_batch = torch.cat(batch.next_state)
print(‘next_state_batch’,state_batch)
print('state_batch_size = ',state_batch.size())
action_batch = torch.cat(batch.action)
print(‘action_batch’,action_batch)

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值