创建C#脚本,包括如何手动控制智能体,智能体的观测空间、动作空间,强化学习的奖励函数等。
注:本文所有代码均来自于官方入门教程,仅作学习使用。
一、HummingbridAgent.cs
该脚本为要绑定在智能体上的脚本,该类继承自工具包ML Agents里Agent类而非MonoBehaviour类(需要在代码里修改),且需额外添加如下引用。
using Unity.MLAgents;
using Unity.MLAgents.Actuators;
using Unity.MLAgents.Sensors;
1.属性
类的属性及其含义如下所示。
[Tooltip("Force to apply when moving")]
public float moveForce = 2f;
[Tooltip("Speed to pitch up or down")]
public float pitchSpeed = 100f;
[Tooltip("Speed to rotate around the up axis")]
public float yawSpeed = 100f;
[Tooltip("Transform at the tip of the beak")]
public Transform beakTip;
[Tooltip("The agent's camera")]
public Camera agentCamera;
[Tooltip("Whether this is training mode or gameplay mode")]
public bool trainingMode;
//The rigidbody of the agent
new private Rigidbody rigidbody;
//The flower area that the agent is in
private FlowerArea flowerArea;
//The nearest flower to the agent
private Flower nearestFlower;
//Allows for smoother pitch changes
private float smoothPitchChange = 0f;
//Allows for smoother yaw changes
private float smoothYawChange = 0f;
//Maximun angle that the bird can pitch up or down
private const float MaxPitchAngle = 80f;
//Maximun distance from the beak tip to accept nectar collision
private const float BeakTipRadius = 0.008f;
//Whether the agent is frozen (intentionally not flying)
private bool frozen = false;
/// <summary>
/// The amount of nectar the agent has obtained this episode
/// </summary>
public float NectarObtained { get; private set; }
2.方法
初始化智能体。
/// <summary>
/// Initialize the agent
/// </summary>
public override void Initialize()
{
rigidbody = GetComponent<Rigidbody>();
flowerArea = GetComponentInParent<FlowerArea>();
//If not training mode, no max step, play forever
if (!trainingMode) MaxStep = 0;
}
回合初始化函数
public override void OnEpisodeBegin()
{
if (trainingMode)
{
//Only reset flowers in training when there is one agent per area
flowerArea.ResetFlowers();
}
//Reset nectar obtained
NectarObtained = 0f;
//Zero out velocities so that movement stops before a new episode begins
rigidbody.velocity = Vector3.zero;
rigidbody.angularVelocity = Vector3.zero;
bool inFrontOfFlower = true;
if (trainingMode)
{
//Spwan in front of flower 50% of the time during training
inFrontOfFlower = UnityEngine.Random.value > 0.5f;
}
//Move the agent to a new random position
MoveToSafeRandomPosition(inFrontOfFlower);
//Recalculate the nearest flower now that the agent has moved
UpdateNearestFlower();
}
回合初始化函数里调用了两个函数,一个为随机移动智能体到一个安全的位置,
/// <summary>
/// Move the agent to a safe randomn position(i.e. does not collide with anything)
/// If in front of flower, also point the beak at the flower
/// </summary>
/// <param name="inFrontOfFlower">Whether to choose a spot in front of a flower</param>
private void MoveToSafeRandomPosition(bool inFrontOfFlower)
{
bool safePositionFound = false;
int attemptsRemaining = 100;//Prevent an infinite loop
Vector3 potentialPosition = Vector3.zero;
Quaternion potentialRotation = new Quaternion();
//Loop until a safe position is found or we run out of attempts
while(!safePositionFound && attemptsRemaining > 0)
{
attemptsRemaining--;
if (inFrontOfFlower)
{
//Pick a random flower
Flower randomFlower = flowerArea.Flowers[UnityEngine.Random.Range(0, flowerArea.Flowers.Count)];
//Position 10 to 20 cm in front of the flower
float distanceFromFlower = UnityEngine.Random.Range(.1f, .2f);
potentialPosition = randomFlower.transform.position + randomFlower.FlowerUpVector * distanceFromFlower;
//Point beak at flower(bird's head is center of transform)
Vector3 toFlower = randomFlower.FlowerCenterPosition - potentialPosition;
potentialRotation = Quaternion.LookRotation(toFlower, Vector3.up);
}
else
{
//Pick a random height from the ground
float height = UnityEngine.Random.Range(1.2f, 2.5f);
//Pick a random radius from the center of the area
float radius = UnityEngine.Random.Range(2f, 7f);
//Pick a random direction rotated around the y axis
Quaternion direction = Quaternion.Euler(0f, UnityEngine.Random.Range(-180f, 180f), 0f);
//Combine height, radius, and direction to pick a potential position
potentialPosition = flowerArea.transform.position + Vector3.up * height + direction * Vector3.forward * radius;
//Choose and set random starting pitch and yaw
float pitch = UnityEngine.Random.Range(-60f, 60f);
float yaw = UnityEngine.Random.Range(-180f,180f);
potentialRotation = Quaternion.Euler(pitch, yaw, 0f);
}
//Check to see if the agent will collide with anything
//在Unity中,Physics.OverlapSphere方法用于检测在给定球形区域内与特定碰撞体发生重叠的其他碰撞体。
Collider[] colliders = Physics.OverlapSphere(potentialPosition, 0.05f);
//Safe position has been found if no colliders are overlapped
safePositionFound = colliders.Length == 0;
}
Debug.Assert(safePositionFound, "Could not find a safe position to spawn");
//Set the position and rotation
transform.position = potentialPosition;
transform.rotation = potentialRotation;
}
另一个为重新计算距离智能体最近的花朵。
private void UpdateNearestFlower()
{
foreach (Flower flower in flowerArea.Flowers)
if (nearestFlower==null && flower.HasNectar)
{
//No current nearest flower and this flower has nectar,so set to this flower
nearestFlower = flower;
}
else if (flower.HasNectar)
{
//Calculate distance to this flower and distance to the current nearest flower
float distanceToFlower = Vector3.Distance(flower.transform.position, beakTip.position);
float distanceTocurrentNearestFlower = Vector3.Distance(nearestFlower.transform.position, beakTip.position);
//If current nearest flower is empty OR this flower is closer,update the nearest flower
if (!nearestFlower.HasNectar || distanceToFlower<distanceTocurrentNearestFlower)
{
nearestFlower = flower;
}
}
}
下面函数定义了智能体的动作函数。
/// <summary>
/// Called when and action is received from either the player input or the neural network
///
/// vectorAction[i] represents:
/// Index 0: move vector x (+1 = right, -1 = left)
/// Index 1: move vector y (+1 = up, -1 = down)
/// Index 2: move vector z (+1 = forward, -1 = backward)
/// Index 3: pitch angle (+1 = pitch up, -1 = pitch down)
/// Index 4: yaw angle x (+1 = turn right, -1 = turn left)
/// </summary>
/// <param name="vectorAction">The actions to take</param>
public override void OnActionReceived(ActionBuffers actions)
{
var vectorAction = actions.ContinuousActions;
//Don't take actios if frozen
if (frozen) return;
//Calculate movement vector
Vector3 move = new Vector3(vectorAction[0], vectorAction[1], vectorAction[2]);
//Add force in the direction of the move vector
rigidbody.AddForce(move * moveForce);
//Get the current rotation
Vector3 rotationVector = transform.rotation.eulerAngles;
//Calculate pitch and yaw rotation
float pitchChange = vectorAction[3];
float yawChange = vectorAction[4];
//Calculate smooth rotation changes
smoothPitchChange = Mathf.MoveTowards(smoothPitchChange, pitchChange, 2f * Time.fixedDeltaTime);
smoothYawChange = Mathf.MoveTowards(smoothYawChange, yawChange, 2f * Time.fixedDeltaTime);
//Calculate new pitch and yaw based on smoothed values
//Clamp pitch to avoid flipping upside down
float pitch = rotationVector.x + smoothPitchChange * Time.fixedDeltaTime * pitchSpeed;
if (pitch > 180f) pitch -= 360f;
pitch = Mathf.Clamp(pitch, -MaxPitchAngle, MaxPitchAngle);
float yaw = rotationVector.y + smoothYawChange * Time.fixedDeltaTime * yawSpeed;
//Apply the new rotation
transform.rotation = Quaternion.Euler(pitch, yaw, 0f);
}
下面方法定义了智能体的观测空间。
/// <summary>
/// Collect vector observations from the environment
/// </summary>
/// <param name="sensor">The vector sensor</param>
public override void CollectObservations(VectorSensor sensor)
{
//If nearestFlower is null, observe an empty array and return early
if (nearestFlower == null)
{
sensor.AddObservation(new float[10]);
return;
}
//Obser the agent's local rotation (4 observations)
sensor.AddObservation(transform.localRotation.normalized);
//Get a vector from the beak tip to the nearest flower
Vector3 toFlower = nearestFlower.FlowerCenterPosition - beakTip.position;
//Observe a normalized vector pointing to the nearest flower(3 observations)
sensor.AddObservation(toFlower.normalized);
//Observe a dot product that indicates whether the beak tip is in front of the flower(1 observation)
//(+1 means that the beak tip is directly in front of the flower, -1 means directly behind)
sensor.AddObservation(Vector3.Dot(toFlower.normalized, -nearestFlower.FlowerUpVector.normalized));
//Observe a dot product that indicates whether the beak is pointing toward the flower(1 observation)
//(+1 means that the beak is pointing directly at the flower, -1 means directly away)
sensor.AddObservation(Vector3.Dot(beakTip.forward.normalized, -nearestFlower.FlowerUpVector.normalized));
//Observe the relative distance from the beak tip to the flower(1 observation)
sensor.AddObservation(toFlower.magnitude / FlowerArea.AreaDiameter);
//10 total observation
}
下面的方法定义手动操作智能体的逻辑。
/// <summary>
/// When Behavior Type is set to "Heuristic Only" on the agent's Behavior Parameters,
/// this function will be called. Its return values will be fed into
/// <see cref="="OnActionReceived(float[])"/> instead of using the neural network
/// </summary>
/// <param name="actionsOut">And output action array</param>
public override void Heuristic(in ActionBuffers actionsOut)
{
//Create placeholders for all movement/turning
Vector3 forward = Vector3.zero;
Vector3 left = Vector3.zero;
Vector3 up = Vector3.zero;
float pitch = 0f;
float yaw = 0f;
//Convert keyboard inputs to movement and turning
//All values should be netween -1 and +1
//Forward/backward
if (Input.GetKey(KeyCode.W)) forward = transform.forward;
else if (Input.GetKey(KeyCode.S)) forward = -transform.forward;
//Left/right
if (Input.GetKey(KeyCode.A)) left = -transform.right;
else if (Input.GetKey(KeyCode.D)) left = transform.right;
//Up/down
if (Input.GetKey(KeyCode.E)) up = transform.up;
else if (Input.GetKey(KeyCode.C)) up = -transform.up;
//Pitch up/down
if (Input.GetKey(KeyCode.UpArrow)) pitch = 1f;
else if (Input.GetKey(KeyCode.DownArrow)) pitch = -1f;
//Turn left/right
if (Input.GetKey(KeyCode.LeftArrow)) yaw = -1f;
else if (Input.GetKey(KeyCode.RightArrow)) yaw = 1f;
//Combine the movement vectors and normalize
Vector3 combined = (forward + left + up).normalized;
//Add the 3 movement values, pitch, and yaw to the actionsOut array
var actions = actionsOut.ContinuousActions;
actions[0] = combined.x;
actions[1] = combined.y;
actions[2] = combined.z;
actions[3] = pitch;
actions[4] = yaw;
}
下面两个方法用于冻结智能体和取消冻结。
/// <summary>
/// Prevent the agent from moving and taking actions
/// </summary>
public void FreezeAgent()
{
Debug.Assert(trainingMode == false, "Freeze/Unfreeze not supported in training");
frozen = true;
rigidbody.Sleep();
}
/// <summary>
/// Resume agent movement and actions
/// </summary>
public void UnFreezeAgent()
{
Debug.Assert(trainingMode == false, "Freeze/Unfreeze not supported in training");
frozen = false;
rigidbody.WakeUp();
}
下面两个方法在设为触发器的碰撞体开始发生碰撞时以及碰撞持续时调用。
/// <summary>
/// Called when the agent's collider enters a trigger collider
/// </summary>
/// <param name="other">The trigger collider</param>
private void OnTriggerEnter(Collider other)
{
TriggerEnterOrStay(other);
}
/// <summary>
/// Called when the agent's collider stays in a trigger collider
/// </summary>
/// <param name="other">The trigger collider</param>
private void OnTriggerStay(Collider other)
{
TriggerEnterOrStay(other);
}
其均调用了下面的方法,该方法的作用为处理触发器事件。在该案例中,仅包含花蜜一个触发器。该方法中定义了智能体采到蜜的奖励。别忘记在Unity中更改FlowerNectarCollider
的Tag为nectar
。
/// <summary>
/// Handles when the agent's collider enters or stays in a trigger collider
/// </summary>
/// <param name="collider"The trigger collider</param>
private void TriggerEnterOrStay(Collider collider)
{
//Check if agent is colliding with nectar
if (collider.CompareTag("nectar"))
{
Vector3 closestPointToBeakTip = collider.ClosestPoint(beakTip.position);
//Check if the closest collision point is close to the beak tip
//Note: a collision with anything but the beak tip should not count
if (Vector3.Distance(beakTip.position, closestPointToBeakTip) < BeakTipRadius)
{
//Look up the flower for this nectar collider
Flower flower = flowerArea.GetFlowerFromNectar(collider);
//Attempt to take .01 nectar
//Note: this is per fixed timestep, meaning it happens every .02 seconds, or 50x per second
float nectarReceived = flower.Feed(.01f);
//Keep track of nectar obtained
NectarObtained += nectarReceived;
if (trainingMode)
{
//Calculate rewaed for getting nectar
float bonus = .2f * Mathf.Clamp01(Vector3.Dot(transform.forward.normalized, -nearestFlower.FlowerUpVector.normalized));
AddReward(0.01f + bonus);
}
//If flower is empty, update the nearest flower
if(!flower.HasNectar)
{
UpdateNearestFlower();
}
}
}
}
下面的方法在智能体与其它实体发生碰撞时调用,其中定义了智能体撞到边界时的奖励。别忘记在Unity中更改IslandBoundaries
的Tag为boundary
。
/// <summary>
/// Called when the agent collides with something solid
/// </summary>
/// <param name="collision">The collision info</param>
private void OnCollisionEnter(Collision collision)
{
if (trainingMode && collision.collider.CompareTag("boundary"))
{
//Collided with the area boundary, give a negative reward
AddReward(-.5f); }
}
下面的方法在屏幕每次刷新时更新,作用为绘制一条从蜂鸟喙到最近花之间的绿色的线。
/// <summary>
/// Called every frame
/// </summary>
private void Update()
{
//Draw a line from the beak tip to the nearest flower
if (nearestFlower != null)
Debug.DrawLine(beakTip.position, nearestFlower.FlowerCenterPosition, Color.green);
}
下面的方法每固定间隔更新一次,作用为当蜂鸟采完当前花的花蜜时,寻找另一朵最近的有花蜜花。
/// <summary>
/// Called every .02 seconds
/// </summary>
private void FixedUpdate()
{
if (nearestFlower != null && !nearestFlower.HasNectar)
UpdateNearestFlower();
}
3.绑定脚本
在Unity编辑器中将该脚本绑定到Hummingbird
上。绑定之后会增加两个选项卡,如下。
在行为参数选项卡里,我们将:
-
Behavior Name修改为Hummingbird,行为的标识符。具有相同行为名称的智能体将学习相同的策略,需要与后续的配置文件内名称保持一致。
-
Space Size为智能体状态矢量的长度,根据上述智能体观测空间的定义函数可知,该值为10。
-
Stacked Vectors不作修改,其表示将多少个状态矢量堆叠后输入神经网络。
-
Continuous Actions:智能体执行的连续动作数量,这里为六自由度运动去掉roll,故为5。
-
Discrete Branches:智能体执行的离散动作,一个整数数组。数组中的值对应于每个动作分支的离散值数量。这里没有离散动作,填0。
-
Model:用于推理的神经网络模型(经过训练后获得)训练时不需要加载,保持None。 Behavior
-
Behavior Type:为Default时,如果有Pytorch连接,则为训练行为,否则为推理行为;选为Heuristic
Only时,只能手动控制;选为Inference Only时,仅能推理控制。
修改结果如下图所示。
在第二个选项卡里:
- 我们把最大迭代步数改为5000,在5000步之后更新回合。
- 力和速度保持默认。
- 双击Assets/Hummingbird/Prefabs/Hummingbird,进入Hummingbird的层级窗口,在层级窗口将BeakTip拖拽到选项卡里Beak Tip框里。
- 在Hummingbird层级窗口里添加一个相机,按下图调整一下它的位置和范围,并关闭Audio Listener,本场景中没有音频。
- 添加完成后,将其拖拽到Agent Camera框里。
- 勾选训练模式。
修改完的参数如下图所示。
最后,需要为智能体再添加一个组件,Decision Requester,该控件用于请求决策,是训练所必需的。Decision Period为5代表神经网络每5个step做一次决策。