代码地址:https://github.com/sfujim/TD3
Class TD3:
方法1:
select action:
state = torch.FloatTensor(state.reshape(1, -1)).to(device) #转化成1行并且变成张量
return self.actor(state).cpu().data.numpy().flatten() #将张量拉平
将状态state传入actor神经网络,得到该状态下Q值最高的action。
方法2:
train:
def train(self, replay_buffer, batch_size=256):
self.total_it += 1
# Sample replay buffer
state, action, next_state, reward, not_done = replay_buffer.sample(batch_size)
with torch.no_grad():
# Select action according to policy and add clipped noise
noise = (
torch.randn_like(action) * self.policy_noise
).clamp(-self.noise_clip, self.noise_clip)
next_action = (
self.actor_targe