import gym
import torch
import torch.nn as nn
import torch.optim as optim
import pygame
import sys
# 定义Actor网络
class Actor(nn.Module):
def __init__(self):
super(Actor, self).__init__()
self.fc = nn.Sequential(
nn.Linear(4, 10),
nn.ReLU(),
nn.Linear(10, 2),
nn.Softmax(dim=-1)
)
def forward(self, x):
return self.fc(x)
# 定义Critic网络
class Critic(nn.Module):
def __init__(self):
super(Critic, self).__init__()
self.fc = nn.Sequential(
nn.Linear(4, 10),
nn.ReLU(),
nn.Linear(10, 1)
)
def forward(self, x):
return self.fc(x)
# 训练模型
def train(actor, critic, actor_optimizer, critic_optimizer, state, action, reward, next_state, done):
state = torch.tensor(state, dtype=torch.float)
next_state = torch.tensor(next_state, dtype=torch.float)
action = torch.tensor(action, dtype=torch.long)
reward = torch.tensor(reward, dtype=torch.float)
if done:
next_value = 0
else:
next_value = critic(next_state).detach()
# Critic loss
value = critic(state)
expected_value = reward + 0.99 * next_value
critic_loss = (value - expected_value).pow(2).mean()
# Actor loss
probs = actor(state)
dist = torch.distributions.Categorical(probs)
log_prob = dist.log_prob(action)
advantage = (expected_value - value).detach() # TD error as advantage
actor_loss = -log_prob * advantage
# Update networks
critic_optimizer.zero_grad()
critic_loss.backward()
critic_optimizer.step()
actor_optimizer.zero_grad()
actor_loss.backward()
actor_optimizer.step()
# 设置环境和模型
env = gym.make('CartPole-v1')
actor = Actor()
critic = Critic()
actor_optimizer = optim.Adam(actor.parameters(), lr=0.001)
critic_optimizer = optim.Adam(critic.parameters(), lr=0.01)
pygame.init()
screen = pygame.display.set_mode((600, 400))
clock = pygame.time.Clock()
# 开始训练
for episode in range(10000):
state = env.reset()
done = False
state = state[0]
step= 0
while not done:
step += 1
state_tensor = torch.tensor(state, dtype=torch.float)
probs = actor(state_tensor)
dist = torch.distributions.Categorical(probs)
action = dist.sample().item()
next_state, reward, done, _ ,_= env.step(action)
train(actor, critic, actor_optimizer, critic_optimizer, state, action, reward, next_state, done)
state = next_state
# Pygame visualization
for event in pygame.event.get():
if event.type == pygame.QUIT:
pygame.quit()
sys.exit()
# Drawing
screen.fill((255, 255, 255))
cart_x = int(state[0] * 100 + 300)
pygame.draw.rect(screen, (0, 0, 255), (cart_x, 300, 50, 30))
pygame.draw.line(screen, (255, 0, 0), (cart_x + 25, 300), (cart_x + 25 - int(50 * torch.sin(torch.tensor(state[2]))), 300 - int(50 * torch.cos(torch.tensor(state[2])))), 5)
pygame.display.flip()
clock.tick(200)
print(f"第{episode}回合,玩{step}次挂了")
actor critic 玩carpole游戏
最新推荐文章于 2024-09-20 10:40:49 发布