ML-Agents案例之Crawler

本案例源自ML-Agents官方的示例,Github地址:https://github.com/Unity-Technologies/ml-agents

本文基于我前面发的两篇文章,需要对ML-Agents有一定的了解,详情请见:Unity强化学习之ML-Agents的使用ML-Agents命令及配置大全

参考资料:ML-Agents(十)Crawler

上一次运行的3DBall的任务比较简单,只需要把小球停在方块上方,输入维度低,奖励函数设置较简单,因此很快就能训练出比较好的效果。接下来训练一个更具有挑战性的任务。

在这里插入图片描述

如上图所示,我们需要训练的是一个四条腿的仿真机器人,让它学会站立,面向目标行走,最后吃到绿色的方块,并且这个过程越迅速越好。

环境讲解

机器人所处的环境是一片没有摩擦平地(有摩擦更好训练),存在空气阻力,周围围着四面墙,里面必定有一个绿色方块作为机器人的目标。

智能体所处的环境很简单,但是,智能体本身一点也不简单。

在这里插入图片描述

智能体机器人本身分为身体主干和四条腿,每条腿分为前肢和后肢。因此具有八个关节。

配置关节

参考资料:深入了解 Unity 可配置关节 Configurable Joints)可配置关节 (Configurable Joint)

在这里插入图片描述

这里只对用到的重要参数进行讲解。首先需要用到这个组件的是四条条前肢(靠近身体)和四条后肢,后肢是前肢的子物体,这样后肢才会跟着前肢动。对于前肢来说我们需要把Angular Y Motion和Angular X Motion设置为Limited,其它为Locked,也就是两个方向的自由度。后肢只需要Angular X Motion设置为Limited,也就是一个方向的自由度。然后点击Edit Angular Limits按钮,设置关节的位置和可以旋转的角度,这可以通过设置Anchor和Axis来实现。


触地检测

可以看到,身体的每个部位都配置有Ground Contact的脚本,这个脚本可以检测那个部位是否接触了地面。

在这里插入图片描述

using UnityEngine;
using Unity.MLAgents;

namespace Unity.MLAgentsExamples
{

    [DisallowMultipleComponent]
    public class GroundContact : MonoBehaviour
    {
        [HideInInspector] public Agent agent;

        [Header("Ground Check")] public bool agentDoneOnGroundContact; // Whether to reset agent on ground contact.
        public bool penalizeGroundContact; // Whether to penalize on contact.
        public float groundContactPenalty; // Penalty amount (ex: -1).
        public bool touchingGround;
        const string k_Ground = "ground"; // Tag of ground object.

        // 进入碰撞时给touchingGround设为true,并给予惩罚,判断游戏是否结束
        void OnCollisionEnter(Collision col)
        {
            if (col.transform.CompareTag(k_Ground))
            {
                touchingGround = true;
                if (penalizeGroundContact)
                {
                    agent.SetReward(groundContactPenalty);
                }
                if (agentDoneOnGroundContact)
                {
                    agent.EndEpisode();
                }
            }
        }
        /// 退出碰撞时touchGraound设为false。判断不接触地面
        void OnCollisionExit(Collision other)
        {
            if (other.transform.CompareTag(k_Ground))
            {
                touchingGround = false;
            }
        }
    }
}

代码分析

现在我们可以正式来看智能体上都挂了哪些脚本。

首先是万年不变的Behavior Parameters,输入向量是32维,输出连续的动作是20维。

然后是万年不变的Decision Requester,Take Actions Between Decisions设为false。

再然后万年不变的Model Overrider也安排上,允许训练期间覆盖模型。


Joint Drive Controller

下面讲解一下Joint Drive Controller,这个脚本负责控制各个关节。

首先看BodyPart方法:

/// <summary>
    /// 用于存储agent每个身体部位的行动和学习相关信息
    /// </summary>
    [System.Serializable]
    public class BodyPart
    {
        [Header("Body Part Info")] [Space(10)] public ConfigurableJoint joint;//身体的可配置关节组件
        public Rigidbody rb;//刚体
        [HideInInspector] public Vector3 startingPos;//起始位置
        [HideInInspector] public Quaternion startingRot;//起始角度

        [Header("Ground & Target Contact")]
        [Space(10)]
        public GroundContact groundContact;//检测地面接触
        public TargetContact targetContact;//检测目标接触

        [FormerlySerializedAs("thisJDController")]
        [HideInInspector] public JointDriveController thisJdController;//关节组件Controller

        [Header("Current Joint Settings")]
        [Space(10)]
        public Vector3 currentEularJointRotation;//关节当前欧拉角

        [HideInInspector] public float currentStrength;//当前作用力
        public float currentXNormalizedRot;
        public float currentYNormalizedRot;
        public float currentZNormalizedRot;

        [Header("Other Debug Info")]
        [Space(10)]
        public Vector3 currentJointForce;//当前关节作用力

        public float currentJointForceSqrMag;//当前关节作用力大小
        public Vector3 currentJointTorque;//当前关节转矩
        public float currentJointTorqueSqrMag;//当前关节转矩大小
        public AnimationCurve jointForceCurve = new AnimationCurve();//关节作用力曲线
        public AnimationCurve jointTorqueCurve = new AnimationCurve();//关节力矩曲线

        /// <summary>
        /// 数据初始化
        /// </summary>
        public void Reset(BodyPart bp)
        {
            bp.rb.transform.position = bp.startingPos;//位置
            bp.rb.transform.rotation = bp.startingRot;//角度
            bp.rb.velocity = Vector3.zero;//速度
            bp.rb.angularVelocity = Vector3.zero;//角速度
            if (bp.groundContact)
            {//地面接触标志置位
                bp.groundContact.touchingGround = false;
            }

            if (bp.targetContact)
            {//目标接触标志置位
                bp.targetContact.touchingTarget = false;
            }
        }

        /// <summary>
        /// 根据给定的x,y,z角度和力的大小计算扭矩
        /// </summary>
        public void SetJointTargetRotation(float x, float y, float z)
        {
            x = (x + 1f) * 0.5f;
            y = (y + 1f) * 0.5f;
            z = (z + 1f) * 0.5f;

            //Mathf.Lerp(from : float, to : float, t : float) 插值,t=0~1,返回(to-from)*t
            var xRot = Mathf.Lerp(joint.lowAngularXLimit.limit, joint.highAngularXLimit.limit, x);
            var yRot = Mathf.Lerp(-joint.angularYLimit.limit, joint.angularYLimit.limit, y);
            var zRot = Mathf.Lerp(-joint.angularZLimit.limit, joint.angularZLimit.limit, z);

            //Mathf.InverseLerp(from : float, to : float, value : float)反插值,返回value在from和to之间的比例值
            currentXNormalizedRot = Mathf.InverseLerp(joint.lowAngularXLimit.limit, joint.highAngularXLimit.limit, xRot);
            currentYNormalizedRot = Mathf.InverseLerp(-joint.angularYLimit.limit, joint.angularYLimit.limit, yRot);
            currentZNormalizedRot = Mathf.InverseLerp(-joint.angularZLimit.limit, joint.angularZLimit.limit, zRot);

            joint.targetRotation = Quaternion.Euler(xRot, yRot, zRot);//使关节转向目标角度
            currentEularJointRotation = new Vector3(xRot, yRot, zRot);//当前关节欧拉角
        }
        /// <summary>
        /// 设置关节作用力大小
        /// </summary>
        /// <param name="strength"></param>
        public void SetJointStrength(float strength)
        {
            var rawVal = (strength + 1f) * 0.5f * thisJdController.maxJointForceLimit;
            var jd = new JointDrive
            {
                positionSpring = thisJdController.maxJointSpring,//关节最大弹力
                positionDamper = thisJdController.jointDampen,//关节弹性大小
                maximumForce = rawVal//施加的最大力
            };
            joint.slerpDrive = jd;
            currentStrength = jd.maximumForce;//当前施加的力
        }
    } 

这个脚本主要是将多个BodyPart进行管理的作用,同时可以实时更新身体每一部分作用力及转矩,用以Agent收集BodyPart的相关信息。

JointDriveController方法:

 /// <summary>
    /// Joint控制器
    /// </summary>
    public class JointDriveController : MonoBehaviour
    {
        [Header("Joint Drive Settings")]
        [Space(10)]
        public float maxJointSpring;//关节最大弹力大小
        public float jointDampen;//关节抵抗弹力的强度
        public float maxJointForceLimit;//最大作用力
        //float m_FacingDot;//该变量没用到

        //身体部位字典
        [HideInInspector] public Dictionary<Transform, BodyPart> bodyPartsDict = new Dictionary<Transform, BodyPart>();

        /// <summary>
        /// 创建BodyPart对象并将其添加到字典中
        /// </summary>
        public void SetupBodyPart(Transform t)
        {
            var bp = new BodyPart
            {
                rb = t.GetComponent<Rigidbody>(),
                joint = t.GetComponent<ConfigurableJoint>(),
                startingPos = t.position,
                startingRot = t.rotation
            };
            bp.rb.maxAngularVelocity = 100;//最大角速度为100

            //添加地面碰撞检测脚本
            bp.groundContact = t.GetComponent<GroundContact>();
            if (!bp.groundContact)
            {
                bp.groundContact = t.gameObject.AddComponent<GroundContact>();
                bp.groundContact.agent = gameObject.GetComponent<Agent>();
            }
            else
            {
                bp.groundContact.agent = gameObject.GetComponent<Agent>();
            }

            //添加目标碰撞检测脚本
            bp.targetContact = t.GetComponent<TargetContact>();
            if (!bp.targetContact)
            {
                bp.targetContact = t.gameObject.AddComponent<TargetContact>();
            }

            bp.thisJdController = this;
            bodyPartsDict.Add(t, bp);
        }
        /// <summary>
        /// 更新身体每一部分当前的作用力及转矩
        /// </summary>
        public void GetCurrentJointForces()
        {
            foreach (var bodyPart in bodyPartsDict.Values)
            {//轮询身体每部分
                if (bodyPart.joint)
                {
                    bodyPart.currentJointForce = bodyPart.joint.currentForce;//当前关节作用力
                    bodyPart.currentJointForceSqrMag = bodyPart.joint.currentForce.magnitude;//当前关节作用力大小
                    bodyPart.currentJointTorque = bodyPart.joint.currentTorque;//当前关节作用转矩
                    bodyPart.currentJointTorqueSqrMag = bodyPart.joint.currentTorque.magnitude;//当前关节作用转矩大小
                    if (Application.isEditor)
                    {//IDE下,创建关节作用力和关节力矩的曲线
                        if (bodyPart.jointForceCurve.length > 1000)
                        {
                            bodyPart.jointForceCurve = new AnimationCurve();
                        }

                        if (bodyPart.jointTorqueCurve.length > 1000)
                        {
                            bodyPart.jointTorqueCurve = new AnimationCurve();
                        }

                        bodyPart.jointForceCurve.AddKey(Time.time, bodyPart.currentJointForceSqrMag);
                        bodyPart.jointTorqueCurve.AddKey(Time.time, bodyPart.currentJointTorqueSqrMag);
                    }
                }
            }
        }
    }

虽然这个脚本挂载在agent上,但不会自己起作用,只有其他脚本调用时才起作用。


RigidBody Sensor Component

可以看到agents下面还挂载着一个RigidBody Sensor Component的脚本。

using System.Collections.Generic;
using UnityEngine;
using Unity.MLAgents.Sensors;

namespace Unity.MLAgents.Extensions.Sensors
{

    public class RigidBodySensorComponent : SensorComponent
    {
        public Rigidbody RootBody;

        /// Optional GameObject used to determine the root of the poses.
        public GameObject VirtualRoot;

        /// Settings defining what types of observations will be generated.
        [SerializeField]
        public PhysicsSensorSettings Settings = PhysicsSensorSettings.Default();

        /// Optional sensor name. This must be unique for each Agent.
        [SerializeField]
        public string sensorName;

        [SerializeField]
        [HideInInspector]
        RigidBodyPoseExtractor m_PoseExtractor;

        /// Creates a PhysicsBodySensor.
        public override ISensor[] CreateSensors()
        {
            var _sensorName = string.IsNullOrEmpty(sensorName) ? $"PhysicsBodySensor:{RootBody?.name}" : sensorName;
            return new ISensor[] { new PhysicsBodySensor(GetPoseExtractor(), Settings, _sensorName) };
        }

        /// Get the DisplayNodes of the hierarchy.
        internal IList<PoseExtractor.DisplayNode> GetDisplayNodes()
        {
            return GetPoseExtractor().GetDisplayNodes();
        }

        /// Lazy construction of the PoseExtractor.
        RigidBodyPoseExtractor GetPoseExtractor()
        {
            if (m_PoseExtractor == null)
            {
                ResetPoseExtractor();
            }

            return m_PoseExtractor;
        }

        /// Reset the pose extractor, trying to keep the enabled state of the corresponding poses the same.
        internal void ResetPoseExtractor()
        {
            // Get the current enabled state of each body, so that we can reinitialize with them.
            Dictionary<Rigidbody, bool> bodyPosesEnabled = null;
            if (m_PoseExtractor != null)
            {
                bodyPosesEnabled = m_PoseExtractor.GetBodyPosesEnabled();
            }
            m_PoseExtractor = new RigidBodyPoseExtractor(RootBody, gameObject, VirtualRoot, bodyPosesEnabled);
        }

        /// Toggle the pose at the given index.
        internal void SetPoseEnabled(int index, bool enabled)
        {
            GetPoseExtractor().SetPoseEnabled(index, enabled);
        }

        internal bool IsTrivial()
        {
            if (ReferenceEquals(RootBody, null))
            {
                // It *is* trivial, but this will happen when the sensor is being set up, so don't warn then.
                return false;
            }
            var joints = RootBody.GetComponentsInChildren<Joint>();
            if (joints.Length == 0)
            {
                if (ReferenceEquals(VirtualRoot, null) || ReferenceEquals(VirtualRoot, RootBody.gameObject))
                {
                    return true;
                }
            }
            return false;
        }
    }

}

在这里插入图片描述

这个组件是新加上去的实验性功能,在ML-Agents.extentions包中而不是在主体包中,其中下面的Hierachy是运行使自动产生的,只要我们把Body物体拖到RootBody上,把OrentationCube拖到VirtualRoot上就能正常使用这个组件。

同样这是一个能自己获取输入的传感器,在CreateSensors方法中,new了一个PhysicsBodySensor,而这个类继承ISensor接口,也就是说它可以自己获取输入。其中ISensor接口的Write方法用于生成实际观察。

当智能体用到关节Joint时,加上该组件可以使智能体训练更好。具体功能待探究。


Crawler Agent

接下来是重头戏Crawler Agent脚本:

在这里插入图片描述

这个组件继承于Agent,是真正实现智能体获取输入,获得输出,定义奖励,定义episode的结束等强化学习关键元素的组件。

我们把其中用到的智能体身体各个部位的Transform,网格渲染,材质一一赋值。然后查看其中的方法都实现了什么:

先看初始化方法Initialize,这个方法定义了游戏开始之前需要做的事情:

 public override void Initialize()
 {
     // 早期版本中没有加入以下两行,但经过研究发现智能体身上加入一个指向物体可以大大增加reward
     // 其中原因值得深究
     SpawnTarget(TargetPrefab, transform.position); //spawn target
     m_OrientationCube = GetComponentInChildren<OrientationCubeController>();

     m_DirectionIndicator = GetComponentInChildren<DirectionIndicator>();
     m_JdController = GetComponent<JointDriveController>();

     //Setup each body part
     m_JdController.SetupBodyPart(body);
     m_JdController.SetupBodyPart(leg0Upper);
     m_JdController.SetupBodyPart(leg0Lower);
     m_JdController.SetupBodyPart(leg1Upper);
     m_JdController.SetupBodyPart(leg1Lower);
     m_JdController.SetupBodyPart(leg2Upper);
     m_JdController.SetupBodyPart(leg2Lower);
     m_JdController.SetupBodyPart(leg3Upper);
     m_JdController.SetupBodyPart(leg3Lower);
 }
// 生成目标方块
void SpawnTarget(Transform prefab, Vector3 pos)
{
	m_Target = Instantiate(prefab, pos, Quaternion.identity, transform.parent);
}

首先是生成一个目标点,然后获取必须的组件,以及各个关节部位的初始化。

然后是每个episode开始的时候执行的OnEpisodeBegin:

public override void OnEpisodeBegin()
{
    // 重置所有关节
    foreach (var bodyPart in m_JdController.bodyPartsDict.Values)
    {
        bodyPart.Reset(bodyPart);
    }

    // 让智能体随机朝着一个方向
    body.rotation = Quaternion.Euler(0, Random.Range(0.0f, 360.0f), 0);

    // 更新智能体身上的一个空物体的坐标和旋转(作用待考究)
    UpdateOrientationObjects();

    // 设置随机的目标速度
    TargetWalkingSpeed = Random.Range(0.1f, m_maxWalkingSpeed);
}

下面就是老朋友CollectObservations了,把相应的输入添加到神经网络的输入:

public override void CollectObservations(VectorSensor sensor)
{
    var cubeForward = m_OrientationCube.transform.forward;

    //velocity we want to match
    var velGoal = cubeForward * TargetWalkingSpeed;
    // 获取刚体的平均速度
    var avgVel = GetAvgVelocity();

    // 输入平均速度和目标速度相差的距离,维度为1
    sensor.AddObservation(Vector3.Distance(velGoal, avgVel));
    // 输入智能体刚体相对于身上cube的平均速度(思考为什么加入这个cube物体会使训练更加有效),维度为3
    sensor.AddObservation(m_OrientationCube.transform.InverseTransformDirection(avgVel));
    // 输入智能体相对于身上cube的速度,维度为3
    sensor.AddObservation(m_OrientationCube.transform.InverseTransformDirection(velGoal));
    // 输入一个四元数旋转,维度为4
    sensor.AddObservation(Quaternion.FromToRotation(body.forward, cubeForward));

    // 输入目标点相对身上cube位置,维度为3
    sensor.AddObservation(m_OrientationCube.transform.InverseTransformPoint(m_Target.transform.position));

    // 发出射线测量身体到地面的距离,维度为1
    RaycastHit hit;
    float maxRaycastDist = 10;
    if (Physics.Raycast(body.position, Vector3.down, out hit, maxRaycastDist))
    {
        sensor.AddObservation(hit.distance / maxRaycastDist);
    }
    else
        sensor.AddObservation(1);

    // 身体的每一个部位输入
    foreach (var bodyPart in m_JdController.bodyPartsList)
    {
        CollectObservationBodyPart(bodyPart, sensor);
    }
}

关于坐标系参考文章:https://www.sohu.com/a/221556633_667928

public void CollectObservationBodyPart(BodyPart bp, VectorSensor sensor)
{   
    // 输入是否接触地面,此处共9个输入
    sensor.AddObservation(bp.groundContact.touchingGround); 

    // 如果不是身体,加入现在的关节力度作为输入,此处共8个输入
    if (bp.rb.transform != body)
    {
        sensor.AddObservation(bp.currentStrength / m_JdController.maxJointForceLimit);
    }
}

总计32个输入维度。

现在看看输出OnActionReceived:

public override void OnActionReceived(ActionBuffers actionBuffers)
{
    // The dictionary with all the body parts in it are in the jdController
    var bpDict = m_JdController.bodyPartsDict;

    var continuousActions = actionBuffers.ContinuousActions;
    var i = -1;
    // Pick a new target joint rotation
    bpDict[leg0Upper].SetJointTargetRotation(continuousActions[++i], continuousActions[++i], 0);
    bpDict[leg1Upper].SetJointTargetRotation(continuousActions[++i], continuousActions[++i], 0);
    bpDict[leg2Upper].SetJointTargetRotation(continuousActions[++i], continuousActions[++i], 0);
    bpDict[leg3Upper].SetJointTargetRotation(continuousActions[++i], continuousActions[++i], 0);
    bpDict[leg0Lower].SetJointTargetRotation(continuousActions[++i], 0, 0);
    bpDict[leg1Lower].SetJointTargetRotation(continuousActions[++i], 0, 0);
    bpDict[leg2Lower].SetJointTargetRotation(continuousActions[++i], 0, 0);
    bpDict[leg3Lower].SetJointTargetRotation(continuousActions[++i], 0, 0);

    // Update joint strength
    bpDict[leg0Upper].SetJointStrength(continuousActions[++i]);
    bpDict[leg1Upper].SetJointStrength(continuousActions[++i]);
    bpDict[leg2Upper].SetJointStrength(continuousActions[++i]);
    bpDict[leg3Upper].SetJointStrength(continuousActions[++i]);
    bpDict[leg0Lower].SetJointStrength(continuousActions[++i]);
    bpDict[leg1Lower].SetJointStrength(continuousActions[++i]);
    bpDict[leg2Lower].SetJointStrength(continuousActions[++i]);
    bpDict[leg3Lower].SetJointStrength(continuousActions[++i]);
}

共设置了八个关节的旋转角度,以及对应的力度。共计20个连续输出。


再看看FixedUpdate,这个函数以固定时间间隔被调用,不受帧率的影响。

void FixedUpdate()
{
    // 更新cube和指示器
    UpdateOrientationObjects();

    // 检查脚是否接触地面,接触了会更换材质
    if (useFootGroundedVisualization)
    {
        foot0.material = m_JdController.bodyPartsDict[leg0Lower].groundContact.touchingGround
            ? groundedMaterial
            : unGroundedMaterial;
        foot1.material = m_JdController.bodyPartsDict[leg1Lower].groundContact.touchingGround
            ? groundedMaterial
            : unGroundedMaterial;
        foot2.material = m_JdController.bodyPartsDict[leg2Lower].groundContact.touchingGround
            ? groundedMaterial
            : unGroundedMaterial;
        foot3.material = m_JdController.bodyPartsDict[leg3Lower].groundContact.touchingGround
            ? groundedMaterial
            : unGroundedMaterial;
    }

    var cubeForward = m_OrientationCube.transform.forward;

    // 现在速度的向量越接近目标速度的向量,奖励越高
    var matchSpeedReward = GetMatchingVelocityReward(cubeForward * TargetWalkingSpeed, GetAvgVelocity());

    // 两个向量点乘,当方向相同时为正,方向相反时为负
    var lookAtTargetReward = (Vector3.Dot(cubeForward, body.forward) + 1) * .5F;
	// 奖励采用相乘的形式,保证训练出来的智能体都面朝着目标并且速度也朝着目标
    AddReward(matchSpeedReward * lookAtTargetReward);
}

其中UpdateOrientationObjects,它时刻都在更新智能体上的cube物体的位置和旋转,使其始终朝着目标。同时更新下方指示器的位置和旋转:

void UpdateOrientationObjects()
{
    m_OrientationCube.UpdateOrientation(body, m_Target);
    if (m_DirectionIndicator)
    {
    	m_DirectionIndicator.MatchOrientation(m_OrientationCube.transform);
	}
}

还有一个GetMatchingVelocityReward方法,输入的是目标速度和实际速度,输出一个奖励,两个速度距离越小,奖励越高:

public float GetMatchingVelocityReward(Vector3 velocityGoal, Vector3 actualVelocity)
{
    //目标速度和实际速度直接的距离,对其范围限制在0到TargetWalkingSpeed
    var velDeltaMagnitude = Mathf.Clamp(Vector3.Distance(actualVelocity, velocityGoal), 0, TargetWalkingSpeed);

    //return the value on a declining sigmoid shaped curve that decays from 1 to 0
    //This reward will approach 1 if it matches perfectly and approach zero as it deviates
    return Mathf.Pow(1 - Mathf.Pow(velDeltaMagnitude / TargetWalkingSpeed, 2), 2);
}

Target Controller

上面继承Agent的主脚本讲解完了,下面是生成目标的脚本,我们要在场地中的随机地点生成一个cube,被吃掉后重新生成。

// 只在程序启动时执行一次
void OnEnable()
{
    m_startingPos = transform.position;
    if (respawnIfTouched)
    {
        MoveTargetToRandomPosition();
    }
}
// 每一帧执行一次
void Update()
{
    if (respawnIfFallsOffPlatform)
    {
        if (transform.position.y < m_startingPos.y - fallDistance)
        {
            Debug.Log($"{transform.name} Fell Off Platform");
            MoveTargetToRandomPosition();
        }
    }
}

// 在一个球形范围内随机移动,固定y轴
public void MoveTargetToRandomPosition()
{
    var newTargetPos = m_startingPos + (Random.insideUnitSphere * spawnRadius);
    newTargetPos.y = m_startingPos.y;
    transform.position = newTargetPos;
}

// 碰撞到智能体时,移动到其他地方
// 此处应该加上碰到cube奖励,但由于前面的奖励设置较完善,不加也能正常训练。
private void OnCollisionEnter(Collision col)
{
    if (col.transform.CompareTag(tagToDetect))
    {
        onCollisionEnterEvent.Invoke(col);
        if (respawnIfTouched)
        {
            MoveTargetToRandomPosition();
        }
    }
}

训练参数配置

没有使用其他附加功能,纯粹的PPO已经能在300万个steps使奖励达到2500以上,智能体动作理想:

behaviors:
  Crawler:
    trainer_type: ppo
    hyperparameters:
      batch_size: 2048
      buffer_size: 20480
      learning_rate: 0.0003
      beta: 0.005
      epsilon: 0.2
      lambd: 0.95
      num_epoch: 3
      learning_rate_schedule: linear
    network_settings:
      normalize: true
      hidden_units: 512
      num_layers: 3
      vis_encode_type: simple
    reward_signals:
      extrinsic:
        gamma: 0.995
        strength: 1.0
    keep_checkpoints: 5
    max_steps: 10000000
    time_horizon: 1000
    summary_freq: 30000

使用SAC算法的配置文件为:

behaviors:
  Crawler:
    trainer_type: sac
    hyperparameters:
      learning_rate: 0.0003
      learning_rate_schedule: constant
      batch_size: 256
      buffer_size: 500000
      buffer_init_steps: 0
      tau: 0.005
      steps_per_update: 20.0
      save_replay_buffer: false
      init_entcoef: 1.0
      reward_signal_steps_per_update: 20.0
    network_settings:
      normalize: true
      hidden_units: 512
      num_layers: 3
      vis_encode_type: simple
    reward_signals:
      extrinsic:
        gamma: 0.995
        strength: 1.0
    keep_checkpoints: 5
    max_steps: 5000000
    time_horizon: 1000
    summary_freq: 30000

使用模仿学习的配置:

behaviors:
  Crawler:
    trainer_type: ppo
    hyperparameters:
      batch_size: 2024
      buffer_size: 20240
      learning_rate: 0.0003
      beta: 0.005
      epsilon: 0.2
      lambd: 0.95
      num_epoch: 3
      learning_rate_schedule: linear
    network_settings:
      normalize: true
      hidden_units: 512
      num_layers: 3
      vis_encode_type: simple
    reward_signals:
      gail:
        gamma: 0.99
        strength: 1.0
        network_settings:
          normalize: true
          hidden_units: 128
          num_layers: 2
          vis_encode_type: simple
        learning_rate: 0.0003
        use_actions: false
        use_vail: false
        demo_path: Project/Assets/ML-Agents/Examples/Crawler/Demos/ExpertCrawler.demo
    keep_checkpoints: 5
    max_steps: 10000000
    time_horizon: 1000
    summary_freq: 30000
    behavioral_cloning:
      demo_path: Project/Assets/ML-Agents/Examples/Crawler/Demos/ExpertCrawler.demo
      steps: 50000
      strength: 0.5
      samples_per_update: 0
  • 9
    点赞
  • 15
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

微笑小星

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值