ML-agents的ball3D代码解析

代码全部Ball3DAgent.cs

using UnityEngine;
using Unity.MLAgents;
using Unity.MLAgents.Actuators;
using Unity.MLAgents.Sensors;
using Random = UnityEngine.Random;

public class Ball3DAgent : Agent
{
    [Header("Specific to Ball3D")]
    public GameObject ball;
    [Tooltip("Whether to use vector observation. This option should be checked " +
        "in 3DBall scene, and unchecked in Visual3DBall scene. ")]
    public bool useVecObs;
    Rigidbody m_BallRb;
    EnvironmentParameters m_ResetParams;

    public override void Initialize()
    {
        m_BallRb = ball.GetComponent<Rigidbody>();
        m_ResetParams = Academy.Instance.EnvironmentParameters;
        SetResetParameters();
    }
    // 会收集当前游戏的各种环境,包括智能体的位置,速度等信息
    public override void CollectObservations(VectorSensor sensor)
    {
        if (useVecObs)
        {
            sensor.AddObservation(gameObject.transform.rotation.z);
            sensor.AddObservation(gameObject.transform.rotation.x);
            sensor.AddObservation(ball.transform.position - gameObject.transform.position);
            sensor.AddObservation(m_BallRb.velocity);
        }
    }
    // 实现的是整个游戏中一个Step中的操作,接收神经网络的输出,使其转换为智能体的动作,设置奖励函数,并且判断游戏是否结束。
    public override void OnActionReceived(ActionBuffers actionBuffers)
    {
        var actionZ = 2f * Mathf.Clamp(actionBuffers.ContinuousActions[0], -1f, 1f);
        var actionX = 2f * Mathf.Clamp(actionBuffers.ContinuousActions[1], -1f, 1f);

        if ((gameObject.transform.rotation.z < 0.25f && actionZ > 0f) ||
            (gameObject.transform.rotation.z > -0.25f && actionZ < 0f))
        {
            gameObject.transform.Rotate(new Vector3(0, 0, 1), actionZ);
        }

        if ((gameObject.transform.rotation.x < 0.25f && actionX > 0f) ||
            (gameObject.transform.rotation.x > -0.25f && actionX < 0f))
        {
            gameObject.transform.Rotate(new Vector3(1, 0, 0), actionX);
        }
        if ((ball.transform.position.y - gameObject.transform.position.y) < -2f ||
            Mathf.Abs(ball.transform.position.x - gameObject.transform.position.x) > 3f ||
            Mathf.Abs(ball.transform.position.z - gameObject.transform.position.z) > 3f)
        {
            SetReward(-1f);
            EndEpisode();
        }
        else
        {
            SetReward(0.1f);
        }
    }

    public override void OnEpisodeBegin()
    {
        gameObject.transform.rotation = new Quaternion(0f, 0f, 0f, 0f);
        gameObject.transform.Rotate(new Vector3(1, 0, 0), Random.Range(-10f, 10f));
        gameObject.transform.Rotate(new Vector3(0, 0, 1), Random.Range(-10f, 10f));
        m_BallRb.velocity = new Vector3(0f, 0f, 0f);
        ball.transform.position = new Vector3(Random.Range(-1.5f, 1.5f), 4f, Random.Range(-1.5f, 1.5f))
            + gameObject.transform.position;
        //Reset the parameters when the Agent is reset.
        SetResetParameters();
    }

    public override void Heuristic(in ActionBuffers actionsOut)
    {
        var continuousActionsOut = actionsOut.ContinuousActions;
        continuousActionsOut[0] = -Input.GetAxis("Horizontal");
        continuousActionsOut[1] = Input.GetAxis("Vertical");
    }

    public void SetBall()
    {
        //Set the attributes of the ball by fetching the information from the academy
        m_BallRb.mass = m_ResetParams.GetWithDefault("mass", 1.0f);
        var scale = m_ResetParams.GetWithDefault("scale", 1.0f);
        ball.transform.localScale = new Vector3(scale, scale, scale);
    }

    public void SetResetParameters()
    {
        SetBall();
    }
}

局部解析

Max Step默认生成,训练最大轮数,迭代次数

在这里插入图片描述

    [Header("Specific to Ball3D")]

此句的功能是在inference里面会显示一个标题

在这里插入图片描述

    public GameObject ball;

初始化一个GameObject实例

    [Tooltip("Whether to use vector observation. This option should be checked " +
        "in 3DBall scene, and unchecked in Visual3DBall scene. ")]
    public bool useVecObs;

上面的Tooltip这个装饰器的功能是在useVecObs上加入一个备注,当鼠标放在useVecobs上会自动弹出这个

在这里插入图片描述

    Rigidbody m_BallRb;
    EnvironmentParameters m_ResetParams;

这两句话的目的是初始化一个刚体变量,此外初始化一个环境参数的容器。

其中EnvironmentParameters是ML-Agents里面的

这里面提到了一个C#关键字 sealed关键字

作用

(1)在类中使用sealed修饰符可防止其他类继承此类

(2)在方法声明中使用sealed修饰符可防止扩充类重写此方法

具体内容可参考C#sealed关键字

Initialize() 方法

Initialize方法,初始化环境,获取组件信息,设置参数在这里完成。

    public override void Initialize()
    {
        m_BallRb = ball.GetComponent<Rigidbody>();
        m_ResetParams = Academy.Instance.EnvironmentParameters;
        SetResetParameters();
    }

调用SetBall方法,用于设置球体的属性。

这个设计允许在需要时重新设置球体的属性,例如,在环境重置时。

    public void SetResetParameters()
    {
        SetBall();
    }
    public void SetBall()
    {
        //Set the attributes of the ball by fetching the information from the academy
        m_BallRb.mass = m_ResetParams.GetWithDefault("mass", 1.0f);
        var scale = m_ResetParams.GetWithDefault("scale", 1.0f);
        ball.transform.localScale = new Vector3(scale, scale, scale);
    }

获取小球的组件参数和环境参数,并且设置初始化条件

接下来逐句解析Initialize

m_BallRb = ball.GetComponent<Rigidbody>();

获取球体的Rigidbody组件,用于后续的物理属性设置。

Rigidbody是Unity中用于添加物理行为的组件,如重力、碰撞等。

m_ResetParams = Academy.Instance.EnvironmentParameters;

获取环境参数,这些参数由Academy类管理,用于在训练过程中调整环境。

Academy相当于环境的意思;

SetResetParameters();

调用SetBall,通过m_ResetParams.GetWithDefault("mass", 1.0f);获取名为mass的环境参数,如果该参数不存在,则默认值为1.0f。这个值被用来设置球体的质量(m_BallRb.mass

m_ResetParams.GetWithDefault("scale", 1.0f);以同样的方式获取名为scale的环境参数,用作球体的缩放比例。如果scale参数不存在,同样默认为1.0f

然后,这个缩放值被应用于球体的localScale属性,这是通过创建一个新的Vector3实例并将其赋值给ball.transform.localScale来实现的。Vector3(scale, scale, scale)表示在三个维度(X、Y、Z轴)上都应用相同的缩放比例,这样球体就会等比例缩放。

SetBall展示了如何在Unity ML-Agents框架中,通过从环境参数中获取值来动态调整游戏对象的物理属性和外观。这种方法使得在机器学习训练过程中,能够灵活地调整训练环境,以适应不同的训练需求和目标。

CollectObservations方法

会收集当前游戏的各种环境,包括智能体的位置,速度等信息给sensor

gameObject是游戏对象,这里就是agent

//需要调用
using Unity.MLAgents.Sensors;

public override void CollectObservations(VectorSensor sensor)
    {
        if (useVecObs)
        {
            sensor.AddObservation(gameObject.transform.rotation.z);
            sensor.AddObservation(gameObject.transform.rotation.x);
            sensor.AddObservation(ball.transform.position - gameObject.transform.position);
            sensor.AddObservation(m_BallRb.velocity);
        }
    }

使用sensor.AddObservation方法添加观察数据到传感器(VectorSensor对象)中通过这些观察数据,智能体可以学习如何根据球体的位置和移动情况来调整自己的行为。

♢ \diamondsuit OnActionReceived方法

    public override void OnActionReceived(ActionBuffers actionBuffers)
    {
        var actionZ = 2f * Mathf.Clamp(actionBuffers.ContinuousActions[0], -1f, 1f);
        var actionX = 2f * Mathf.Clamp(actionBuffers.ContinuousActions[1], -1f, 1f);

        if ((gameObject.transform.rotation.z < 0.25f && actionZ > 0f) ||
            (gameObject.transform.rotation.z > -0.25f && actionZ < 0f))
        {
            gameObject.transform.Rotate(new Vector3(0, 0, 1), actionZ);
        }

        if ((gameObject.transform.rotation.x < 0.25f && actionX > 0f) ||
            (gameObject.transform.rotation.x > -0.25f && actionX < 0f))
        {
            gameObject.transform.Rotate(new Vector3(1, 0, 0), actionX);
        }
        if ((ball.transform.position.y - gameObject.transform.position.y) < -2f ||
            Mathf.Abs(ball.transform.position.x - gameObject.transform.position.x) > 3f ||
            Mathf.Abs(ball.transform.position.z - gameObject.transform.position.z) > 3f)
        {
            SetReward(-1f);
            EndEpisode();
        }
        else
        {
            SetReward(0.1f);
        }
    }

这段代码是Unity ML-Agents框架中的一部分,用于定义智能体如何根据接收到的动作来响应环境。OnActionReceived方法是智能体接收动作并执行相应逻辑的核心方法。

        var actionZ = 2f * Mathf.Clamp(actionBuffers.ContinuousActions[0], -1f, 1f);
        var actionX = 2f * Mathf.Clamp(actionBuffers.ContinuousActions[1], -1f, 1f);

通过**actionBuffers.ContinuousActions获取连续动作的数组**,并对数组中的第一个和第二个元素进行处理。这些动作值被限制在-1到1之间,并乘以2,得到actionZ和actionX获取连续动作的数组,并对数组中的第一个和第二个元素进行处理。这些动作值被限制在-1到1之间,并乘以2,得到actionZactionX。这些值代表智能体在Z轴和X轴上的旋转动作。


        if ((gameObject.transform.rotation.z < 0.25f && actionZ > 0f) ||
            (gameObject.transform.rotation.z > -0.25f && actionZ < 0f))
        {
            gameObject.transform.Rotate(new Vector3(0, 0, 1), actionZ);
        }

        if ((gameObject.transform.rotation.x < 0.25f && actionX > 0f) ||
            (gameObject.transform.rotation.x > -0.25f && actionX < 0f))
        {
            gameObject.transform.Rotate(new Vector3(1, 0, 0), actionX);
        }
  1. Z轴旋转条件
    • 首先,检查对象在Z轴的旋转角度是否小于0.25f且actionZ大于0,或者对象在Z轴的旋转角度是否大于-0.25f且actionZ小于0。
    • 如果满足上述任一条件,使用Rotate方法沿Z轴(向量(0, 0, 1))旋转对象,旋转角度为actionZ
  2. X轴旋转条件
    • 接着,检查对象在X轴的旋转角度是否小于0.25f且actionX大于0,或者对象在X轴的旋转角度是否大于-0.25f且actionX小于0。
    • 如果满足上述任一条件,使用Rotate方法沿X轴(向量(1, 0, 0))旋转对象,旋转角度为actionX
        if ((ball.transform.position.y - gameObject.transform.position.y) < -2f ||
            Mathf.Abs(ball.transform.position.x - gameObject.transform.position.x) > 3f ||
            Mathf.Abs(ball.transform.position.z - gameObject.transform.position.z) > 3f)
        {
            SetReward(-1f);
            EndEpisode();
        }
        else
        {
            SetReward(0.1f);
        }
  1. 条件判断
    • 首先,判断球体相对于游戏对象在y轴上的位置是否低于-2f(即球体是否在游戏对象下方至少2个单位)。
    • 然后,判断球体和游戏对象在x轴和z轴上的距离是否超过3个单位。这是通过计算两者在x轴和z轴上位置的绝对值差来实现的。
  2. 奖励和回合控制
    • 如果上述任一条件满足,即球体距离游戏对象太远或太低,那么调用SetReward(-1f)给予-1的奖励(惩罚),并通过EndEpisode()结束当前回合。
    • 如果上述条件都不满足,即球体与游戏对象的位置关系在可接受的范围内,那么调用SetReward(0.1f)给予0.1的奖励。
  • SetReward(float reward):这个函数用于设置当前步骤或回合的奖励值。reward参数是一个浮点数,表示要分配给智能体的奖励(或惩罚,如果是负值)。
  • EndEpisode():这个函数用于结束当前的训练回合。在强化学习中,一个回合(Episode)是智能体从开始到达到某个终止状态的一系列动作。==调用EndEpisode()时,ML-Agents会重置环境到初始状态或新的随机状态(OnEpisodeBegin),以便开始新的回合。==这通常在智能体达到目标、失败或其他某些终止条件时发生。

OnEpisodeBegin方法

在每个训练周期开始时重置环境和Agent的状态。

    public override void OnEpisodeBegin()
    {
        gameObject.transform.rotation = new Quaternion(0f, 0f, 0f, 0f);
        gameObject.transform.Rotate(new Vector3(1, 0, 0), Random.Range(-10f, 10f));
        gameObject.transform.Rotate(new Vector3(0, 0, 1), Random.Range(-10f, 10f));
        m_BallRb.velocity = new Vector3(0f, 0f, 0f);
        ball.transform.position = new Vector3(Random.Range(-1.5f, 1.5f), 4f, Random.Range(-1.5f, 1.5f))
            + gameObject.transform.position;
        //Reset the parameters when the Agent is reset.
        SetResetParameters();
    }
  1. 重置旋转:

    gameObject.transform.rotation = new Quaternion(0f, 0f, 0f, 0f);
    

    这一行代码将游戏对象的旋转重置为初始状态。注意,这里使用的是四元数(Quaternion),(0f, 0f, 0f, 0f)实际上是一个无效的四元数,应该是(0f, 0f, 0f, 1f)来表示没有旋转。

  2. 随机旋转:

    gameObject.transform.Rotate(new Vector3(1, 0, 0), Random.Range(-10f, 10f));
    gameObject.transform.Rotate(new Vector3(0, 0, 1), Random.Range(-10f, 10f));
    

    这两行代码分别在x轴和z轴方向上对游戏对象进行随机旋转,旋转角度在-10度到10度之间。

  3. 重置速度:

    m_BallRb.velocity = new Vector3(0f, 0f, 0f);
    

    这一行代码将的速度重置为(0f, 0f, 0f),确保球在新的训练周期开始时静止。

  4. 重置位置:

    ball.transform.position = new Vector3(Random.Range(-1.5f, 1.5f), 4f, Random.Range(-1.5f, 1.5f)) + gameObject.transform.position;
    

    这一行代码将球的位置重置为一个随机位置,x和z方向上的位置范围在-1.5到1.5之间,y方向固定为4f。这个位置偏移是基于游戏对象的当前位置。

  5. 重置参数:

    SetResetParameters();
    

    这一行代码调用了SetResetParameters方法,重置Agent的参数。这个方法通常包含一些自定义的重置逻辑,例如重置奖励函数、状态变量等。

Heuristic方法

Heuristic方法,用于在ML-Agents中实现手动控制Agent。这个方法在没有训练模型时,可以用于手动测试和调试Agent的行为。

    public override void Heuristic(in ActionBuffers actionsOut)
    {
        var continuousActionsOut = actionsOut.ContinuousActions;
        continuousActionsOut[0] = -Input.GetAxis("Horizontal");
        continuousActionsOut[1] = Input.GetAxis("Vertical");
    }

获取连续动作数组:

var continuousActionsOut = actionsOut.ContinuousActions;

这行代码从actionsOut中获取连续动作数组。这个数组用于存储Agent在当前时间步执行的连续动作。

读取并设置水平方向的动作:

continuousActionsOut[0] = -Input.GetAxis("Horizontal");

这行代码读取用户在水平方向的输入(例如通过按下左/右箭头键或使用游戏控制器的左/右摇杆),并将其值设置为连续动作数组的第一个元素。Input.GetAxis("Horizontal")返回一个值,表示水平输入方向(左为负,右为正)。通过添加负号,将左侧输入变为正值,右侧输入变为负值。

读取并设置垂直方向的动作:

continuousActionsOut[1] = Input.GetAxis("Vertical");

这行代码读取用户在垂直方向的输入(例如通过按下上/下箭头键或使用游戏控制器的上/下摇杆),并将其值设置为连续动作数组的第二个元素。Input.GetAxis("Vertical")返回一个值,表示垂直输入方向(下为负,上为正)。

补充:

在Unity的ML-Agents中,OnEpisodeBeginInitialize是两个重要的方法,它们在强化学习代理(Agent)的生命周期中扮演不同的角色。

Initialize

Initialize 方法在Agent实例化时调用一次。这个方法主要用于初始化Agent的状态和参数。具体来说,你可以在Initialize方法中执行以下操作:

  1. 初始化参数:设置强化学习所需的一些初始参数,如学习率、折扣因子等。
  2. 设置环境:配置Agent所处的环境,例如加载场景、设置环境参数等。
  3. 定义变量:定义一些在整个Agent生命周期内需要使用的变量。

例如:

public override void Initialize()
{
    // 初始化参数
    learningRate = 0.01f;
    
    // 设置环境
    SetupEnvironment();

    // 定义变量
    stepCount = 0;
}
OnEpisodeBegin

OnEpisodeBegin 方法在每个训练周期(episode)开始时调用。这个方法主要用于在每个新的训练周期开始时重置环境和Agent的状态。具体来说,你可以在OnEpisodeBegin方法中执行以下操作:

  1. 重置环境:重置环境的状态,使得每个训练周期都从一个初始状态开始。
  2. 重置Agent状态:重置Agent的状态,如位置、速度等,以确保Agent从相同的起点开始新的训练周期。
  3. 其他初始化操作:如果在每个训练周期开始时需要执行特定的操作,可以在这里进行。

例如:

public override void OnEpisodeBegin()
{
    // 重置环境
    ResetEnvironment();

    // 重置Agent状态
    ResetAgentState();

    // 其他初始化操作
    stepCount = 0;
}
总结
  • Initialize方法在Agent实例化时调用一次,主要用于Agent的初始配置和参数设置。
  • OnEpisodeBegin方法在每个训练周期开始时调用,主要用于重置环境和Agent的状态,以便开始新的训练周期。

这两个方法的区别在于调用的时机和目的:Initialize用于Agent的初始配置,而OnEpisodeBegin用于每个训练周期的重新初始化。

  • 20
    点赞
  • 26
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值