class PolicyNet(torch.nn.Module):
def __init__(self, state_dim, hidden_dim, action_dim):
super(PolicyNet, self).__init__()
self.fc1 = torch.nn.Linear(state_dim, hidden_dim)
self.fc2 = torch.nn.Linear(hidden_dim, hidden_dim)
self.fc3 = torch.nn.Linear(hidden_dim, action_dim)
def forward(self, x):
x = F.relu(self.fc2(F.relu(self.fc1(x))))
return nn.Softmax()(self.fc3(x))
#return F.softmax(self.fc3(x), dim=-1)
#
#
net=PolicyNet(10,5,2)
data=torch.randn([1,10])
data_2=torch.randn([5,10])
print(net(data))
print(net(data_2))
nn.softmax()没有dim这个参数,所以使用dim参数会报错,但不使用这个参数会有警告:
UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.
class PolicyNet(torch.nn.Module):
def __init__(self, state_dim, hidden_dim, action_dim):
super(PolicyNet, self).__init__()
self.fc1 = torch.nn.Linear(state_dim, hidden_dim)
self.fc2 = torch.nn.Linear(hidden_dim, hidden_dim)
self.fc3 = torch.nn.Linear(hidden_dim, action_dim)
def forward(self, x):
x = F.relu(self.fc2(F.relu(self.fc1(x))))
#return nn.Softmax()(self.fc3(x),dim=-1)
return F.softmax(self.fc3(x), dim=-1)
#
#
net=PolicyNet(10,5,2)
data=torch.randn([1,10])
data_2=torch.randn([5,10])
print(net(data))
print(net(data_2))
F.softmax()没有问题