多智能体游戏环境PettingZoo_天授(Tianshou)训练智能体

本文展示了如何使用Tianshou库结合MARL训练井字棋游戏的智能体。训练过程中,作者添加了保存和加载模型的代码,利用pickle进行模型的序列化。训练完成后,智能体的模型被保存为pickle文件,但在尝试使用torch的`load_state_dict`方法加载模型时遇到问题。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

原始代码:https://pettingzoo.farama.org/tutorials/tianshou/intermediate/
代码使用的PettingZoo游戏环境是井字棋(Tic Tac Toe)。
我在原始代码里加了保存模型的代码,以便训练结束后可以随时拿出来测试、可视化训练的成果(智能体的表现)。
训练的代码写在train.py,测试的代码(我额外写的)写在t_ttt.py。

train.py

# 用天授训练智能体
''''''
"""This is a minimal example of using Tianshou with MARL to train agents.
这是将天授与MARL一起用于训练智能体的一个最小示例。
Author: Will (https://github.com/WillDudley)

Python version used: 3.8.10

Requirements:
pettingzoo == 1.22.0
git+https://github.com/thu-ml/tianshou
"""

import os
from typing import Optional, Tuple

import gym
import numpy as np
import torch
from tianshou.data import Collector, VectorReplayBuffer
from tianshou.env import DummyVectorEnv
from tianshou.env.pettingzoo_env import PettingZooEnv
from tianshou.policy import BasePolicy, DQNPolicy, MultiAgentPolicyManager, RandomPolicy
from tianshou.trainer import offpolicy_trainer
from tianshou.utils.net.common import Net

from pettingzoo.classic import tictactoe_v3

import pickle # 我用pickle保存模型

def _get_agents(
    agent_learn: Optional[BasePolicy] = None,
    agent_opponent: Optional[BasePolicy] = None,
    optim: Optional[torch.optim.Optimizer] = None,
) -> Tuple[BasePolicy, torch.optim.Optimizer, list]:
    env = _get_env()
    observation_space = (
        env.observation_space["observation"]
        if isinstance(env.observation_space, gym.spaces.Dict)
        else env.observation_space
    )
    if agent_learn is None:  # 这里的学习器默认用DQN算法。
        # model
        net = Net(
            state_shape=observation_space["observation"].shape
            or observation_space["observation"].n,
            action_shape=env.action_space.shape or env.action_space.n,
            hidden_sizes=[128, 128, 128, 128],
            device="cuda" if torch.cuda.is_available() else "cpu",
        ).to("cuda" if torch.cuda.is_available() else "cpu")
        if optim is None:
            optim = torch.optim.Adam(net.parameters(), lr=1e-4)
        agent_learn = DQNPolicy(
            model=net,
            optim=optim,
            discount_factor=0.9,
            estimation_step=3,
            target_update_freq=320,
        )

    if agent_opponent is None:  # 这里的对手默认用随机策略
        agent_opponent = RandomPolicy()

    agents = [agent_opponent, agent_learn] # 这里设置了2个智能体,第0号是对手,第1号是学习器。
    policy = MultiAgentPolicyManager(agents, env)
    return policy, optim, env.agents


def _get_env():
    """This function is needed to provide callables for DummyVectorEnv."""
    '''此函数是为DummyVectorEnv提供可调用程序所必需的。'''
    return PettingZooEnv(tictactoe_v3.env())


if __name__ == "__main__":
    # ======== Step 1: Environment setup =========
    # ======== 第一步:环境设置 =========
    train_envs = DummyVectorEnv([_get_env for _ in range(10)])  # 训练的环境
    test_envs = DummyVectorEnv([_get_env for _ in range(10)])  # 测试的环境

    # seed
    seed = 1
    np.random.seed(seed)
    torch.manual_seed(seed)
    train_envs.seed(seed)
    test_envs.seed(seed)

    # ======== Step 2: Agent setup =========
    # ======== 第二步:智能体设置 ========
    policy, optim, agents = _get_agents()

    # ======== Step 3: Collector setup =========
    # ======== 第三步:采集器设置 ========
    train_collector = Collector(
        policy,
        train_envs,
        VectorReplayBuffer(20_000, len(train_envs)),
        exploration_noise=True,
    )
    test_collector = Collector(policy, test_envs, exploration_noise=True)
    # policy.set_eps(1)
    train_collector.collect(n_step=64 * 10)  # batch size * training_num

    # ======== Step 4: Callback functions setup =========
    # ======== 第四步:回调函数设置 ========
    def save_best_fn(policy):
        model_save_path = os.path.join("log", "rps", "dqn", "policy.pth")
        os.makedirs(os.path.join("log", "rps", "dqn"), exist_ok=True)
        torch.save(policy.policies[agents[1]].state_dict(), model_save_path)

    def stop_fn(mean_rewards):
        return mean_rewards >= 0.6

    def train_fn(epoch, env_step):
        policy.policies[agents[1]].set_eps(0.1)

    def test_fn(epoch, env_step):
        policy.policies[agents[1]].set_eps(0.05)

    def reward_metric(rews):
        return rews[:, 1]

    # ======== Step 5: Run the trainer =========
    # ======== 第五步:运行训练器 ========
    result = offpolicy_trainer(  # 这里使用了异策略的技巧
        policy=policy,
        train_collector=train_collector,
        test_collector=test_collector,
        max_epoch=50,
        step_per_epoch=1000,
        step_per_collect=50,
        episode_per_test=10,
        batch_size=64,
        train_fn=train_fn,
        test_fn=test_fn,
        stop_fn=stop_fn,
        save_best_fn=save_best_fn,
        update_per_step=0.1,
        test_in_train=False,
        reward_metric=reward_metric,
    )

    # return result, policy.policies[agents[1]]
    print(f"\n==========Result==========\n{result}")
    print("\n(the trained policy can be accessed via policy.policies[agents[1]])")
    # 训练好的策略可以通过policy.policies[agents[1]]访问(存取)

    # 以下代码是我加上的:
    # 保存模型参数:
    torch.save(policy.policies[agents[1]].state_dict(), 'tictactoe_dqn.pth') # 这里仅保存了第1号智能体(学习器)
    

t_ttt.py

# 把训练好的智能体拿来测试
# 'tictactoe_dqn.pth'
#import torch
import tianshou as ts
from pettingzoo.classic import tictactoe_v3
import pickle

env=tictactoe_v3.env(render_mode="human")
env = ts.env.PettingZooEnv(env)

device="cuda" if torch.cuda.is_available() else "cpu"
net = Net(
    state_shape=env.observation_space["observation"].shape
                or env.observation_space["observation"].n,
    action_shape=env.action_space.shape or env.action_space.n,
    hidden_sizes=[128, 128, 128, 128],
    device=device,
).to(device)
optim = torch.optim.Adam(net.parameters(), lr=1e-4)
p1=DQNPolicy(
            model=net,
            optim=optim,
            discount_factor=0.9,
            estimation_step=3,
            target_update_freq=320,
        )
#加载模型参数
p1.load_state_dict(torch.load('tictactoe_dqn.pth'))

p2=ts.policy.RandomPolicy()
po=ts.policy.MultiAgentPolicyManager([p1, p2], env)

env = ts.env.DummyVectorEnv([lambda: env])

collector = ts.data.Collector(po, env)
result = collector.collect(n_episode=1, render=0.2)
print(result)

输出如下(截图不完全):
在这里插入图片描述

目前我遇到的问题是:使用Tianshou的方法【policy.load_state_dict(torch.load(‘tictactoe_dqn.pth’))】加载模型不行,总是提示没有这个函数。所以我仍然使用pickle来保存和加载模型。

### Tianshou 强化学习框架简介 Tianshou 是一个基于 PyTorch 的强化学习库,旨在提供简洁而灵活的接口以便于研究者快速开发和测试新的算法。它支持多种经典的强化学习方法以及最新的研究成果,并提供了模块化的结构设计使得用户可以轻松定制自己的实验环境[^6]。 #### 主要特点 - **模块化设计**:Tianshou 将策略、收集器、回放缓冲区等功能分离成独立组件,方便组合使用。 - **高效训练流程**:内置高效的采样机制与并行数据采集工具,提升整体性能。 - **易于扩展**:允许研究人员自由定义新模型或者调整现有实现细节。 - **文档齐全**:官方维护详尽的技术文档及示例程序帮助初学者入门。 以下是创建一个简单的 DQN 模型并通过 Gym 环境进行训练的例子: ```python from tianshou.env import SubprocVectorEnv from tianshou.policy import DQNPolicy from tianshou.trainer import offpolicy_trainer from tianshou.data import Collector, ReplayBuffer import torch.nn as nn import gym def make_env(): return gym.make('CartPole-v0') net = Net(state_shape=4, action_shape=2).cuda() optim = torch.optim.Adam(net.parameters(), lr=1e-3) policy = DQNPolicy( model=net, optim=optim, discount_factor=0.9, estimation_step=3, target_update_freq=320 ) train_collector = Collector(policy=policy, env=SubprocVectorEnv([make_env for _ in range(8)]), buffer=ReplayBuffer(size=20000)) test_collector = Collector(policy=policy, env=gym.make('CartPole-v0')) result = offpolicy_trainer( policy=policy, train_collector=train_collector, test_collector=test_collector, max_epoch=10, step_per_epoch=1000, collect_per_step=10, update_per_step=0.1, batch_size=64, save_best=True ) ``` 此代码片段展示了如何利用 Tianshou 构建基础离线策略训练过程中的几个核心要素——网络初始化、优化器配置、策略实例化、数据集构建以及最终调用 `offpolicy_trainer` 函数完成整个循环逻辑控制[^7]。 ### 学习资源推荐 对于希望深入理解该框架的朋友来说,可以从以下几个方面入手: 1. 阅读项目主页上的 README 文件获取概览信息; 2. 浏览 API 文档熟悉各个类别的功能描述及其参数选项; 3. 查看具体案例分析掌握实际应用场景下的编码技巧;
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值