Unity-ML-Agents--Learning-Environment-Design-Agents.md-代码解读(1)

代码来源:https://github.com/Unity-Technologies/ml-agents/blob/release_19/docs/Learning-Environment-Design-Agents.md#decisions

1.Agent.CollectObservations()

1.1 代码总括

    public GameObject ball;

    public override void CollectObservations(VectorSensor sensor)
    {
        // 添加观测值
        // 立方体的方向(2个浮点数)
        sensor.AddObservation(gameObject.transform.rotation.z);
        sensor.AddObservation(gameObject.transform.rotation.x);
        // 小球相对于立方体的位置(3个浮点数)
        sensor.AddObservation(ball.transform.position - gameObject.transform.position);
        // 小球的速度(3个浮点数)
        sensor.AddObservation(m_BallRb.velocity);
    }

这段代码是Unity ML-Agents中一个Agent的重要函数CollectObservations,它用来采集当前状态的观测值,并将这些观测值传递给神经网络进行处理和决策

其中的参数VectorSensor sensor是一个向量传感器,可以用来收集观测值。在这个函数中,我们需要根据我们的游戏环境和任务设计,收集不同的观测值,并将这些观测值添加到传感器中。

具体来说,这段代码的作用是收集以下观测值:

①Cube的Z轴旋转角度(一个浮点数)。

②Cube的X轴旋转角度(一个浮点数)。

③Ball相对于Cube的位置向量(三个浮点数)。

④Ball的速度向量(三个浮点数)。

这些观测值将被组成一个向量,并传递给神经网络进行处理和决策。在训练过程中,我们会通过调整神经网络的权重和偏置来优化其决策能力,使其能够在游戏中获得更高的得分。

1.2 代码分解

1.2.1 public GameObject ball

public GameObject ball;

这行代码声明了一个GameObject类型的变量ball,用于引用场景中的球对象。在之后的代码中,可以通过这个变量来获取球对象的信息和状态。

问:GameObject类型?

在Unity中,GameObject是最基本的对象类型之一,它代表了场景中的实体。每个GameObject都可以拥有一个或多个组件(Component),例如Transform(控制GameObject的位置、旋转和缩放)、MeshRenderer(用于渲染物体的网格)等。你可以使用代码获取场景中的GameObject,以便对它们进行操作和修改。

1.2.2 public override void CollectObservations(VectorSensor sensor)

public override void CollectObservations(VectorSensor sensor)

这段代码是Unity ML-Agents中一个Agent的重要函数CollectObservations,它用来采集当前状态的观测值,并将这些观测值传递给神经网络进行处理和决策。

其中的参数VectorSensor sensor是一个向量传感器,可以用来收集观测值。在这个函数中,我们需要根据我们的游戏环境和任务设计,收集不同的观测值,并将这些观测值添加到传感器中。

VectorSensor 是ML-Agents SDK中的一个类,用于向智能体发送观察(observation)数据。而sensor则是在CollectObservations方法中作为参数传入的一个VectorSensor对象,它代表着当前智能体的观察数据。在CollectObservations方法中,我们会调用sensorAddObservation方法,将当前智能体的各项观察数据添加到VectorSensor中。这些观察数据最终会被发送到智能体中,用于训练智能体的神经网络模型。

1.2.3 sensor.AddObservation(gameObject.transform.rotation.z)

sensor.AddObservation(gameObject.transform.rotation.z);
sensor.AddObservation(gameObject.transform.rotation.x);

这两行代码会向传感器添加物体的方向信息gameObject是指当前脚本所在的游戏对象transform是指该游戏对象的变换组件。在这里,我们获取了该游戏对象在 z 和 x 轴上的旋转值,并将这些值添加到向量传感器中,以供神经网络模型使用。这个方向信息可以帮助模型了解物体在世界中的朝向。

sensor.AddObservation 是 ML-Agents SDK 提供的一个方法,用于向 Agent 发送观测(observation)的信息。在这个例子中,CollectObservations 函数中的 sensor 参数是一个 VectorSensor 类型的对象,它提供了一系列的方法来添加不同类型的观测信息。AddObservation 方法可以向 VectorSensor 中添加一个浮点型的数值,用来表示当前的观测信息。在这个例子中,我们向 VectorSensor 中添加了 2 个浮点型数值,分别是 gameObject.transform.rotation.zgameObject.transform.rotation.x。这两个数值用来表示智能体当前的朝向。

1.2.4 sensor.AddObservation(ball.transform.position - gameObject.transform.position)

 sensor.AddObservation(ball.transform.position - gameObject.transform.position);

这行代码是将球相对于物体的位置加入到观测中。具体来说,ball.transform.position表示球的世界空间坐标,而gameObject.transform.position表示当前脚本所在物体的世界空间坐标。两者相减即得到球相对于物体的位置。将这个向量加入到观测中,让智能体能够感知球相对于自己的位置信息。

这三个浮点数代表了球相对于立方体的位置,即相对位置信息。这些信息是在每个时间步骤中从环境中收集的,用于训练智能体学习如何控制立方体以使球保持在正确的位置。在神经网络中,这些相对位置信息作为输入特征被传递给模型,以帮助模型学习决策动作。

问:gameObject.transform.position ?

gameObject是一个特殊的变量,代表挂载脚本的游戏对象本身。gameObject.transform.position指的是当前游戏对象的位置,返回一个Vector3类型的坐标值,包括x、y、z三个分量。这个值表示游戏对象在三维空间中的位置。

问:为何是三个浮点数?

gameObject.transform.position 表示游戏对象的位置,是一个三维向量,包含三个浮点数分别表示x、y和z方向上的位置。而 ball.transform.position - gameObject.transform.position 是表示球相对于立方体的相对位置,同样也是一个三维向量,因此也需要三个浮点数来表示。

1.2.5 sensor.AddObservation(m_BallRb.velocity)

sensor.AddObservation(m_BallRb.velocity)

这行代码会将球的速度作为观测值添加到传感器中,球的速度是一个三维向量,因此会添加三个浮点数作为观测值。具体而言,它会将球刚体组件 m_BallRb 的速度向量添加到传感器中,以便智能体可以感知球的当前速度。

问:为何是三个浮点数?

m_BallRb.velocity 表示球体的速度,是一个三维向量,分别表示球体在 x、y、z 方向上的速度大小。因此,需要用三个浮点数来描述这个向量。在这里,这三个浮点数分别对应了神经网络中的三个输入节点。

2. Observable Fields and Properties

2.1 代码总括

using Unity.MLAgents.Sensors.Reflection;

public class Ball3DHardAgent : Agent {

    [Observable(numStackedObservations: 9)]
    Vector3 PositionDelta
    {
        get
        {
            return ball.transform.position - gameObject.transform.position;
        }
    }
}

这段代码使用了ML-Agents SDK中的ReflectionSensor,它可以通过反射机制从代码中获取特定属性的值,并将其作为观察结果传递给神经网络

具体来说,该代码定义了一个Ball3DHardAgent类,并继承了Agent类,表示这是一个可以训练的智能体。这个智能体有一个公共的GameObject类型的变量ball,代表场景中的球。同时,这个类还定义了一个名为PositionDelta的属性,返回球的位置相对于智能体位置的差量。这个属性使用了Observable特性,表示它可以被ReflectionSensor观察到,并且在多个时间步骤中进行堆叠(numStackedObservations: 9),以便于神经网络可以获取更多历史信息进行学习。

2.2 代码分解

2.2.1 using Unity.MLAgents.Sensors.Reflection

using Unity.MLAgents.Sensors.Reflection;

这行代码导入了Reflection命名空间,其中包含了ObservableAttribute特性类。

2.2.2 public class Ball3DHardAgent : Agent

public class Ball3DHardAgent : Agent

这段代码定义了一个名为 Ball3DHardAgent 的类,该类继承了 Agent 类,表示这是一个使用 Unity ML-Agents 框架的智能体。

智能体是指在机器学习中用于解决问题的实体,它能够感知环境,并通过学习策略来最大化某个目标函数,以达到问题的最优解。在 Unity ML-Agents 框架中,智能体可以是场景中的任何一个游戏对象(GameObject)。

该类中包含一个名为 PositionDelta 的属性,使用了 Observable 特性,用于在观察数据中添加神经网络的输入。具体来说, PositionDelta 是一个 Vector3 类型的属性,它计算了球相对于智能体的位置偏移,并将其作为一个长度为 9 的向量(numStackedObservations: 9)添加到观察数据中。

在使用 Reflection Sensor 时,Observable 特性用于标记智能体的观测数据(observation)应该被添加到 Reflection Sensor 的缓冲区中。该特性还允许指定观测数据在缓冲区中的位置,以及多个观测数据的堆叠数量。

2.2.3 [Observable(numStackedObservations: 9)]

[Observable(numStackedObservations: 9)]

这是一个特性(attribute),用于标识一个字段或属性需要在观测中被包含,参数numStackedObservations表示要在多少个步骤中记录该观测。在这个例子中,PositionDelta属性将被包含在观测中,并在9个连续的步骤中记录。

2.2.4 Vector3 PositionDelta

Vector3 PositionDelta

Vector3 PositionDelta 是一个公共的 get 属性,返回代表球相对于代理的位置的 3D 矢量。它标记为 [Observable(numStackedObservations: 9)],表示该属性可以被观察,并且要被观察的历史记录数量为 9。

问:get ?

get是一个C#中的关键字,用于获取属性值。在这个上下文中,getPositionDelta属性的一个访问器,表示获取球和代理之间的位置差矢量。当我们在其它部分使用PositionDelta属性时,它将返回该方法中定义的球和代理之间的位置差矢量。

问:ObservableAttribute?

ObservableAttribute是一个用于标记属性的特性,指示这个属性应该被添加到观测中。ML-Agents SDK中,观测是用来提供环境信息给智能体的一种方式,可以被训练的神经网络用来做出决策。numStackedObservations参数用于指定这个属性在观测中应该被添加多少次,以产生一系列历史信息,也称为"stacked observations"。在这个示例中,PositionDelta属性被标记为可观测,并指定了它应该被添加9次。

3.One-hot encoding categorical information

3.1 代码总括

enum ItemType { Sword, Shield, Bow, LastItem }

public override void CollectObservations(VectorSensor sensor)
{
    // 遍历所有物品类型
    for (int ci = 0; ci < (int)ItemType.LastItem; ci++)
    {
        // 如果当前持有物品是该类型,就设置为 1,否则为 0
        sensor.AddObservation((int)currentItem == ci ? 1.0f : 0.0f);
    }
}

这段代码使用枚举定义了物品类型,然后在CollectObservations函数中使用VectorSensor来收集观察数据。在这里,通过遍历所有物品类型并检查当前持有的物品类型,将当前持有物品的类型转换为整数(0,1,2 等)并添加到观察向量中。如果当前持有的物品类型当前遍历的物品类型相同,那么将值设置为 1,否则将值设置为 0。

3.2 代码分解

3.2.1 enum ItemType { Sword, Shield, Bow, LastItem }

enum ItemType { Sword, Shield, Bow, LastItem }

这里定义了一个枚举类型 ItemType,包含了四个成员:Sword、Shield、Bow 和 LastItem。其中,LastItem 可以看做是一个标记,用来标记 ItemType 枚举类型的成员数量

枚举类型可以用来定义一些具有特定含义的常量。在这个例子中,ItemType 可以表示游戏中的不同物品类型。Sword、Shield 和 Bow 代表三种不同的武器类型,而 LastItem 则用来帮助计算枚举类型的成员数量。

需要注意的是,枚举类型的成员默认情况下会被赋予一个整数值,从0开始自增。在这个例子中,Sword 的值为0,Shield 的值为1,Bow 的值为2,LastItem 的值为3。

3.2.2 public override void CollectObservations(VectorSensor sensor)

public override void CollectObservations(VectorSensor sensor)

这个函数是用来收集Agent的观测(observation)信息,其中的 VectorSensor 参数是一个用于存储Agent观测信息的向量。

在这个例子中,函数会为每一个枚举值添加一个观测。循环会遍历枚举中的每一个值,对于每一个值,如果当前Agent拥有该物品,则将1.0添加到观测向量中,否则添加0.0。这样就可以将Agent拥有的物品情况作为观测信息传给神经网络进行训练。

3.2.3 for (int ci = 0; ci < (int)ItemType.LastItem; ci++)

for (int ci = 0; ci < (int)ItemType.LastItem; ci++)
    {
        // 如果当前持有物品是该类型,就设置为 1,否则为 0
        sensor.AddObservation((int)currentItem == ci ? 1.0f : 0.0f);
    }

这段代码是在 CollectObservations 函数中,用于向智能体的观测空间(observation space)中添加一些数据。具体来说,它遍历了一个枚举类型 ItemType 中的所有枚举值不包括 LastItem),并根据当前持有的物品类型,将该类型对应的观测值设置为 1,其余类型对应的观测值设置为 0。这些观测值将在智能体的感知过程中被用来决策其下一步的行动。

sensor.AddObservation((int)currentItem == ci ? 1.0f : 0.0f) 这行代码的作用是将当前持有物品的类型转换成整数,并与循环变量 ci 进行比较。如果当前持有物品的类型等于 ci,则向观察向量中添加一个值为 1 的观察值,否则添加一个值为 0 的观察值。这样就可以将当前持有物品的信息传递给智能体的神经网络,使其能够感知到智能体当前所持有的物品类型。

第二种遍历方式:

// 定义枚举类型 ItemType,表示拥有的物品类型
enum ItemType { Sword, Shield, Bow, LastItem }

// 定义 NUM_ITEM_TYPES 常量表示物品类型的数量,这里为枚举类型中所有值的数量
const int NUM_ITEM_TYPES = (int)ItemType.LastItem;

public override void CollectObservations(VectorSensor sensor)
{
    // 使用 AddOneHotObservation 方法将一个整型值编码为 one-hot 向量
    // 第一个参数是要编码的整型值,第二个参数是 one-hot 向量的维度(即可能的取值数量)
    sensor.AddOneHotObservation((int)currentItem, NUM_ITEM_TYPES);
}

这段代码实现了将一个枚举类型的变量编码成 one-hot 向量,方便神经网络进行处理。

const int NUM_ITEM_TYPES = (int)ItemType.LastItem;

NUM_ITEM_TYPES 是一个常量,它被设置为 ItemType 枚举的最后一个元素(即 LastItem)的整数值。由于 ItemType 枚举中的元素都是连续的整数值,因此 LastItem 的整数值等于所有元素数量加 1。通过将 NUM_ITEM_TYPES 设置为 (int)ItemType.LastItem,我们可以保证它的值总是等于 ItemType 枚举中元素的数量。这在后续的代码中会被用到。

sensor.AddOneHotObservation((int)currentItem, NUM_ITEM_TYPES);

这段代码使用了 AddOneHotObservation 方法,它将指定索引的观察值编码为 one-hot 向量。one-hot 向量是指只有一个元素为 1,其余为 0 的向量。这种编码方式在分类问题中非常常见,例如将物品类型(剑、盾、弓等)编码为 one-hot 向量。

在这里,我们将当前所持物品类型的索引传递给 AddOneHotObservation 方法,同时还需要传递该特征的可能取值数量,也就是枚举类型 ItemType 中元素的个数(通过将 LastItem 强制转换为整数获得)。这将为神经网络提供一组 0 和 1 的观察值,表示当前所持物品类型是哪一种。

第三种方式:

enum ItemType { Sword, Shield, Bow }

public class HeroAgent : Agent
{
    [Observable]
    ItemType m_CurrentItem;
}

这段代码定义了一个枚举 ItemType 表示不同种类的游戏物品,还定义了一个 HeroAgent 类,该类继承了 Agent 类。HeroAgent 类中有一个 m_CurrentItem 变量,用来表示当前英雄所持有的游戏物品类型,该变量被标记为 [Observable],表示它可以被 ML-Agents 访问并收集为状态信息。

[Observable]是Unity ML-Agents中的一个属性,可以用于将变量标记为需要观察的属性。在训练过程中,智能体会观察这些属性,并将它们作为状态传递给神经网络进行学习。这些属性可以是任何类型的基本数据类型,例如整数、浮点数、布尔值或枚举类型。使用[Observable]属性来标记属性可以确保它们在训练过程中被正确地观察和记录。

 ItemType m_CurrentItem;

这行代码声明了一个公共字段 m_CurrentItem,类型为 ItemTypeItemType 是一个枚举类型,包含了三个枚举值 SwordShieldBow,表示英雄当前持有的装备。

[Observable] 属性告诉 ML-Agents,该字段需要被作为观察值传递给智能体的神经网络。因此,在智能体的 CollectObservations() 函数中,你可以通过 VectorSensor.AddObservation() 方法将 m_CurrentItem 的整数值作为观察值传递给神经网络。

  • 1
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

天寒心亦热

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值