1.
整个逻辑线路比较清晰,关于导入的包,像Module,seq等都属于tensordict;
其他的一些包大部分数据torchrl;
import time
from pathlib import Path
import numpy as np
import torch
from tensordict.nn import (
TensorDictModule as Mod,
TensorDictSequential,
TensorDictSequential as Seq,
)
from torch.optim import Adam
from torchrl._utils import timeit
from torchrl.collectors import SyncDataCollector
from torchrl.data import LazyTensorStorage, ReplayBuffer
from torchrl.envs import (
Compose,
GrayScale,
GymEnv,
Resize,
set_exploration_type,
StepCounter,
ToTensorImage,
TransformedEnv,
)
from torchrl.modules import ConvNet, EGreedyModule, QValueModule
from torchrl.objectives import DQNLoss, SoftUpdate
torch.manual_seed(0)
env = TransformedEnv(
GymEnv("ALE/Pong-v5", categorical_action_encoding=True),
Compose(
ToTensorImage(), Resize(84, interpolation="nearest"), GrayScale(), StepCounter()
),
)
env.set_seed(0)
value_mlp = ConvNet.default_atari_dqn(num_actions=env.action_spec.space.n)
value_net = Mod(value_mlp, in_keys=["pixels"], out_keys=["action_value"])
policy = Seq(value_net, QValueModule(spec=env.action_spec))
exploration_module = EGreedyModule(
env.action_spec, annealing_num_steps=100_000, eps_init=0.5
)
policy_explore = Seq(policy, exploration_module)
init_rand_steps = 5000
frames_per_batch = 100
optim_steps = 10
collector = SyncDataCollector(
env,
policy_explore,
frames_per_batch=frames_per_batch,
total_frames=-1,
init_random_frames=init_rand_steps,
)
rb = ReplayBuffer(storage=LazyTensorStorage(100_000))
loss = DQNLoss(value_network=policy, action_space=env.action_spec, delay_value=True)
optim = Adam(loss.parameters())
updater = SoftUpdate(loss, eps=0.99)
total_count = 0
total_episodes = 0
t0 = time.time()
for data in collector:
# Write data in replay buffer
rb.extend(data)
max_length = rb[:]["next", "step_count"].max()
if len(rb) > init_rand_steps:
# Optim loop (we do several optim steps
# per batch collected for efficiency)
for _ in range(optim_steps):
sample = rb.sample(128)
loss_vals = loss(sample)
loss_vals["loss"].backward()
optim.step()
optim.zero_grad()
# Update exploration factor
exploration_module.step(data.numel())
# Update target params
updater.step()
total_count += data.numel()
total_episodes += data["next", "done"].sum()
if max_length > 200:
break
5180

被折叠的 条评论
为什么被折叠?



