设计网络接受批量/单个输入

import torch  
import torch.nn as nn  
import torch.nn.functional as F  
  
class PolicyValueNet(nn.Module):  
    def __init__(self, input_size, num_actions):  
        super().__init__()  
        self.fc1 = nn.Linear(input_size, 128)  
        self.fc_pi = nn.Linear(128, num_actions)  
        self.fc_v = nn.Linear(128, 1)  
  
    def forward(self, x):  
        x = F.relu(self.fc1(x))  
        log_act_probs = F.log_softmax(self.fc_pi(x), dim=1)  
        state_value = self.fc_v(x)  
        return log_act_probs, state_value  
  
# 假设我们的游戏状态是一个长度为64的向量,并且有4个可能的动作  
input_size = 64  
num_actions = 4  
  
# 创建一个网络实例  
net = PolicyValueNet(input_size, num_actions)  
  
# 示例:处理单个状态输入  
single_state = torch.randn(1, input_size)  # 假设这是一个随机生成的游戏状态  
log_probs, value = net(single_state)  
print("Single state log probabilities:", log_probs)  
print("Single state value:", value)  
  
# 示例:处理批量状态输入(一个批次包含3个状态)  
batch_states = torch.randn(3, input_size)  # 假设这是一个包含3个随机生成的游戏状态的批次  
log_probs_batch, value_batch = net(batch_states)  
print("Batch state log probabilities:", log_probs_batch)  
print("Batch state values:", value_batch)

在这个例子中,PolicyValueNet 类定义了一个简单的神经网络,它有一个共享的全连接层(fc1),然后分别有两个输出层:一个用于策略(动作概率分布,fc_pi),另一个用于价值估计(fc_v)。

在 forward 方法中,输入 x 被传递给网络,并经过一系列计算得到对数动作概率 log_act_probs 和状态价值 state_value。无论是单个状态还是批量状态,这个前向传播过程都是相同的。

*注意,对于单个状态,我们将其形状设置为 (1, input_size),以便与批量输入的形状 (batch_size, input_size) 保持一致。这样,网络就可以无缝地处理这两种情况。

  1. 批次维度:在 (1, input_size) 这个形状中,第一个维度 1 表示批次大小。在神经网络的前向传播中,批次大小可以是任意的(包括1),只要特征维度(即第二个维度)与全连接层期望的输入维度匹配。

  2. 特征维度:第二个维度 input_size 表示每个样本的特征数量。这个维度必须与全连接层期望的输入特征数量相匹配。在你的例子中,input_size 既是全连接层期望的输入特征数量,也是单个样本的特征数量。

当 single_state 被传递给网络时,PyTorch 会自动处理这个批次大小为1的输入。全连接层会忽略批次维度(因为它只关心特征维度),并对每个样本(在这个例子中只有一个样本)执行相同的计算。

简而言之,能够处理 (1, input_size) 形状的输入是因为全连接层只关心特征维度 input_size,而批次大小可以是任意的(包括1)。

  • 5
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值