import torch
import torch.nn as nn
import torch.nn.functional as F
class PolicyValueNet(nn.Module):
def __init__(self, input_size, num_actions):
super().__init__()
self.fc1 = nn.Linear(input_size, 128)
self.fc_pi = nn.Linear(128, num_actions)
self.fc_v = nn.Linear(128, 1)
def forward(self, x):
x = F.relu(self.fc1(x))
log_act_probs = F.log_softmax(self.fc_pi(x), dim=1)
state_value = self.fc_v(x)
return log_act_probs, state_value
# 假设我们的游戏状态是一个长度为64的向量,并且有4个可能的动作
input_size = 64
num_actions = 4
# 创建一个网络实例
net = PolicyValueNet(input_size, num_actions)
# 示例:处理单个状态输入
single_state = torch.randn(1, input_size) # 假设这是一个随机生成的游戏状态
log_probs, value = net(single_state)
print("Single state log probabilities:", log_probs)
print("Single state value:", value)
# 示例:处理批量状态输入(一个批次包含3个状态)
batch_states = torch.randn(3, input_size) # 假设这是一个包含3个随机生成的游戏状态的批次
log_probs_batch, value_batch = net(batch_states)
print("Batch state log probabilities:", log_probs_batch)
print("Batch state values:", value_batch)
在这个例子中,PolicyValueNet
类定义了一个简单的神经网络,它有一个共享的全连接层(fc1
),然后分别有两个输出层:一个用于策略(动作概率分布,fc_pi
),另一个用于价值估计(fc_v
)。
在 forward
方法中,输入 x
被传递给网络,并经过一系列计算得到对数动作概率 log_act_probs
和状态价值 state_value
。无论是单个状态还是批量状态,这个前向传播过程都是相同的。
*注意,对于单个状态,我们将其形状设置为 (1, input_size)
,以便与批量输入的形状 (batch_size, input_size)
保持一致。这样,网络就可以无缝地处理这两种情况。
-
批次维度:在
(1, input_size)
这个形状中,第一个维度1
表示批次大小。在神经网络的前向传播中,批次大小可以是任意的(包括1),只要特征维度(即第二个维度)与全连接层期望的输入维度匹配。 -
特征维度:第二个维度
input_size
表示每个样本的特征数量。这个维度必须与全连接层期望的输入特征数量相匹配。在你的例子中,input_size
既是全连接层期望的输入特征数量,也是单个样本的特征数量。
当 single_state
被传递给网络时,PyTorch 会自动处理这个批次大小为1的输入。全连接层会忽略批次维度(因为它只关心特征维度),并对每个样本(在这个例子中只有一个样本)执行相同的计算。
简而言之,能够处理 (1, input_size)
形状的输入是因为全连接层只关心特征维度 input_size
,而批次大小可以是任意的(包括1)。