关于强化学习(RL)中网络结构设计的笔记和思考

3 篇文章 1 订阅
3 篇文章 1 订阅

问题背景:

强化学习的SAC算法(Soft Actor-Critic,SAC)下,对于代码框架结构中Q Network,Value Network的思考。尤其是Policy Network。

以下是SAC神经网络结构设计中的策略网络PolicyNetwork的代码,相较于Value Network, Soft Q Network,他除了网络结构,层数,节点的设计之外,还有forward, evaluate和get action这几个函数。我有以下问题想思考:

  1. 为什么Policy Network的结构设计,和Value Network, Soft Q Network有所不同?
  2. PolicyNetwork的功能是什么,作用是什么,在整个SAC框架中起了什么作用?
  3. PolicyNetwork的forward(),evaluate()以及get_action()的作用分别是什么?

代码描述

以下是代码部分,Policy策略网络每一行代码的作用使用注释说明:

class ValueNetwork(nn.Module):
    def __init__(self, state_dim, hidden_dim, init_w=3e-3):
        super(ValueNetwork, self).__init__()

        self.linear1 = nn.Linear(state_dim, hidden_dim)
        self.linear2 = nn.Linear(hidden_dim, hidden_dim)
        self.linear3 = nn.Linear(hidden_dim, 1)

        self.linear3.weight.data.uniform_(-init_w, init_w)
        self.linear3.bias.data.uniform_(-init_w, init_w)

    def forward(self, state):
        x = F.relu(self.linear1(state))
        x = F.relu(self.linear2(x))
        x = self.linear3(x)
        return x


class SoftQNetwork(nn.Module):
    def __init__(self, num_inputs, num_actions, hidden_size, init_w=3e-3):
        super(SoftQNetwork, self).__init__()

        self.linear1 = nn.Linear(num_inputs + num_actions, hidden_size)
        self.linear2 = nn.Linear(hidden_size, hidden_size)
        self.linear3 = nn.Linear(hidden_size, 1)

        self.linear3.weight.data.uniform_(-init_w, init_w)
        self.linear3.bias.data.uniform_(-init_w, init_w)

    def forward(self, state, action):
        x = torch.cat([state, action], 1)
        x = F.relu(self.linear1(x))
        x = F.relu(self.linear2(x))
        x = self.linear3(x)
        return x


class PolicyNetwork(nn.Module):
    def __init__(self, num_inputs, num_actions, hidden_size, init_w=3e-3, log_std_min=-20, log_std_max=2):
        super(PolicyNetwork, self).__init__()

        self.log_std_min = log_std_min
        self.log_std_max = log_std_max

        self.linear1 = nn.Linear(num_inputs, hidden_size)
        self.linear2 = nn.Linear(hidden_size, hidden_size)

        self.mean_linear = nn.Linear(hidden_size, num_actions)
        self.mean_linear.weight.data.uniform_(-init_w, init_w)
        self.mean_linear.bias.data.uniform_(-init_w, init_w)

        self.log_std_linear = nn.Linear(hidden_size, num_actions)
        self.log_std_linear.weight.data.uniform_(-init_w, init_w)
        self.log_std_linear.bias.data.uniform_(-init_w, init_w)

    def forward(self, state):
        x = F.relu(self.linear1(state))
        x = F.relu(self.linear2(x))

        mean    = self.mean_linear(x)
        log_std = self.log_std_linear(x)
        log_std = torch.clamp(log_std, self.log_std_min, self.log_std_max)

        return mean, log_std

    #这个方法主要用于训练期间
    def evaluate(self, state, epsilon=1e-6):
        mean, log_std = self.forward(state) #前向传播输入状态,得到动作的均值和对数标准差
        std = log_std.exp() #指数运算将对数标准差转换为标准差

        normal = Normal(0, 1) #创建一个标准正态分布
        z      = normal.sample() #标准分布中采样噪声
        action = torch.tanh(mean+ std*z.to(device)) #生成动作,基于标准正态分布中生成的随机噪声
        # we sample some noise from a Standard Normal distribution and multiply it with our standard deviation,
        # and then add the result to the mean.
        log_prob = Normal(mean, std).log_prob(mean+ std*z.to(device)) - torch.log(1 - action.pow(2) + epsilon) #算动作的对数概率,由正态分布的对数概率密度函数减去对动作执行tanh操作的雅可比行列式的对数得到
        return action, log_prob, z, mean, log_std 

        #action:策略网络生成的动作
        #log_prob:action的对数概率,用来计算loss
        #z:标准正态分布采样的随机噪声
        #mean:动作的均值
        #log_std:动作对数标准差,用以生成动作

    #这个方法主要用在策略执行阶段训练结束后,使用训练好的模型来决策,只返回一个动作
    def get_action(self, state):
        state = torch.FloatTensor(state).unsqueeze(0).to(device) #将输入状态转换为浮点张量,增加维度匹配网络输入形状
        mean, log_std = self.forward(state) #方法
        std = log_std.exp()

        normal = Normal(0, 1)
        z      = normal.sample().to(device)
        action = torch.tanh(mean + std*z)

        action  = action.cpu()#.detach().cpu().numpy() #将动作从GPU转移到CPU
        return action[0] #这里使用[0]是因为unsqueeze(0)增加了一个维度,所以我们需要使用[0]来移除这个额外的维度

代码分析:

部分回答借助 Chat-GPT4补充

  1. 为什么Policy Network的结构设计,和Value Network, Soft Q Network有所不同?

SAC算法中存在S(soft)的关键是我们鼓励agent在策略上偏向于有更多的探索。策略网络(即Policy Network)作用就是通过随机采样得到更多的动作。
所以在我的理解中,V网络(Value Network)用于整体状态评估,Q网络(Q Network)用于动作评估,这两个网络本质都是一个评分网络,而Policy Network则是基于评分得到新动作的策略。

  1. PolicyNetwork的功能是什么,作用是什么,在整个SAC框架中起了什么作用?

策略网络负责生成环境中的动作。

  1. PolicyNetwork的forward(),evaluate()以及get_action()的作用分别是什么?

forward()函数:网络结构前向传播,输入“状态”,返回“对应动作”均值和对数标准差。这里的“状态”是指当前环境下的状态,“对应动作”是指在此状态下,采取的预期动作或者决策,目标是学习策略,在当前状态下选择对应动作,使得奖励最大化,

evaluate()函数:在训练中使用,对于批量的动作,通过forward()函数和一系列概率参数,得到动作的均值(预期的动作)和标准差(动作的不确定性)。

get_action()函数:在测试中使用。对于一个单个的动作,执行和evaluate()函数类似的行为,得到结果。


补充代码说明

为什么get_action()会使用 state = torch.FloatTensor(state).unsqueeze(0).to(device)这句话而evaluate()不会使用?

evaluate()函数主要在训练过程中使用,训练时我们通常一次性处理一个批量的数据,而不是单个数据点。因此,我们一般会假设evaluate()的输入state已经是一个批量的数据,因此不需要使用unsqueeze(0)来增加一个额外的批量维度。

另一方面,get_action()函数主要在策略执行阶段使用,当我们使用训练好的模型来进行决策时,我们通常一次处理一个状态,因此输入state通常是一个单个的数据点。然而,由于模型期望接收批量数据,我们需要使用unsqueeze(0)来为单个数据点增加一个批量维度。

我们写一段简单的python代码来说明:

#理解 unsqueeze()的作用
import torch

# 创建一个形状为(3,)的一维张量
x = torch.tensor([1, 2, 3])
print(x)
print(f"Original tensor: {x}")
print(f"Shape of original tensor: {x.shape}\n")

# 使用unsqueeze(0)在第一维(索引为0)处增加一个维度
x_unsqueeze_0 = x.unsqueeze(0)
print(x_unsqueeze_0)
print(f"Tensor after unsqueeze(0): {x_unsqueeze_0}")
print(f"Shape of tensor after unsqueeze(0): {x_unsqueeze_0.shape}\n")

# 使用unsqueeze(1)在第二维(索引为1)处增加一个维度
x_unsqueeze_1 = x.unsqueeze(1)
print(x_unsqueeze_1)
print(f"Tensor after unsqueeze(1): {x_unsqueeze_1}")
print(f"Shape of tensor after unsqueeze(1): {x_unsqueeze_1.shape}\n")

输出:
在这里插入图片描述

总结

SAC网络下的evaluate()和get_action()函数,本质上都是得到动作,然后在函数式中计算其相关的均值,标准差,得到随机数,然后通过这些信息,生成新的动作。区别只是在于,evaluate()用在训练中,而get_action用在训练后。

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值