unity learn—— ML_Agent:Hummingbirds中文教程 自译(三——代码HummingbirdAgent.cs)

由于官方教程是全英版,本文为根据个人理解做的中文版
(能力有限,有问题的地方还望指出,同时欢迎志同道合的朋友参与讨论,谢谢!)
官方教程:https://learn.unity.com/course/ml-agents-hummingbirds?uv=2019.3

系列文章:unity learn—— ML_Agent:Hummingbirds中文教程 自译(一 ——配置Unity)
unity learn—— ML_Agent:Hummingbirds中文教程 自译(二——代码Flowers.cs、FlowerArea.cs)

unity learn—— ML_Agent:Hummingbirds中文教程 自译(四——Ray Perception Sensor)

unity learn—— ML_Agent:Hummingbirds中文教程 自译(五 ——安装Anaconda)
unity learn—— ML_Agent:Hummingbirds中文教程 自译(六——准备训练)
unity learn—— ML_Agent:Hummingbirds中文教程 自译(七——测试模型)

HummingbirdAgent.cs是unity与python API交互的最重要的代码,我称之为Agent代码

Note:
Agent代码创建之后的共性操作:

  1. 引入两个库

    using Unity.MLAgents;
    using Unity.MLAgents.Sensors;

  2. 父物体由MonoBehaviour 改为 Agent.

  3. 删除Update()和Start()函数(不过这点不绝对,有的Agent代码还是需要Update()函数,而Start()则是被Initialize()代替)

  4. 需要重构三个函数

    OnEpisodeBegin() : 初始化函数
    CollectObservations(VectorSensor sensor):用于获取环境参数
    OnActionReceived(float[] vectorAction):动作函数

以下是蜂鸟移动方式的示意图,以便你理解代码
在这里插入图片描述
HummingbirdAgent.cs

using System.Collections;
using System.Collections.Generic;
using UnityEngine;
using Unity.MLAgents;       //注意这里导入了MLAgents的库
using Unity.MLAgents.Sensors;

/// <summary>
/// 蜂鸟机器学习
/// </summary>
public class HummingbirdAgent : Agent   //继承也变了
{
    [Tooltip("用于移动的力")]
    public float moveForce = 2f;
    //pitch和yaw方向见上图
    [Tooltip("pitch方向旋转的速度")]
    public float pitchSpeed = 100f;
    [Tooltip("yaw方向旋转的速度")]
    public float yawSpeed = 100f;
    [Tooltip("鸟喙尖端处的transfrom")]
    public Transform beakTip;
    [Tooltip("智能体的摄像头")]      //对于该项目,以下所有智能体均指蜂鸟
    public Camera agentCamera;      
    [Tooltip("是训练模式?否为游戏模式")]
    public bool trainingMode;

    new private Rigidbody rigidbody;    //智能体的刚体
    private FlowerArea flowerArea;  //智能体所在的空间
    private Flower nearestFlower;   //离智能体最近的花 

   //两个用于平滑移动的变量
    private float smoothPitchChange = 0f; //平滑pitch方向
    private float smoothYawChange = 0f;  //平滑yaw方向

    private const float maxPitchAngle = 80f;    //pitch方向最大角度
    private const float beakTipRadius = 0.008f; //鸟喙的直径
    private bool frozen = false;    //是否冻结智能体的行动

    /// <summary>
    /// 智能体获得的花粉数量
    /// </summary>
    public float NectarObtained { get; private set; }

    /// <summary>
    /// 初始化变量
    /// </summary>
    public override void Initialize()
    {
        rigidbody = GetComponent<Rigidbody>();
        flowerArea = GetComponentInParent<FlowerArea>();    //该句成功的关键是鸟必须是挂载了flowerArea脚本的物体的第一级子物体

        //如果不是训练模式
        if (!trainingMode) MaxStep = 0; //MaxStep定义在Agent中
    }

    /// <summary>
    /// 当玩家输入或神经网络控制时调用
    /// 
    /// vectorAction[i] represents:
    /// Index 0: move vector x (+1 = right, -1 = left)
    /// Index 1: move vector y (+1 = up, -1 = down)
    /// Index 2: move vector z (+1 = forward, -1 = backward)
    /// Index 3: pitch angle (+1 = pitch up, -1 = pitch down)
    /// Index 4: yaw angle (+1 = turn right, -1 = turn left)
    /// </summary>
    /// <param name="vectorAction">将要采取的动作</param>
    public override void OnActionReceived(float[] vectorAction)
    {
        //如果冻结,则不采取任何动作
        if (frozen) { return; }
        //移动方向
        Vector3 move = new Vector3(vectorAction[0], vectorAction[1], vectorAction[2]);
        //向移动方向施加力
        rigidbody.AddForce(move * moveForce);


        //获取pitch和yaw方向上的移动量
        float pitchChange = vectorAction[3];
        float yawChange = vectorAction[4];
        //平滑旋转变化,防抖动
        smoothPitchChange = Mathf.MoveTowards(smoothPitchChange, pitchChange, 2f * Time.fixedDeltaTime);    //注:为了更贴合物理,
        smoothYawChange = Mathf.MoveTowards(smoothYawChange, yawChange, 2f * Time.fixedDeltaTime);          //不是Time.deltaTime
        
        //获得目前rotation
        Vector3 rotationVector = transform.rotation.eulerAngles;
        //计算新的pitch和yaw方向上的移动量
        float pitch = rotationVector.x + smoothPitchChange * Time.fixedDeltaTime * pitchSpeed;
        float yaw = rotationVector.y + smoothYawChange * Time.fixedDeltaTime * yawSpeed;    //pithc方向需要限制范围,而yaw方向不需要
        if (pitch > 180f) pitch -= 360f;
        pitch = Mathf.Clamp(pitch, -maxPitchAngle, maxPitchAngle);
        //旋转对应的角度
        transform.rotation = Quaternion.Euler(pitch, yaw, 0f);
    }

    /// <summary>
    /// 获取环境的坐标数据
    /// </summary>
    /// <param name="sensor">坐标探测器</param>
    public override void CollectObservations(VectorSensor sensor)
    {
    	//为了便于训练,以下observation都做了归一化
			
        //如果nearestFlower是空的,就观测十个空浮点数,之后返回,防止报错
        if (nearestFlower == null)
        {
            sensor.AddObservation(new float[10]);
            return;
        }
        //观测localRotation和与nearestFlower的相对位置
        sensor.AddObservation(transform.localRotation.normalized);  //4个探测目标(归一化是为了简化分析)
        Vector3 toFlower = nearestFlower.FlowerCenterPosition - beakTip.position;
        sensor.AddObservation(toFlower.normalized);     //3个探测目标

        // 通过向量点积观测鸟喙是否在花蕊正前方(1个探测目标)
        // (+1 代表在花正前方, -1 代表在花正后面)
        sensor.AddObservation(Vector3.Dot(toFlower.normalized, -nearestFlower.FlowerUpVector.normalized));
        // 通过向量点积观测鸟喙是否指向花蕊正前方 (1个探测目标)
        // (+1代表正面指向花, -1 代表在花后面指向花)
        sensor.AddObservation(Vector3.Dot(beakTip.forward.normalized, -nearestFlower.FlowerUpVector.normalized));
        //观测鸟喙与花的相对场地直径的相对距离(1个探测目标)
        sensor.AddObservation(toFlower.magnitude / FlowerArea.AreaDiameter);

        // 一共10个探测目标
    }

    /// <summary>
    /// 当行为模式为“Heuristic Only”,即仅玩家探索时调用 
    /// 此时由玩家代替神经网络操作
    /// </summary>
    /// <param name="actionsOut">输出的行动数组</param>
    public override void Heuristic(float[] actionsOut)
    {
        // 创建所有用于移动的变量
        Vector3 forward = Vector3.zero;
        Vector3 left = Vector3.zero;
        Vector3 up = Vector3.zero;
        float pitch = 0f;
        float yaw = 0f;

        // 获取键盘输入并进行相应的移动控制
        // 所有数据都应在区间( -1 , +1 )内

        // 前后控制
        if (Input.GetKey(KeyCode.W)) forward = transform.forward;
        else if (Input.GetKey(KeyCode.S)) forward = -transform.forward;
        //左右控制
        if (Input.GetKey(KeyCode.A)) left = -transform.right;
        else if (Input.GetKey(KeyCode.D)) left = transform.right;
        //上下控制
        if (Input.GetKey(KeyCode.E)) up = transform.up;
        else if (Input.GetKey(KeyCode.Q)) up = -transform.up;
        //俯仰控制
        if (Input.GetKey(KeyCode.UpArrow)) pitch = 1f;
        else if (Input.GetKey(KeyCode.DownArrow)) pitch = -1f;
        //左右转身控制
        if (Input.GetKey(KeyCode.LeftArrow)) yaw = -1f;
        else if (Input.GetKey(KeyCode.RightArrow)) yaw = 1f;

        //合并移动变量并归一化
        Vector3 combined = (forward + left + up).normalized;

        // 将3个移动变量和两个旋转变量加到actionsOut数组中
        actionsOut[0] = combined.x;
        actionsOut[1] = combined.y;
        actionsOut[2] = combined.z;
        actionsOut[3] = pitch;
        actionsOut[4] = yaw;
    }


    /// <summary>
    /// 停止智能体移动
    /// </summary>
    public void FreezeAgent()
    {
        //训练模式下不支freeze
        Debug.Assert(trainingMode == false, "Freeze/Unfreeze not supported in training");
        frozen = true;
        rigidbody.Sleep();  //刚体也停止作用
    }
    /// <summary>
    /// 恢复智能体移动
    /// </summary>
    public void UnfreezeAgent()
    {
        Debug.Assert(trainingMode == false, "Freeze/Unfreeze not supported in training");
        frozen = false;
        rigidbody.WakeUp();
    }

    /// <summary>
    /// 当事件开始时重置智能体
    /// </summary>
    public override void OnEpisodeBegin()
    {
        if (trainingMode)
        {
            //当场景中只有一个智能体时,只重置花
            flowerArea.ResetFlower();
        }
        
        NectarObtained = 0f;
        //重要!!! 将所有速度归零(防止上次一情景的运动影响到现在)
        rigidbody.velocity = Vector3.zero;
        rigidbody.angularVelocity = Vector3.zero;

        bool inFrontofFlower = true;
        if (trainingMode)   //如果是训练模式
        {
            //50%可能性使智能体初始化时面对一朵花
            inFrontofFlower = UnityEngine.Random.value > .5f;
        }
        //将智能体移动到一个随机的安全位置(主要避免出现在树木,石头等物体内部)
        MoveToSafeRandomPosition(inFrontofFlower);
        //重新计算最近的花
        UpdateNearestFlower();
    }

    /// <summary>
    /// 将智能体移动到一个随机的安全位置(主要避免出现在树木,石头等物体内部)
    /// 如果inFrontofFlower == true,则将其面对一朵花
    /// </summary>
    /// <param name="inFrontofFlower">是否选择一个面对花的位置</param>
    private void MoveToSafeRandomPosition(bool inFrontofFlower)
    {
        bool safePositionFound = false; //是否发现安全位置
        int attemptsRemaining = 100;    //尝试的次数,超过该次数则认为出现了问题,防止无限循环
        //初始化位置状态
        Vector3 potentialPosition = Vector3.zero;
        Quaternion potentialRotation = new Quaternion();

        //循环到找到安全位置或者尝试次数超过限定值
        while (!safePositionFound && attemptsRemaining > 0)
        {
            attemptsRemaining--;
            if (inFrontofFlower)
            {
                //随机选取一朵花
                Flower randomFlower = flowerArea.Flowers[UnityEngine.Random.Range(0, flowerArea.Flowers.Count)];
                //生成位置距花10-20cm
                float distanceFromFlower = UnityEngine.Random.Range(.1f, .2f);
                potentialPosition = randomFlower.transform.position + randomFlower.FlowerUpVector * distanceFromFlower;
                //使智能体看向花
                Vector3 toFlower = randomFlower.FlowerCenterPosition - potentialPosition; //计算方向向量
                potentialRotation = Quaternion.LookRotation(toFlower, Vector3.up);
            }
            else
            {
                //随机Position
                float height = UnityEngine.Random.Range(1.2f, 2.5f);    //随机高度
                float radius = UnityEngine.Random.Range(2f, 7f);    //以区域中心为圆心的生成半径(之所以以2开始是为了避开中心的大树)
                Quaternion direction = Quaternion.Euler(0f, UnityEngine.Random.Range(-180f, 180f), 0f); //关于y轴的一个随机方向
                potentialPosition = flowerArea.transform.position + Vector3.up * height + direction * Vector3.forward * radius;

                //随机Rotation
                float pitch = UnityEngine.Random.Range(-60f, 60f);
                float yaw = UnityEngine.Random.Range(-180f, 180f);
                potentialRotation = Quaternion.Euler(pitch, yaw, 0f);
            }
            //检查是否与其他碰撞体重叠,不重叠就表示找到了安全的生成位置
            safePositionFound = !(Physics.CheckBox(potentialPosition, new Vector3(0.05f, 0.05f, 0.05f)));  //以一个中心为potentialPosition边长为10cm的正方体探测
        }
        Debug.Assert(safePositionFound, "Could not find a safe position to spawn"); //safePositionFound为false时发出警告
        //设置位置
        transform.position = potentialPosition;
        transform.rotation = potentialRotation;

    }

    /// <summary>
    /// 更新最近的花(个人觉得这个方法效率挺低的,感兴趣的同学可以优化一下)
    /// </summary>
    private void UpdateNearestFlower()
    {
        foreach (Flower flower in flowerArea.Flowers)
        {
            //如果最近的花nearestFlower没赋值,并且循环到的这朵花有花粉
            if (nearestFlower == null && flower.HasNectar)
            {
                nearestFlower = flower;
            }//如果最近的花nearestFlower已经赋值,并且循环到的这朵花有花粉
            else if (flower.HasNectar)
            {
                //计算并比较哪个是最近的花,然后更新nearestFlower 
                float distanceToFlower = Vector3.Distance(flower.transform.position, beakTip.position);
                float distanceToCurrentNearestFlower = Vector3.Distance(nearestFlower.transform.position, beakTip.position);
                //如果现有的nearestFlower没花粉了,或者它不再是最近的花了
                if (!nearestFlower.HasNectar || distanceToCurrentNearestFlower > distanceToFlower)
                {
                    nearestFlower = flower;
                }

            }
        }
    }

    private void OnTriggerEnter(Collider other)
    {
        TriggerEnterOrStay(other);
    }
    private void OnTriggerStay(Collider other)
    { 
        TriggerEnterOrStay(other);
    }
    private void TriggerEnterOrStay(Collider collider)
    {
        // 检查是否碰到了花蕊
        if (collider.CompareTag("nectar"))
        {
            Vector3 closestPointToBeakTip = collider.ClosestPoint(beakTip.position);//返回花蕊碰撞体上距离鸟喙最近的点

            // 检查是不是真的是鸟喙碰到了花蕊
            // Note: 除鸟喙之外,别的碰撞体碰到花蕊不算
            if (Vector3.Distance(beakTip.position, closestPointToBeakTip) < beakTipRadius)
            {
                // 通过花蕊找到对应的花
                Flower flower = flowerArea.GetFlowerFromNectar(collider);

                // 尝试"吃掉" 0.01 的花粉
                // Note: 每 0.02 秒调用一次
                float nectarReceived = flower.Feed(.01f);

                //  更新已获得的花粉(不是用于训练而是用在UI上)
                NectarObtained += nectarReceived;

                if (trainingMode)
                {
                    // 计算奖励
                    float bonus = .02f * Mathf.Clamp01(Vector3.Dot(transform.forward.normalized, -nearestFlower.FlowerUpVector.normalized));
                    AddReward(.01f + bonus);        //在bonus之外,只要是鸟喙碰到了花蕊就保证有0.01的基础奖励
                }

                // 如果花粉没了,换一朵花
                if (!flower.HasNectar)
                {
                    UpdateNearestFlower();
                }
            }
        }
    }
    private void OnCollisionEnter(Collision collision)
    {
        if (trainingMode && collision.collider.CompareTag("boundary"))
        {
            //碰到了障碍物或者区域边界,给负分
            AddReward(-.5f);
        }
    }

    private void Update()
    {
        // 在鸟喙和最近的花之间画一条线
        if (nearestFlower != null)
            Debug.DrawLine(beakTip.position, nearestFlower.FlowerCenterPosition, Color.green);
    }

    private void FixedUpdate()
    {
        // 防止最近花的花粉被对手“吃”光了,但是没有更新
        if (nearestFlower != null && !nearestFlower.HasNectar)
            UpdateNearestFlower();
    }
}

脚本编好后,需要返回unity再加两个tag:nectar和boundary
在这里插入图片描述
找到prefab中FloatingIsland的子物体IslandBoundary,为其设置tag:boundary
在这里插入图片描述
找到FlowerBud中的FlowerNectarCollider,设置tag:nectar
在这里插入图片描述
找到蜂鸟预制体,添加本次的脚本hummingbird.cs(它会自动添加另一个脚本Behavior Parameters)
在这里插入图片描述

关于Behavior Parameter:

变量注释
Behavior Name动作空间名:Default:可以训练,可以人机交互(就是玩家自己操作),也可以挂上训练好的模型做推理; Heuristic Only:仅人机交互; Inference Only:挂上训练好的模型后,推理用
Vector Observation :Space Size观测的个数(其值应与重构函数CollectObservations(VectorSensor sensor)中的总探测目标个数相等)
Vector Observation :Stacked Vectors实际保留的探测目标量为(Space Size*Stacked Vectors)。 将其设为1以上是赋予agent有限的“memory”的简单方法,而无需添加递归神经网络(RNN)的复杂性。
Vector Action : Space Type动作空间的性质:Discrete离散 或 Continuous连续
Vector Action : Branches Size动作空间变量的个数(应与函数OnActionReceived(float[] vectorAction)中vectorActions数组中元素个数相同)
Model训练好的模型(如果没有可以不用挂载)
Inference Device训练的设备:CPU 或 GPU
Behavior Type训练模式: default:可以训练也可以推理;Heuristic Only: 人机交互模式,该模式下训练好的模型和Training都将没有作用; Inference Only: 推理模式, 该模式下只能使用模型
Team ID用于定义self-play的team
Use Child Sensors是否使用子物体的Sensor组件(我们后面要用子物体的RaySensor,所以要勾上
Max Step每个agent的最大步数。一旦达到此数目,OnEpisodeBegin() 讲被调用,重置Agent。

调整后的设置如下:
在这里插入图片描述
其中Agent Camera不是场景中的主摄像机,而需要在HummingBird下创建一个Camera子物体,具体设置如下:
在这里插入图片描述
在这里插入图片描述

注:训练时用不到这个摄像机,要将其屏蔽掉(即取消勾选其名字旁边的方框)
在这里插入图片描述
另外Audio_Listener可自行决定是否勾选(如果不想听声就去掉)

最后,在hummingBird之下再加一个脚本:Decision Requester.cs
用于为智能体做决定(Decision Period 设置为5即为每5步做一个决定)
在这里插入图片描述

  • 3
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

昼行plus

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值