


该脚本为要绑定在智能体上的脚本,该类继承自工具包ML Agents里Agent类而非MonoBehaviour类(需要在代码里修改),且需额外添加如下引用。

using Unity.MLAgents;
using Unity.MLAgents.Actuators;
using Unity.MLAgents.Sensors;



	[Tooltip("Force to apply when moving")]
    public float moveForce = 2f;

    [Tooltip("Speed to pitch up or down")]
    public float pitchSpeed = 100f;

    [Tooltip("Speed to rotate around the up axis")]
    public float yawSpeed = 100f;

    [Tooltip("Transform at the tip of the beak")]
    public Transform beakTip;

    [Tooltip("The agent's camera")]
    public Camera agentCamera;

    [Tooltip("Whether this is training mode or gameplay mode")]
    public bool trainingMode;

    //The rigidbody of the agent
    new private Rigidbody rigidbody;

    //The flower area that the agent is in
    private FlowerArea flowerArea;

    //The nearest flower to the agent
    private Flower nearestFlower;

    //Allows for smoother pitch changes
    private float smoothPitchChange = 0f;

    //Allows for smoother yaw changes
    private float smoothYawChange = 0f;

    //Maximun angle that the bird can pitch up or down
    private const float MaxPitchAngle = 80f;

    //Maximun distance from the beak tip to accept nectar collision
    private const float BeakTipRadius = 0.008f;

    //Whether the agent is frozen (intentionally not flying)
    private bool frozen = false;

    /// <summary>
    /// The amount of nectar the agent has obtained this episode
    /// </summary>
    public float NectarObtained { get; private set; }



    /// <summary>
    /// Initialize the agent
    /// </summary>
    public override void Initialize()
        rigidbody = GetComponent<Rigidbody>();
        flowerArea = GetComponentInParent<FlowerArea>();

        //If not training mode, no max step, play forever
        if (!trainingMode) MaxStep = 0;


 public override void OnEpisodeBegin()
        if (trainingMode)
        //Only reset flowers in training when there is one agent per area
        //Reset nectar obtained
        NectarObtained = 0f;

        //Zero out velocities so that movement stops before a new episode begins
        rigidbody.velocity = Vector3.zero;
        rigidbody.angularVelocity = Vector3.zero;

        bool inFrontOfFlower = true;
        if (trainingMode)
            //Spwan in front of flower 50% of the time during training
            inFrontOfFlower = UnityEngine.Random.value > 0.5f;

        //Move the agent to a new random position

        //Recalculate the nearest flower now that the agent has moved


/// <summary>
    /// Move the agent to a safe randomn position(i.e. does not collide with anything)
    /// If in front of flower, also point the beak at the flower
    /// </summary>
    /// <param name="inFrontOfFlower">Whether to choose a spot in front of a flower</param>
    private void MoveToSafeRandomPosition(bool inFrontOfFlower)
        bool safePositionFound = false;
        int attemptsRemaining = 100;//Prevent an infinite loop
        Vector3 potentialPosition = Vector3.zero;
        Quaternion potentialRotation = new Quaternion();

        //Loop until a safe position is found or we run out of attempts
        while(!safePositionFound && attemptsRemaining > 0)
            if (inFrontOfFlower)
                //Pick a random flower
                Flower randomFlower = flowerArea.Flowers[UnityEngine.Random.Range(0, flowerArea.Flowers.Count)];

                //Position 10 to 20 cm in front of the flower
                float distanceFromFlower = UnityEngine.Random.Range(.1f, .2f);
                potentialPosition = randomFlower.transform.position + randomFlower.FlowerUpVector * distanceFromFlower;

                //Point beak at flower(bird's head is center of transform)
                Vector3 toFlower = randomFlower.FlowerCenterPosition - potentialPosition;
                potentialRotation = Quaternion.LookRotation(toFlower, Vector3.up);
                //Pick a random height from the ground
                float height = UnityEngine.Random.Range(1.2f, 2.5f);

                //Pick a random radius from the center of the area
                float radius = UnityEngine.Random.Range(2f, 7f);

                //Pick a random direction rotated around the y axis
                Quaternion direction = Quaternion.Euler(0f, UnityEngine.Random.Range(-180f, 180f), 0f);

                //Combine height, radius, and direction to pick a potential position
                potentialPosition = flowerArea.transform.position + Vector3.up * height + direction * Vector3.forward * radius;

                //Choose and set random starting pitch and yaw
                float pitch = UnityEngine.Random.Range(-60f, 60f);
                float yaw = UnityEngine.Random.Range(-180f,180f);
                potentialRotation = Quaternion.Euler(pitch, yaw, 0f);

            //Check to see if the agent will collide with anything
            Collider[] colliders = Physics.OverlapSphere(potentialPosition, 0.05f);

            //Safe position has been found if no colliders are overlapped
            safePositionFound = colliders.Length == 0;

        Debug.Assert(safePositionFound, "Could not find a safe position to spawn");

        //Set the position and rotation
        transform.position = potentialPosition;
        transform.rotation = potentialRotation;


 private void UpdateNearestFlower()
        foreach (Flower flower in flowerArea.Flowers)
        if (nearestFlower==null && flower.HasNectar)
                //No current nearest flower and this flower has nectar,so set to this flower
                nearestFlower = flower;
        else if (flower.HasNectar)
                //Calculate distance to this flower and distance to the current nearest flower
                float distanceToFlower = Vector3.Distance(flower.transform.position, beakTip.position);
                float distanceTocurrentNearestFlower = Vector3.Distance(nearestFlower.transform.position, beakTip.position);

                //If current nearest flower is empty OR this flower is closer,update the nearest flower
                if (!nearestFlower.HasNectar || distanceToFlower<distanceTocurrentNearestFlower)
                    nearestFlower = flower;


	/// <summary>
    /// Called when and action is received from either the player input or the neural network
    /// vectorAction[i] represents:
    /// Index 0: move vector x (+1 = right, -1 = left)
    /// Index 1: move vector y (+1 = up, -1 = down)
    /// Index 2: move vector z (+1 = forward, -1 = backward)
    /// Index 3: pitch angle (+1 = pitch up, -1 = pitch down)
    /// Index 4: yaw angle x (+1 = turn right, -1 = turn left)
    /// </summary>
    /// <param name="vectorAction">The actions to take</param>
public override void OnActionReceived(ActionBuffers actions)
        var vectorAction = actions.ContinuousActions;
        //Don't take actios if frozen
        if (frozen) return;

        //Calculate movement vector
        Vector3 move = new Vector3(vectorAction[0], vectorAction[1], vectorAction[2]);

        //Add force in the direction of the move vector
        rigidbody.AddForce(move * moveForce);

        //Get the current rotation
        Vector3 rotationVector = transform.rotation.eulerAngles;

        //Calculate pitch and yaw rotation
        float pitchChange = vectorAction[3];
        float yawChange = vectorAction[4];

        //Calculate smooth rotation changes
        smoothPitchChange = Mathf.MoveTowards(smoothPitchChange, pitchChange, 2f * Time.fixedDeltaTime);
        smoothYawChange = Mathf.MoveTowards(smoothYawChange, yawChange, 2f * Time.fixedDeltaTime);

        //Calculate new pitch and yaw based on smoothed values
        //Clamp pitch to avoid flipping upside down
        float pitch = rotationVector.x + smoothPitchChange * Time.fixedDeltaTime * pitchSpeed;
        if (pitch > 180f) pitch -= 360f;
        pitch = Mathf.Clamp(pitch, -MaxPitchAngle, MaxPitchAngle);

        float yaw = rotationVector.y + smoothYawChange * Time.fixedDeltaTime * yawSpeed;

        //Apply the new rotation
        transform.rotation = Quaternion.Euler(pitch, yaw, 0f);


	/// <summary>
    /// Collect vector observations from the environment
    /// </summary>
    /// <param name="sensor">The vector sensor</param>
 public override void CollectObservations(VectorSensor sensor)
        //If nearestFlower is null, observe an empty array and return early
        if (nearestFlower == null)
            sensor.AddObservation(new float[10]);
        //Obser the agent's local rotation (4 observations)

        //Get a vector from the beak tip to the nearest flower
        Vector3 toFlower = nearestFlower.FlowerCenterPosition - beakTip.position;

        //Observe a normalized vector pointing to the nearest flower(3 observations)

        //Observe a dot product that indicates whether the beak tip is in front of the flower(1 observation)
        //(+1 means that the beak tip is directly in front of the flower, -1 means directly behind)
        sensor.AddObservation(Vector3.Dot(toFlower.normalized, -nearestFlower.FlowerUpVector.normalized));

        //Observe a dot product that indicates whether the beak is pointing toward the flower(1 observation)
        //(+1 means that the beak is pointing directly at the flower, -1 means directly away)
        sensor.AddObservation(Vector3.Dot(beakTip.forward.normalized, -nearestFlower.FlowerUpVector.normalized));

        //Observe the relative distance from the beak tip to the flower(1 observation)
        sensor.AddObservation(toFlower.magnitude / FlowerArea.AreaDiameter);

        //10 total observation


 	/// <summary>
    /// When Behavior Type is set to "Heuristic Only" on the agent's Behavior Parameters,
    /// this function will be called. Its return values will be fed into
    /// <see cref="="OnActionReceived(float[])"/> instead of using the neural network
    /// </summary>
    /// <param name="actionsOut">And output action array</param>
    public override void Heuristic(in ActionBuffers actionsOut)
        //Create placeholders for all movement/turning
        Vector3 forward = Vector3.zero;
        Vector3 left = Vector3.zero;
        Vector3 up = Vector3.zero;
        float pitch = 0f;
        float yaw = 0f;

        //Convert keyboard inputs to movement and turning
        //All values should be netween -1 and +1

        if (Input.GetKey(KeyCode.W)) forward = transform.forward;
        else if (Input.GetKey(KeyCode.S)) forward = -transform.forward;

        if (Input.GetKey(KeyCode.A)) left = -transform.right;
        else if (Input.GetKey(KeyCode.D)) left = transform.right;

        if (Input.GetKey(KeyCode.E)) up = transform.up;
        else if (Input.GetKey(KeyCode.C)) up = -transform.up;

        //Pitch up/down
        if (Input.GetKey(KeyCode.UpArrow)) pitch = 1f;
        else if (Input.GetKey(KeyCode.DownArrow)) pitch = -1f;

        //Turn left/right
        if (Input.GetKey(KeyCode.LeftArrow)) yaw = -1f;
        else if (Input.GetKey(KeyCode.RightArrow)) yaw = 1f;

        //Combine the movement vectors and normalize
        Vector3 combined = (forward + left + up).normalized;

        //Add the 3 movement values, pitch, and yaw to the actionsOut array
        var actions = actionsOut.ContinuousActions;
        actions[0] = combined.x;
        actions[1] = combined.y;
        actions[2] = combined.z;
        actions[3] = pitch;
        actions[4] = yaw;


	/// <summary>
    /// Prevent the agent from moving and taking actions
    /// </summary>
    public void FreezeAgent()
        Debug.Assert(trainingMode == false, "Freeze/Unfreeze not supported in training");
        frozen = true;

    /// <summary>
    /// Resume agent movement and actions
    /// </summary>
    public void UnFreezeAgent()
        Debug.Assert(trainingMode == false, "Freeze/Unfreeze not supported in training");
        frozen = false;


	/// <summary>
    /// Called when the agent's collider enters a trigger collider
    /// </summary>
    /// <param name="other">The trigger collider</param>
    private void OnTriggerEnter(Collider other)

    /// <summary>
    /// Called when the agent's collider stays in a trigger collider
    /// </summary>
    /// <param name="other">The trigger collider</param>
    private void OnTriggerStay(Collider other)


 	/// <summary>
    /// Handles when the agent's collider enters or stays in a trigger collider
    /// </summary>
    /// <param name="collider"The trigger collider</param>
    private void TriggerEnterOrStay(Collider collider)
        //Check if agent is colliding with nectar
        if (collider.CompareTag("nectar"))
            Vector3 closestPointToBeakTip = collider.ClosestPoint(beakTip.position);

            //Check if the closest collision point is close to the beak tip
            //Note: a collision with anything but the beak tip should not count
            if (Vector3.Distance(beakTip.position, closestPointToBeakTip) < BeakTipRadius)
                //Look up the flower for this nectar collider
                Flower flower = flowerArea.GetFlowerFromNectar(collider);

                //Attempt to take .01 nectar
                //Note: this is per fixed timestep, meaning it happens every .02 seconds, or 50x per second
                float nectarReceived = flower.Feed(.01f);

                //Keep track of nectar obtained
                NectarObtained += nectarReceived;

                if (trainingMode)
                    //Calculate rewaed for getting nectar
                    float bonus = .2f * Mathf.Clamp01(Vector3.Dot(transform.forward.normalized, -nearestFlower.FlowerUpVector.normalized));
                    AddReward(0.01f + bonus);

                //If flower is empty, update the nearest flower


   /// <summary>
    /// Called when the agent collides with something solid
    /// </summary>
    /// <param name="collision">The collision info</param>
    private void OnCollisionEnter(Collision collision)
        if (trainingMode && collision.collider.CompareTag("boundary"))
            //Collided with the area boundary, give a negative reward
            AddReward(-.5f);        }


    /// <summary>
    /// Called every frame
    /// </summary>
    private void Update()
        //Draw a line from the beak tip to the nearest flower
        if (nearestFlower != null)
            Debug.DrawLine(beakTip.position, nearestFlower.FlowerCenterPosition, Color.green);


    /// <summary>
    /// Called every .02 seconds
    /// </summary>
    private void FixedUpdate()
        if (nearestFlower != null && !nearestFlower.HasNectar)



  • Behavior Name修改为Hummingbird,行为的标识符。具有相同行为名称的智能体将学习相同的策略,需要与后续的配置文件内名称保持一致。

  • Space Size为智能体状态矢量的长度,根据上述智能体观测空间的定义函数可知,该值为10。

  • Stacked Vectors不作修改,其表示将多少个状态矢量堆叠后输入神经网络。

  • Continuous Actions:智能体执行的连续动作数量,这里为六自由度运动去掉roll,故为5。

  • Discrete Branches:智能体执行的离散动作,一个整数数组。数组中的值对应于每个动作分支的离散值数量。这里没有离散动作,填0。

  • Model:用于推理的神经网络模型(经过训练后获得)训练时不需要加载,保持None。 Behavior

  • Behavior Type:为Default时,如果有Pytorch连接,则为训练行为,否则为推理行为;选为Heuristic
    Only时,只能手动控制;选为Inference Only时,仅能推理控制。


  • 我们把最大迭代步数改为5000,在5000步之后更新回合。
  • 力和速度保持默认。
  • 双击Assets/Hummingbird/Prefabs/Hummingbird,进入Hummingbird的层级窗口,在层级窗口将BeakTip拖拽到选项卡里Beak Tip框里。
  • 在Hummingbird层级窗口里添加一个相机,按下图调整一下它的位置和范围,并关闭Audio Listener,本场景中没有音频。
  • 添加完成后,将其拖拽到Agent Camera框里。
  • 勾选训练模式。

最后,需要为智能体再添加一个组件,Decision Requester,该控件用于请求决策,是训练所必需的。Decision Period为5代表神经网络每5个step做一次决策。

