1.Agent.CollectObservations()
1.1 代码总括
public GameObject ball;
public override void CollectObservations(VectorSensor sensor)
{
// 添加观测值
// 立方体的方向(2个浮点数)
sensor.AddObservation(gameObject.transform.rotation.z);
sensor.AddObservation(gameObject.transform.rotation.x);
// 小球相对于立方体的位置(3个浮点数)
sensor.AddObservation(ball.transform.position - gameObject.transform.position);
// 小球的速度(3个浮点数)
sensor.AddObservation(m_BallRb.velocity);
}
这段代码是Unity ML-Agents中一个Agent的重要函数CollectObservations
,它用来采集当前状态的观测值,并将这些观测值传递给神经网络进行处理和决策。
其中的参数VectorSensor sensor
是一个向量传感器,可以用来收集观测值。在这个函数中,我们需要根据我们的游戏环境和任务设计,收集不同的观测值,并将这些观测值添加到传感器中。
具体来说,这段代码的作用是收集以下观测值:
①Cube的Z轴旋转角度(一个浮点数)。
②Cube的X轴旋转角度(一个浮点数)。
③Ball相对于Cube的位置向量(三个浮点数)。
④Ball的速度向量(三个浮点数)。
这些观测值将被组成一个向量,并传递给神经网络进行处理和决策。在训练过程中,我们会通过调整神经网络的权重和偏置来优化其决策能力,使其能够在游戏中获得更高的得分。
1.2 代码分解
1.2.1 public GameObject ball
public GameObject ball;
这行代码声明了一个GameObject类型的变量ball,用于引用场景中的球对象。在之后的代码中,可以通过这个变量来获取球对象的信息和状态。
问:GameObject类型?
在Unity中,GameObject是最基本的对象类型之一,它代表了场景中的实体。每个GameObject都可以拥有一个或多个组件(Component),例如Transform(控制GameObject的位置、旋转和缩放)、MeshRenderer(用于渲染物体的网格)等。你可以使用代码获取场景中的GameObject,以便对它们进行操作和修改。
1.2.2 public override void CollectObservations(VectorSensor sensor)
public override void CollectObservations(VectorSensor sensor)
这段代码是Unity ML-Agents中一个Agent的重要函数CollectObservations
,它用来采集当前状态的观测值,并将这些观测值传递给神经网络进行处理和决策。
其中的参数VectorSensor sensor
是一个向量传感器,可以用来收集观测值。在这个函数中,我们需要根据我们的游戏环境和任务设计,收集不同的观测值,并将这些观测值添加到传感器中。
VectorSensor
是ML-Agents SDK中的一个类,用于向智能体发送观察(observation)数据。而sensor
则是在CollectObservations
方法中作为参数传入的一个VectorSensor
对象,它代表着当前智能体的观察数据。在CollectObservations
方法中,我们会调用sensor
的AddObservation
方法,将当前智能体的各项观察数据添加到VectorSensor
中。这些观察数据最终会被发送到智能体中,用于训练智能体的神经网络模型。
1.2.3 sensor.AddObservation(gameObject.transform.rotation.z)
sensor.AddObservation(gameObject.transform.rotation.z);
sensor.AddObservation(gameObject.transform.rotation.x);
这两行代码会向传感器添加物体的方向信息。gameObject
是指当前脚本所在的游戏对象,transform
是指该游戏对象的变换组件。在这里,我们获取了该游戏对象在 z 和 x 轴上的旋转值,并将这些值添加到向量传感器中,以供神经网络模型使用。这个方向信息可以帮助模型了解物体在世界中的朝向。
sensor.AddObservation
是 ML-Agents SDK 提供的一个方法,用于向 Agent 发送观测(observation)的信息。在这个例子中,CollectObservations
函数中的 sensor
参数是一个 VectorSensor
类型的对象,它提供了一系列的方法来添加不同类型的观测信息。AddObservation
方法可以向 VectorSensor
中添加一个浮点型的数值,用来表示当前的观测信息。在这个例子中,我们向 VectorSensor
中添加了 2 个浮点型数值,分别是 gameObject.transform.rotation.z
和 gameObject.transform.rotation.x
。这两个数值用来表示智能体当前的朝向。
1.2.4 sensor.AddObservation(ball.transform.position - gameObject.transform.position)
sensor.AddObservation(ball.transform.position - gameObject.transform.position);
这行代码是将球相对于物体的位置加入到观测中。具体来说,ball.transform.position
表示球的世界空间坐标,而gameObject.transform.position
表示当前脚本所在物体的世界空间坐标。两者相减即得到球相对于物体的位置。将这个向量加入到观测中,让智能体能够感知球相对于自己的位置信息。
这三个浮点数代表了球相对于立方体的位置,即相对位置信息。这些信息是在每个时间步骤中从环境中收集的,用于训练智能体学习如何控制立方体以使球保持在正确的位置。在神经网络中,这些相对位置信息作为输入特征被传递给模型,以帮助模型学习决策动作。
问:gameObject.transform.position ?
gameObject
是一个特殊的变量,代表挂载脚本的游戏对象本身。gameObject.transform.position
指的是当前游戏对象的位置,返回一个Vector3类型的坐标值,包括x、y、z三个分量。这个值表示游戏对象在三维空间中的位置。
问:为何是三个浮点数?
gameObject.transform.position
表示游戏对象的位置,是一个三维向量,包含三个浮点数分别表示x、y和z方向上的位置。而 ball.transform.position - gameObject.transform.position
是表示球相对于立方体的相对位置,同样也是一个三维向量,因此也需要三个浮点数来表示。
1.2.5 sensor.AddObservation(m_BallRb.velocity)
sensor.AddObservation(m_BallRb.velocity)
这行代码会将球的速度作为观测值添加到传感器中,球的速度是一个三维向量,因此会添加三个浮点数作为观测值。具体而言,它会将球刚体组件 m_BallRb
的速度向量添加到传感器中,以便智能体可以感知球的当前速度。
问:为何是三个浮点数?
m_BallRb.velocity
表示球体的速度,是一个三维向量,分别表示球体在 x、y、z 方向上的速度大小。因此,需要用三个浮点数来描述这个向量。在这里,这三个浮点数分别对应了神经网络中的三个输入节点。
2. Observable Fields and Properties
2.1 代码总括
using Unity.MLAgents.Sensors.Reflection;
public class Ball3DHardAgent : Agent {
[Observable(numStackedObservations: 9)]
Vector3 PositionDelta
{
get
{
return ball.transform.position - gameObject.transform.position;
}
}
}
这段代码使用了ML-Agents SDK中的ReflectionSensor,它可以通过反射机制从代码中获取特定属性的值,并将其作为观察结果传递给神经网络。
具体来说,该代码定义了一个Ball3DHardAgent类,并继承了Agent类,表示这是一个可以训练的智能体。这个智能体有一个公共的GameObject类型的变量ball,代表场景中的球。同时,这个类还定义了一个名为PositionDelta的属性,返回球的位置相对于智能体位置的差量。这个属性使用了Observable特性,表示它可以被ReflectionSensor观察到,并且在多个时间步骤中进行堆叠(numStackedObservations: 9),以便于神经网络可以获取更多历史信息进行学习。
2.2 代码分解
2.2.1 using Unity.MLAgents.Sensors.Reflection
using Unity.MLAgents.Sensors.Reflection;
这行代码导入了Reflection命名空间,其中包含了ObservableAttribute特性类。
2.2.2 public class Ball3DHardAgent : Agent
public class Ball3DHardAgent : Agent
这段代码定义了一个名为 Ball3DHardAgent
的类,该类继承了 Agent
类,表示这是一个使用 Unity ML-Agents 框架的智能体。
智能体是指在机器学习中用于解决问题的实体,它能够感知环境,并通过学习策略来最大化某个目标函数,以达到问题的最优解。在 Unity ML-Agents 框架中,智能体可以是场景中的任何一个游戏对象(GameObject)。
该类中包含一个名为 PositionDelta
的属性,使用了 Observable
特性,用于在观察数据中添加神经网络的输入。具体来说, PositionDelta
是一个 Vector3 类型的属性,它计算了球相对于智能体的位置偏移,并将其作为一个长度为 9 的向量(numStackedObservations: 9)添加到观察数据中。
在使用 Reflection Sensor 时,Observable
特性用于标记智能体的观测数据(observation)应该被添加到 Reflection Sensor 的缓冲区中。该特性还允许指定观测数据在缓冲区中的位置,以及多个观测数据的堆叠数量。
2.2.3 [Observable(numStackedObservations: 9)]
[Observable(numStackedObservations: 9)]
这是一个特性(attribute),用于标识一个字段或属性需要在观测中被包含,参数numStackedObservations
表示要在多少个步骤中记录该观测。在这个例子中,PositionDelta
属性将被包含在观测中,并在9个连续的步骤中记录。
2.2.4 Vector3 PositionDelta
Vector3 PositionDelta
Vector3 PositionDelta
是一个公共的 get
属性,返回代表球相对于代理的位置的 3D 矢量。它标记为 [Observable(numStackedObservations: 9)]
,表示该属性可以被观察,并且要被观察的历史记录数量为 9。
问:get ?
get
是一个C#中的关键字,用于获取属性值。在这个上下文中,get
是PositionDelta
属性的一个访问器,表示获取球和代理之间的位置差矢量。当我们在其它部分使用PositionDelta
属性时,它将返回该方法中定义的球和代理之间的位置差矢量。
问:ObservableAttribute?
ObservableAttribute
是一个用于标记属性的特性,指示这个属性应该被添加到观测中。ML-Agents SDK中,观测是用来提供环境信息给智能体的一种方式,可以被训练的神经网络用来做出决策。numStackedObservations
参数用于指定这个属性在观测中应该被添加多少次,以产生一系列历史信息,也称为"stacked observations"。在这个示例中,PositionDelta
属性被标记为可观测,并指定了它应该被添加9次。
3.One-hot encoding categorical information
3.1 代码总括
enum ItemType { Sword, Shield, Bow, LastItem }
public override void CollectObservations(VectorSensor sensor)
{
// 遍历所有物品类型
for (int ci = 0; ci < (int)ItemType.LastItem; ci++)
{
// 如果当前持有物品是该类型,就设置为 1,否则为 0
sensor.AddObservation((int)currentItem == ci ? 1.0f : 0.0f);
}
}
这段代码使用枚举定义了物品类型,然后在CollectObservations
函数中使用VectorSensor
来收集观察数据。在这里,通过遍历所有物品类型并检查当前持有的物品类型,将当前持有物品的类型转换为整数(0,1,2 等)并添加到观察向量中。如果当前持有的物品类型与当前遍历的物品类型相同,那么将值设置为 1,否则将值设置为 0。
3.2 代码分解
3.2.1 enum ItemType { Sword, Shield, Bow, LastItem }
enum ItemType { Sword, Shield, Bow, LastItem }
这里定义了一个枚举类型 ItemType,包含了四个成员:Sword、Shield、Bow 和 LastItem。其中,LastItem 可以看做是一个标记,用来标记 ItemType 枚举类型的成员数量。
枚举类型可以用来定义一些具有特定含义的常量。在这个例子中,ItemType 可以表示游戏中的不同物品类型。Sword、Shield 和 Bow 代表三种不同的武器类型,而 LastItem 则用来帮助计算枚举类型的成员数量。
需要注意的是,枚举类型的成员默认情况下会被赋予一个整数值,从0开始自增。在这个例子中,Sword 的值为0,Shield 的值为1,Bow 的值为2,LastItem 的值为3。
3.2.2 public override void CollectObservations(VectorSensor sensor)
public override void CollectObservations(VectorSensor sensor)
这个函数是用来收集Agent的观测(observation)信息,其中的 VectorSensor
参数是一个用于存储Agent观测信息的向量。
在这个例子中,函数会为每一个枚举值添加一个观测。循环会遍历枚举中的每一个值,对于每一个值,如果当前Agent拥有该物品,则将1.0添加到观测向量中,否则添加0.0。这样就可以将Agent拥有的物品情况作为观测信息传给神经网络进行训练。
3.2.3 for (int ci = 0; ci < (int)ItemType.LastItem; ci++)
for (int ci = 0; ci < (int)ItemType.LastItem; ci++)
{
// 如果当前持有物品是该类型,就设置为 1,否则为 0
sensor.AddObservation((int)currentItem == ci ? 1.0f : 0.0f);
}
这段代码是在 CollectObservations
函数中,用于向智能体的观测空间(observation space)中添加一些数据。具体来说,它遍历了一个枚举类型 ItemType
中的所有枚举值(不包括 LastItem
),并根据当前持有的物品类型,将该类型对应的观测值设置为 1,其余类型对应的观测值设置为 0。这些观测值将在智能体的感知过程中被用来决策其下一步的行动。
sensor.AddObservation((int)currentItem == ci ? 1.0f : 0.0f) 这行代码的作用是将当前持有物品的类型转换成整数,并与循环变量 ci
进行比较。如果当前持有物品的类型等于 ci
,则向观察向量中添加一个值为 1 的观察值,否则添加一个值为 0 的观察值。这样就可以将当前持有物品的信息传递给智能体的神经网络,使其能够感知到智能体当前所持有的物品类型。
第二种遍历方式:
// 定义枚举类型 ItemType,表示拥有的物品类型
enum ItemType { Sword, Shield, Bow, LastItem }
// 定义 NUM_ITEM_TYPES 常量表示物品类型的数量,这里为枚举类型中所有值的数量
const int NUM_ITEM_TYPES = (int)ItemType.LastItem;
public override void CollectObservations(VectorSensor sensor)
{
// 使用 AddOneHotObservation 方法将一个整型值编码为 one-hot 向量
// 第一个参数是要编码的整型值,第二个参数是 one-hot 向量的维度(即可能的取值数量)
sensor.AddOneHotObservation((int)currentItem, NUM_ITEM_TYPES);
}
这段代码实现了将一个枚举类型的变量编码成 one-hot 向量,方便神经网络进行处理。
const int NUM_ITEM_TYPES = (int)ItemType.LastItem;
NUM_ITEM_TYPES
是一个常量,它被设置为 ItemType 枚举的最后一个元素(即 LastItem
)的整数值。由于 ItemType
枚举中的元素都是连续的整数值,因此 LastItem
的整数值等于所有元素数量加 1。通过将 NUM_ITEM_TYPES
设置为 (int)ItemType.LastItem
,我们可以保证它的值总是等于 ItemType 枚举中元素的数量。这在后续的代码中会被用到。
sensor.AddOneHotObservation((int)currentItem, NUM_ITEM_TYPES);
这段代码使用了 AddOneHotObservation
方法,它将指定索引的观察值编码为 one-hot 向量。one-hot 向量是指只有一个元素为 1,其余为 0 的向量。这种编码方式在分类问题中非常常见,例如将物品类型(剑、盾、弓等)编码为 one-hot 向量。
在这里,我们将当前所持物品类型的索引传递给 AddOneHotObservation
方法,同时还需要传递该特征的可能取值数量,也就是枚举类型 ItemType
中元素的个数(通过将 LastItem
强制转换为整数获得)。这将为神经网络提供一组 0 和 1 的观察值,表示当前所持物品类型是哪一种。
第三种方式:
enum ItemType { Sword, Shield, Bow }
public class HeroAgent : Agent
{
[Observable]
ItemType m_CurrentItem;
}
这段代码定义了一个枚举 ItemType
表示不同种类的游戏物品,还定义了一个 HeroAgent
类,该类继承了 Agent
类。HeroAgent
类中有一个 m_CurrentItem
变量,用来表示当前英雄所持有的游戏物品类型,该变量被标记为 [Observable]
,表示它可以被 ML-Agents 访问并收集为状态信息。
[Observable]
是Unity ML-Agents中的一个属性,可以用于将变量标记为需要观察的属性。在训练过程中,智能体会观察这些属性,并将它们作为状态传递给神经网络进行学习。这些属性可以是任何类型的基本数据类型,例如整数、浮点数、布尔值或枚举类型。使用[Observable]
属性来标记属性可以确保它们在训练过程中被正确地观察和记录。
ItemType m_CurrentItem;
这行代码声明了一个公共字段 m_CurrentItem
,类型为 ItemType
。ItemType
是一个枚举类型,包含了三个枚举值 Sword
、Shield
和 Bow
,表示英雄当前持有的装备。
[Observable]
属性告诉 ML-Agents,该字段需要被作为观察值传递给智能体的神经网络。因此,在智能体的 CollectObservations()
函数中,你可以通过 VectorSensor.AddObservation()
方法将 m_CurrentItem
的整数值作为观察值传递给神经网络。