由于官方教程是全英版,本文为根据个人理解做的中文版
(能力有限,有问题的地方还望指出,同时欢迎志同道合的朋友参与讨论,谢谢!)
官方教程:https://learn.unity.com/course/ml-agents-hummingbirds?uv=2019.3
系列文章:unity learn—— ML_Agent:Hummingbirds中文教程 自译(一 ——配置Unity)
unity learn—— ML_Agent:Hummingbirds中文教程 自译(二——代码Flowers.cs、FlowerArea.cs)
unity learn—— ML_Agent:Hummingbirds中文教程 自译(四——Ray Perception Sensor)
unity learn—— ML_Agent:Hummingbirds中文教程 自译(五 ——安装Anaconda)
unity learn—— ML_Agent:Hummingbirds中文教程 自译(六——准备训练)
unity learn—— ML_Agent:Hummingbirds中文教程 自译(七——测试模型)
HummingbirdAgent.cs是unity与python API交互的最重要的代码,我称之为Agent代码
Note:
Agent代码创建之后的共性操作:
-
引入两个库
using Unity.MLAgents;
using Unity.MLAgents.Sensors; -
父物体由MonoBehaviour 改为 Agent.
-
删除Update()和Start()函数(不过这点不绝对,有的Agent代码还是需要Update()函数,而Start()则是被Initialize()代替)
-
需要重构三个函数
OnEpisodeBegin() : 初始化函数
CollectObservations(VectorSensor sensor):用于获取环境参数
OnActionReceived(float[] vectorAction):动作函数
以下是蜂鸟移动方式的示意图,以便你理解代码
HummingbirdAgent.cs
using System.Collections;
using System.Collections.Generic;
using UnityEngine;
using Unity.MLAgents; //注意这里导入了MLAgents的库
using Unity.MLAgents.Sensors;
/// <summary>
/// 蜂鸟机器学习
/// </summary>
public class HummingbirdAgent : Agent //继承也变了
{
[Tooltip("用于移动的力")]
public float moveForce = 2f;
//pitch和yaw方向见上图
[Tooltip("pitch方向旋转的速度")]
public float pitchSpeed = 100f;
[Tooltip("yaw方向旋转的速度")]
public float yawSpeed = 100f;
[Tooltip("鸟喙尖端处的transfrom")]
public Transform beakTip;
[Tooltip("智能体的摄像头")] //对于该项目,以下所有智能体均指蜂鸟
public Camera agentCamera;
[Tooltip("是训练模式?否为游戏模式")]
public bool trainingMode;
new private Rigidbody rigidbody; //智能体的刚体
private FlowerArea flowerArea; //智能体所在的空间
private Flower nearestFlower; //离智能体最近的花
//两个用于平滑移动的变量
private float smoothPitchChange = 0f; //平滑pitch方向
private float smoothYawChange = 0f; //平滑yaw方向
private const float maxPitchAngle = 80f; //pitch方向最大角度
private const float beakTipRadius = 0.008f; //鸟喙的直径
private bool frozen = false; //是否冻结智能体的行动
/// <summary>
/// 智能体获得的花粉数量
/// </summary>
public float NectarObtained { get; private set; }
/// <summary>
/// 初始化变量
/// </summary>
public override void Initialize()
{
rigidbody = GetComponent<Rigidbody>();
flowerArea = GetComponentInParent<FlowerArea>(); //该句成功的关键是鸟必须是挂载了flowerArea脚本的物体的第一级子物体
//如果不是训练模式
if (!trainingMode) MaxStep = 0; //MaxStep定义在Agent中
}
/// <summary>
/// 当玩家输入或神经网络控制时调用
///
/// vectorAction[i] represents:
/// Index 0: move vector x (+1 = right, -1 = left)
/// Index 1: move vector y (+1 = up, -1 = down)
/// Index 2: move vector z (+1 = forward, -1 = backward)
/// Index 3: pitch angle (+1 = pitch up, -1 = pitch down)
/// Index 4: yaw angle (+1 = turn right, -1 = turn left)
/// </summary>
/// <param name="vectorAction">将要采取的动作</param>
public override void OnActionReceived(float[] vectorAction)
{
//如果冻结,则不采取任何动作
if (frozen) { return; }
//移动方向
Vector3 move = new Vector3(vectorAction[0], vectorAction[1], vectorAction[2]);
//向移动方向施加力
rigidbody.AddForce(move * moveForce);
//获取pitch和yaw方向上的移动量
float pitchChange = vectorAction[3];
float yawChange = vectorAction[4];
//平滑旋转变化,防抖动
smoothPitchChange = Mathf.MoveTowards(smoothPitchChange, pitchChange, 2f * Time.fixedDeltaTime); //注:为了更贴合物理,
smoothYawChange = Mathf.MoveTowards(smoothYawChange, yawChange, 2f * Time.fixedDeltaTime); //不是Time.deltaTime
//获得目前rotation
Vector3 rotationVector = transform.rotation.eulerAngles;
//计算新的pitch和yaw方向上的移动量
float pitch = rotationVector.x + smoothPitchChange * Time.fixedDeltaTime * pitchSpeed;
float yaw = rotationVector.y + smoothYawChange * Time.fixedDeltaTime * yawSpeed; //pithc方向需要限制范围,而yaw方向不需要
if (pitch > 180f) pitch -= 360f;
pitch = Mathf.Clamp(pitch, -maxPitchAngle, maxPitchAngle);
//旋转对应的角度
transform.rotation = Quaternion.Euler(pitch, yaw, 0f);
}
/// <summary>
/// 获取环境的坐标数据
/// </summary>
/// <param name="sensor">坐标探测器</param>
public override void CollectObservations(VectorSensor sensor)
{
//为了便于训练,以下observation都做了归一化
//如果nearestFlower是空的,就观测十个空浮点数,之后返回,防止报错
if (nearestFlower == null)
{
sensor.AddObservation(new float[10]);
return;
}
//观测localRotation和与nearestFlower的相对位置
sensor.AddObservation(transform.localRotation.normalized); //4个探测目标(归一化是为了简化分析)
Vector3 toFlower = nearestFlower.FlowerCenterPosition - beakTip.position;
sensor.AddObservation(toFlower.normalized); //3个探测目标
// 通过向量点积观测鸟喙是否在花蕊正前方(1个探测目标)
// (+1 代表在花正前方, -1 代表在花正后面)
sensor.AddObservation(Vector3.Dot(toFlower.normalized, -nearestFlower.FlowerUpVector.normalized));
// 通过向量点积观测鸟喙是否指向花蕊正前方 (1个探测目标)
// (+1代表正面指向花, -1 代表在花后面指向花)
sensor.AddObservation(Vector3.Dot(beakTip.forward.normalized, -nearestFlower.FlowerUpVector.normalized));
//观测鸟喙与花的相对场地直径的相对距离(1个探测目标)
sensor.AddObservation(toFlower.magnitude / FlowerArea.AreaDiameter);
// 一共10个探测目标
}
/// <summary>
/// 当行为模式为“Heuristic Only”,即仅玩家探索时调用
/// 此时由玩家代替神经网络操作
/// </summary>
/// <param name="actionsOut">输出的行动数组</param>
public override void Heuristic(float[] actionsOut)
{
// 创建所有用于移动的变量
Vector3 forward = Vector3.zero;
Vector3 left = Vector3.zero;
Vector3 up = Vector3.zero;
float pitch = 0f;
float yaw = 0f;
// 获取键盘输入并进行相应的移动控制
// 所有数据都应在区间( -1 , +1 )内
// 前后控制
if (Input.GetKey(KeyCode.W)) forward = transform.forward;
else if (Input.GetKey(KeyCode.S)) forward = -transform.forward;
//左右控制
if (Input.GetKey(KeyCode.A)) left = -transform.right;
else if (Input.GetKey(KeyCode.D)) left = transform.right;
//上下控制
if (Input.GetKey(KeyCode.E)) up = transform.up;
else if (Input.GetKey(KeyCode.Q)) up = -transform.up;
//俯仰控制
if (Input.GetKey(KeyCode.UpArrow)) pitch = 1f;
else if (Input.GetKey(KeyCode.DownArrow)) pitch = -1f;
//左右转身控制
if (Input.GetKey(KeyCode.LeftArrow)) yaw = -1f;
else if (Input.GetKey(KeyCode.RightArrow)) yaw = 1f;
//合并移动变量并归一化
Vector3 combined = (forward + left + up).normalized;
// 将3个移动变量和两个旋转变量加到actionsOut数组中
actionsOut[0] = combined.x;
actionsOut[1] = combined.y;
actionsOut[2] = combined.z;
actionsOut[3] = pitch;
actionsOut[4] = yaw;
}
/// <summary>
/// 停止智能体移动
/// </summary>
public void FreezeAgent()
{
//训练模式下不支freeze
Debug.Assert(trainingMode == false, "Freeze/Unfreeze not supported in training");
frozen = true;
rigidbody.Sleep(); //刚体也停止作用
}
/// <summary>
/// 恢复智能体移动
/// </summary>
public void UnfreezeAgent()
{
Debug.Assert(trainingMode == false, "Freeze/Unfreeze not supported in training");
frozen = false;
rigidbody.WakeUp();
}
/// <summary>
/// 当事件开始时重置智能体
/// </summary>
public override void OnEpisodeBegin()
{
if (trainingMode)
{
//当场景中只有一个智能体时,只重置花
flowerArea.ResetFlower();
}
NectarObtained = 0f;
//重要!!! 将所有速度归零(防止上次一情景的运动影响到现在)
rigidbody.velocity = Vector3.zero;
rigidbody.angularVelocity = Vector3.zero;
bool inFrontofFlower = true;
if (trainingMode) //如果是训练模式
{
//50%可能性使智能体初始化时面对一朵花
inFrontofFlower = UnityEngine.Random.value > .5f;
}
//将智能体移动到一个随机的安全位置(主要避免出现在树木,石头等物体内部)
MoveToSafeRandomPosition(inFrontofFlower);
//重新计算最近的花
UpdateNearestFlower();
}
/// <summary>
/// 将智能体移动到一个随机的安全位置(主要避免出现在树木,石头等物体内部)
/// 如果inFrontofFlower == true,则将其面对一朵花
/// </summary>
/// <param name="inFrontofFlower">是否选择一个面对花的位置</param>
private void MoveToSafeRandomPosition(bool inFrontofFlower)
{
bool safePositionFound = false; //是否发现安全位置
int attemptsRemaining = 100; //尝试的次数,超过该次数则认为出现了问题,防止无限循环
//初始化位置状态
Vector3 potentialPosition = Vector3.zero;
Quaternion potentialRotation = new Quaternion();
//循环到找到安全位置或者尝试次数超过限定值
while (!safePositionFound && attemptsRemaining > 0)
{
attemptsRemaining--;
if (inFrontofFlower)
{
//随机选取一朵花
Flower randomFlower = flowerArea.Flowers[UnityEngine.Random.Range(0, flowerArea.Flowers.Count)];
//生成位置距花10-20cm
float distanceFromFlower = UnityEngine.Random.Range(.1f, .2f);
potentialPosition = randomFlower.transform.position + randomFlower.FlowerUpVector * distanceFromFlower;
//使智能体看向花
Vector3 toFlower = randomFlower.FlowerCenterPosition - potentialPosition; //计算方向向量
potentialRotation = Quaternion.LookRotation(toFlower, Vector3.up);
}
else
{
//随机Position
float height = UnityEngine.Random.Range(1.2f, 2.5f); //随机高度
float radius = UnityEngine.Random.Range(2f, 7f); //以区域中心为圆心的生成半径(之所以以2开始是为了避开中心的大树)
Quaternion direction = Quaternion.Euler(0f, UnityEngine.Random.Range(-180f, 180f), 0f); //关于y轴的一个随机方向
potentialPosition = flowerArea.transform.position + Vector3.up * height + direction * Vector3.forward * radius;
//随机Rotation
float pitch = UnityEngine.Random.Range(-60f, 60f);
float yaw = UnityEngine.Random.Range(-180f, 180f);
potentialRotation = Quaternion.Euler(pitch, yaw, 0f);
}
//检查是否与其他碰撞体重叠,不重叠就表示找到了安全的生成位置
safePositionFound = !(Physics.CheckBox(potentialPosition, new Vector3(0.05f, 0.05f, 0.05f))); //以一个中心为potentialPosition边长为10cm的正方体探测
}
Debug.Assert(safePositionFound, "Could not find a safe position to spawn"); //safePositionFound为false时发出警告
//设置位置
transform.position = potentialPosition;
transform.rotation = potentialRotation;
}
/// <summary>
/// 更新最近的花(个人觉得这个方法效率挺低的,感兴趣的同学可以优化一下)
/// </summary>
private void UpdateNearestFlower()
{
foreach (Flower flower in flowerArea.Flowers)
{
//如果最近的花nearestFlower没赋值,并且循环到的这朵花有花粉
if (nearestFlower == null && flower.HasNectar)
{
nearestFlower = flower;
}//如果最近的花nearestFlower已经赋值,并且循环到的这朵花有花粉
else if (flower.HasNectar)
{
//计算并比较哪个是最近的花,然后更新nearestFlower
float distanceToFlower = Vector3.Distance(flower.transform.position, beakTip.position);
float distanceToCurrentNearestFlower = Vector3.Distance(nearestFlower.transform.position, beakTip.position);
//如果现有的nearestFlower没花粉了,或者它不再是最近的花了
if (!nearestFlower.HasNectar || distanceToCurrentNearestFlower > distanceToFlower)
{
nearestFlower = flower;
}
}
}
}
private void OnTriggerEnter(Collider other)
{
TriggerEnterOrStay(other);
}
private void OnTriggerStay(Collider other)
{
TriggerEnterOrStay(other);
}
private void TriggerEnterOrStay(Collider collider)
{
// 检查是否碰到了花蕊
if (collider.CompareTag("nectar"))
{
Vector3 closestPointToBeakTip = collider.ClosestPoint(beakTip.position);//返回花蕊碰撞体上距离鸟喙最近的点
// 检查是不是真的是鸟喙碰到了花蕊
// Note: 除鸟喙之外,别的碰撞体碰到花蕊不算
if (Vector3.Distance(beakTip.position, closestPointToBeakTip) < beakTipRadius)
{
// 通过花蕊找到对应的花
Flower flower = flowerArea.GetFlowerFromNectar(collider);
// 尝试"吃掉" 0.01 的花粉
// Note: 每 0.02 秒调用一次
float nectarReceived = flower.Feed(.01f);
// 更新已获得的花粉(不是用于训练而是用在UI上)
NectarObtained += nectarReceived;
if (trainingMode)
{
// 计算奖励
float bonus = .02f * Mathf.Clamp01(Vector3.Dot(transform.forward.normalized, -nearestFlower.FlowerUpVector.normalized));
AddReward(.01f + bonus); //在bonus之外,只要是鸟喙碰到了花蕊就保证有0.01的基础奖励
}
// 如果花粉没了,换一朵花
if (!flower.HasNectar)
{
UpdateNearestFlower();
}
}
}
}
private void OnCollisionEnter(Collision collision)
{
if (trainingMode && collision.collider.CompareTag("boundary"))
{
//碰到了障碍物或者区域边界,给负分
AddReward(-.5f);
}
}
private void Update()
{
// 在鸟喙和最近的花之间画一条线
if (nearestFlower != null)
Debug.DrawLine(beakTip.position, nearestFlower.FlowerCenterPosition, Color.green);
}
private void FixedUpdate()
{
// 防止最近花的花粉被对手“吃”光了,但是没有更新
if (nearestFlower != null && !nearestFlower.HasNectar)
UpdateNearestFlower();
}
}
脚本编好后,需要返回unity再加两个tag:nectar和boundary
找到prefab中FloatingIsland的子物体IslandBoundary,为其设置tag:boundary
找到FlowerBud中的FlowerNectarCollider,设置tag:nectar
找到蜂鸟预制体,添加本次的脚本hummingbird.cs(它会自动添加另一个脚本Behavior Parameters)
关于Behavior Parameter:
变量 | 注释 |
---|---|
Behavior Name | 动作空间名:Default:可以训练,可以人机交互(就是玩家自己操作),也可以挂上训练好的模型做推理; Heuristic Only:仅人机交互; Inference Only:挂上训练好的模型后,推理用 |
Vector Observation :Space Size | 观测的个数(其值应与重构函数CollectObservations(VectorSensor sensor)中的总探测目标个数相等) |
Vector Observation :Stacked Vectors | 实际保留的探测目标量为(Space Size*Stacked Vectors)。 将其设为1以上是赋予agent有限的“memory”的简单方法,而无需添加递归神经网络(RNN)的复杂性。 |
Vector Action : Space Type | 动作空间的性质:Discrete离散 或 Continuous连续 |
Vector Action : Branches Size | 动作空间变量的个数(应与函数OnActionReceived(float[] vectorAction)中vectorActions数组中元素个数相同) |
Model | 训练好的模型(如果没有可以不用挂载) |
Inference Device | 训练的设备:CPU 或 GPU |
Behavior Type | 训练模式: default:可以训练也可以推理;Heuristic Only: 人机交互模式,该模式下训练好的模型和Training都将没有作用; Inference Only: 推理模式, 该模式下只能使用模型 |
Team ID | 用于定义self-play的team |
Use Child Sensors | 是否使用子物体的Sensor组件(我们后面要用子物体的RaySensor,所以要勾上 |
Max Step | 每个agent的最大步数。一旦达到此数目,OnEpisodeBegin() 讲被调用,重置Agent。 |
调整后的设置如下:
其中Agent Camera不是场景中的主摄像机,而需要在HummingBird下创建一个Camera子物体,具体设置如下:
注:训练时用不到这个摄像机,要将其屏蔽掉(即取消勾选其名字旁边的方框)
另外Audio_Listener可自行决定是否勾选(如果不想听声就去掉)
最后,在hummingBird之下再加一个脚本:Decision Requester.cs
用于为智能体做决定(Decision Period 设置为5即为每5步做一个决定)