目录
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
from torch.distributions import Normal
import random
import numpy as np
1.定义算法
1.1 建立Q网络和策略网络
class ValueNet(nn.Module):
def __init__(self, n_states, hidden_dim, init_w=3e-3):
super(ValueNet, self).__init__()
'''定义值网络
'''
self.linear1 = nn.Linear(n_states, hidden_dim) # 输入层
self.linear2 = nn.Linear(hidden_dim, hidden_dim) # 隐藏层
self.linear3 = nn.Linear(hidden_dim, 1)
self.linear3.weight.data.uniform_(-init_w, init_w) # 初始化权重
self.linear3.bias.data.uniform_(-init_w, init_w)
def forward(self, state):
x = F.relu(self.linear1(state))
x = F.relu(self.linear2(x))
x = self.linear3(x)
return x
class SoftQNet(nn.Module):
def __init__(self, n_states, n_actions, hidden_dim, init_w=3e-3):
super(SoftQNet, self).__init__()
'''定义Q网络,n_states, n_actions, hidden_dim, init_w分别为状态维度、动作维度隐藏层维度和初始化权重
'''
self.linear1 = nn.Linear(n_states + n_actions, hidden_dim)
self.linear2 = nn.Linear(hidden_dim, hidden_dim)
self.linear3 = nn.Linear(hidden_dim, 1)
self.linear3.weight.data.uniform_(-init_w, init_w)
self.linear3.bias.data.uniform_(-init_w, init_w)
def forward(self, state, action):
x = torch.cat([state, action], 1)
x = F.relu(self.linear1(x))
x = F.relu(self.linear2(x))
x = self.linear3(x)
return x
class PolicyNet(nn.Module):
def __init__(self, n_states, n_actions, hidden_dim, init_w=3e-3, log_std_min=-20, log_std_max=2):
super(PolicyNet, self).__init__()
'''定义策略网络,n_states, n_actions, hidden_dim, init_w分别为状态维度、动作维度隐藏层维度和初始化权重
log_std_min和log_std_max为标准差对数的最大值和最小值
'''
self.log_std_min = log_std_min
self.log_std_max = log_std_max
self.linear1 = nn.Linear(n_states, hidden_dim)
self.linear2 = nn.Linear(hidden_dim, hidden_dim)
self.mean_linear = nn.Linear(hidden_dim, n_actions)
self.mean_linear.weight.data.uniform_(-init_w, init_w)
self.mean_linear.bias.data.uniform_(-init_w, init_w)
self.log_std_linear = nn.Linear(hidden_dim, n_actions)
self.log_std_linear.weight.data.uniform_(-init_w, init_w)
self.log_std_linear.bias.data.uniform_(-init_w, init_w)
def forward(self, state):
x = F.relu(self.linear1(state))
x = F.relu(self.linear2(x))
mean = self.mean_linear(x)
log_std = self.log_std_linear(x)
log_std = torch.clamp(log_std, self.log_std_min, self.log_std_max)
return mean, log_std
def evaluate(self, state, epsilon=1e-6):
mean, log_std = self.forward(state)
std = log_std.exp()
## 计算动作
normal = Normal(mean, std)
z = normal.sample()
action = torch.tanh(z)
## 计算动作概率
log_prob = normal.log_prob(z) - torch.log(1 - action.pow(2) + epsilon)
log_prob = log_prob.sum(-1, keepdim=True)
return action, log_prob, z, mean, log_std
def get_action(self, state):
state = torch.FloatTensor(state).unsqueeze(