Unity3D ML-Agent-0.8.1 学习七(例子源码分析1)

Unity3D ML-Agent-0.8.1 学习七(例子源码分析1))

写的目的

本篇想分享下看例子中的源码分析,其实也就是一些我理解之后的注释,一些思路,希望对你有帮助。

例子

Basic

在这里插入图片描述
这个例子主要是训练方块去左右移动,获得最大奖励,左边奖励小,右边大,于是最后会让方块就往右边走,是一格格走的。
主要源码分析:BasicAgent.cs

using UnityEngine;
using MLAgents;

public class BasicAgent : Agent
{
    /// <summary>
    /// 获取环境
    /// </summary>
    [Header("Specific to Basic")]
    private BasicAcademy academy;
    /// <summary>
    /// 请求决策的时间间隔
    /// </summary>
    public float timeBetweenDecisionsAtInference;
    /// <summary>
    /// 累计间隔时间
    /// </summary>
    private float timeSinceDecision;
    /// <summary>
    /// 起始位置
    /// </summary>
    int position;
    // <summary>
    /// 小目标位置
    /// </summary>
    int smallGoalPosition;
    // <summary>
    /// 大目标位置
    /// </summary>
    int largeGoalPosition;

    /// <summary>
    /// 2个目标物体
    /// </summary>
    public GameObject largeGoal;
    public GameObject smallGoal;

    /// <summary>
    /// 限制最大最小位置,防止跑出去
    /// </summary>
    int minPosition;
    int maxPosition;

    public override void InitializeAgent()
    {
        academy = FindObjectOfType(typeof(BasicAcademy)) as BasicAcademy;
    }

    /// <summary>
    /// 用了one-host编码 即一个20位的列表 第postion个位置为1, 其他都是0。
    /// 举个例子,简单点5位的one-host编码 比如position=3,即[0,0,0,1,0]
    /// 这样做比较简单,因为就是一格格移动的,可以记录当前在哪一格,跟飞行棋走格子一样
    /// </summary>
    public override void CollectObservations()
    {
        AddVectorObs(position, 20);
    }

    /// <summary>
    /// 因为是走格子的,所以是获得离散的输入就可以
    /// </summary>
    /// <param name="vectorAction"></param>
    /// <param name="textAction"></param>
    public override void AgentAction(float[] vectorAction, string textAction)
	{
        //获取离散的值。一般就是从0开始的,在brain面板里的Branch n Szie里填的,
        //比如这个n是3,那就是0 1 2,三个值
        var movement = (int)vectorAction[0];
	    
		int direction = 0;
	    //左右移动 左边-1 右边1
		switch (movement)
		{
		    case 1:
		        direction = -1;
		        break;
		    case 2:
		        direction = 1;
		        break;
		}

        //计算位置,限定位置最大最小范围
        position += direction;
        if (position < minPosition) { position = minPosition; }
        if (position > maxPosition) { position = maxPosition; }

        gameObject.transform.position = new Vector3(position - 10f, 0f, 0f);
        //每次行动后给予惩罚,为了让他达到任何目标
        AddReward(-0.01f);

        if (position == smallGoalPosition)
        {
            Done();
            AddReward(0.1f);
        }
        //大目标奖励多
        if (position == largeGoalPosition)
        {
            Done();
            AddReward(1f);
        }
    }
    /// <summary>
    /// 初始化第一次和每次迭代结束后重新设置位置
    /// </summary>
    public override void AgentReset()
    {
        position = 10;
        minPosition = 0;
        maxPosition = 20;
        smallGoalPosition = 7;
        largeGoalPosition = 17;
        smallGoal.transform.position = new Vector3(smallGoalPosition - 10f, 0f, 0f);
        largeGoal.transform.position = new Vector3(largeGoalPosition - 10f, 0f, 0f);
    }

    public override void AgentOnDone()
    {

    }

    public void FixedUpdate()
    {
        WaitTimeInference();
    }
    /// <summary>
    /// 固定时间请求决策
    /// </summary>
    private void WaitTimeInference()
    {
        if (!academy.GetIsInference())
        {
            RequestDecision();
        }
        else
        {
            if (timeSinceDecision >= timeBetweenDecisionsAtInference)
            {
                timeSinceDecision = 0f;
                RequestDecision();
            }
            else
            {
                timeSinceDecision += Time.fixedDeltaTime;
            }
        }
    }

}

3DBall

在这里插入图片描述
主要是训练平台让小球不掉下去,需要同时关注小球的速度,位置,平台的角度。
主要源码分析:Ball3DAgent.cs

using UnityEngine;
using MLAgents;

public class Ball3DAgent : Agent
{
    [Header("Specific to Ball3D")]
    public GameObject ball;
    private Rigidbody ballRb;

    /// <summary>
    /// 初始化代理,获得平台刚体组件,其实是在Agent的OnEnable调用
    /// </summary>
    public override void InitializeAgent()
    {
        ballRb = ball.GetComponent<Rigidbody>();
    }

    /// <summary>
    /// 获取观察环境,考虑平台角度,保持平衡,考虑相对位置,小球的速度,衡量是否来得及调整
    /// </summary>
    public override void CollectObservations()
    {
        //平台的旋转角度
        AddVectorObs(gameObject.transform.rotation.z);
        AddVectorObs(gameObject.transform.rotation.x);
        //球和平台的相对位置,用世界坐标和相对父类坐标都可以,差一样的
        AddVectorObs(ball.transform.position - gameObject.transform.position);
        //小球的速度
        AddVectorObs(ballRb.velocity);
    }

    /// <summary>
    /// 决策后采取的动作
    /// </summary>
    /// <param name="vectorAction"></param>
    /// <param name="textAction"></param>
    public override void AgentAction(float[] vectorAction, string textAction)
    {
        //如果参数是连续的,获取Z,X的值,根据情况旋转角度,保持平衡
        if (brain.brainParameters.vectorActionSpaceType == SpaceType.continuous)
        {
            var actionZ = 2f * Mathf.Clamp(vectorAction[0], -1f, 1f);
            var actionX = 2f * Mathf.Clamp(vectorAction[1], -1f, 1f);

            if ((gameObject.transform.rotation.z < 0.25f && actionZ > 0f) ||
                (gameObject.transform.rotation.z > -0.25f && actionZ < 0f))
            {
                gameObject.transform.Rotate(new Vector3(0, 0, 1), actionZ);
            }

            if ((gameObject.transform.rotation.x < 0.25f && actionX > 0f) ||
                (gameObject.transform.rotation.x > -0.25f && actionX < 0f))
            {
                gameObject.transform.Rotate(new Vector3(1, 0, 0), actionX);
            }
        }
        //如果球在平台下了,或者跑出平台外了,给予惩罚,否则就奖励
        if ((ball.transform.position.y - gameObject.transform.position.y) < -2f ||
            Mathf.Abs(ball.transform.position.x - gameObject.transform.position.x) > 3f ||
            Mathf.Abs(ball.transform.position.z - gameObject.transform.position.z) > 3f)
        {
            Done();
            SetReward(-1f);
        }
        else
        {
            SetReward(0.1f);
        }
    }

    //一次迭代后重置数据
    public override void AgentReset()
    {
        gameObject.transform.rotation = new Quaternion(0f, 0f, 0f, 0f);
        gameObject.transform.Rotate(new Vector3(1, 0, 0), Random.Range(-10f, 10f));
        gameObject.transform.Rotate(new Vector3(0, 0, 1), Random.Range(-10f, 10f));
        ballRb.velocity = new Vector3(0f, 0f, 0f);
        ball.transform.position = new Vector3(Random.Range(-1.5f, 1.5f), 4f, Random.Range(-1.5f, 1.5f))
                                      + gameObject.transform.position;

    }

}

GridWorld

在这里插入图片描述
在四周有墙的地方,让方块去找绿色为目标,而避开红色的陷阱,也是一格格走的,而且是视觉学习。
主要源码分析:GridAcademy.cs

using System.Collections.Generic;
using UnityEngine;
using System.Linq;
using MLAgents;


public class GridAcademy : Academy
{
    /// <summary>
    /// 陷阱 目标物体的列表
    /// </summary>
    [HideInInspector]
    public List<GameObject> actorObjs;

    /// <summary>
    /// 相应玩家预制的标记值 numObstacles=2 numGoals=1
    /// </summary>
    [HideInInspector]
    public int[] players;

    /// <summary>
    /// 代理
    /// </summary>
    public GameObject trueAgent;
    /// <summary>
    /// 格子大小 gridSize x gridSize
    /// </summary>
    public int gridSize;

    /// <summary>
    /// 摄像机物体
    /// </summary>
    public GameObject camObject;


    /// <summary>
    /// 场景相机
    /// </summary>
    Camera cam;

    /// <summary>
    /// 代理的视觉相机
    /// </summary>
    Camera agentCam;

    /// <summary>
    /// 代理物体预制
    /// </summary>
    public GameObject agentPref;
    /// <summary>
    /// 目标物体预制
    /// </summary>
    public GameObject goalPref;
    /// <summary>
    /// 陷阱物体预制
    /// </summary>
    public GameObject pitPref;

    /// <summary>
    /// 存放物体预制
    /// </summary>
    GameObject[] objects;

    /// <summary>
    /// 环境平台
    /// </summary>
    GameObject plane;
    GameObject sN;
    GameObject sS;
    GameObject sE;
    GameObject sW;

    /// <summary>
    /// 各种初始化
    /// </summary>
    public override void InitializeAcademy()
    {
        //从面板上获取填入的参数
        gridSize = (int)resetParameters["gridSize"];
        cam = camObject.GetComponent<Camera>();

        objects = new GameObject[3] {agentPref, goalPref, pitPref};

        agentCam = GameObject.Find("agentCam").GetComponent<Camera>();

        actorObjs = new List<GameObject>();

        plane = GameObject.Find("Plane");
        sN = GameObject.Find("sN");
        sS = GameObject.Find("sS");
        sW = GameObject.Find("sW");
        sE = GameObject.Find("sE");
    }

    /// <summary>
    /// 设置环境
    /// </summary>
    public void SetEnvironment()
    {
        //根据gridSize调整相机
        cam.transform.position = new Vector3(-((int)resetParameters["gridSize"] - 1) / 2f, 
                                             (int)resetParameters["gridSize"] * 1.25f, 
                                             -((int)resetParameters["gridSize"] - 1) / 2f);
        cam.orthographicSize = ((int)resetParameters["gridSize"] + 5f) / 2f;

        List<int> playersList = new List<int>();

        for (int i = 0; i < (int)resetParameters["numObstacles"]; i++)
        {
            playersList.Add(2);
        }

        for (int i = 0; i < (int)resetParameters["numGoals"]; i++)
        {
            playersList.Add(1);
        }
        players = playersList.ToArray();

        //根据gridSize调整场景物体
        plane.transform.localScale = new Vector3(gridSize / 10.0f, 1f, gridSize / 10.0f);
        plane.transform.position = new Vector3((gridSize - 1) / 2f, -0.5f, (gridSize - 1) / 2f);
        sN.transform.localScale = new Vector3(1, 1, gridSize + 2);
        sS.transform.localScale = new Vector3(1, 1, gridSize + 2);
        sN.transform.position = new Vector3((gridSize - 1) / 2f, 0.0f, gridSize);
        sS.transform.position = new Vector3((gridSize - 1) / 2f, 0.0f, -1);
        sE.transform.localScale = new Vector3(1, 1, gridSize + 2);
        sW.transform.localScale = new Vector3(1, 1, gridSize + 2);
        sE.transform.position = new Vector3(gridSize, 0.0f, (gridSize - 1) / 2f);
        sW.transform.position = new Vector3(-1, 0.0f, (gridSize - 1) / 2f);

        agentCam.orthographicSize = (gridSize) / 2f;
        agentCam.transform.position = new Vector3((gridSize - 1) / 2f, gridSize + 1f, (gridSize - 1) / 2f);

    }

    public override void AcademyReset()
    {
        foreach (GameObject actor in actorObjs)
        {
            DestroyImmediate(actor);
        }
        SetEnvironment();

        actorObjs.Clear();

        //重新设置场景里格子里的物体,根据players的数量随机生成坐标0-24号的格子的位置,HashSet不重复的位置
        //后面算出对应的行,列,即位置, 比如25个格子 那10号 就是第1行 第1列(行列号从0开始)
        HashSet<int> numbers = new HashSet<int>();
        while (numbers.Count < players.Length + 1)
        {
            numbers.Add(Random.Range(0, gridSize * gridSize));
        }

        int[] numbersA = Enumerable.ToArray(numbers);

        //计算numbersA具体位置 单位为1 格子号
        for (int i = 0; i < players.Length; i++)
        {
            //行
            int x = (numbersA[i]) / gridSize;
            //列
            int y = (numbersA[i]) % gridSize;
            GameObject actorObj = Instantiate(objects[players[i]]);
            actorObj.transform.position = new Vector3(x, -0.25f, y);
            actorObjs.Add(actorObj);
        }
        //获取numbersA后一个随机的位置
        int x_a = (numbersA[players.Length]) / gridSize;
        int y_a = (numbersA[players.Length]) % gridSize;
        trueAgent.transform.position = new Vector3(x_a, -0.25f, y_a);

    }

    public override void AcademyStep()
    {

    }
}

还有GridAgent.cs

using System;
using UnityEngine;
using System.Linq;
using MLAgents;

public class GridAgent : Agent
{
    [Header("Specific to GridWorld")]
    private GridAcademy academy;
    /// <summary>
    /// 请求决策的时间间隔
    /// </summary>
    public float timeBetweenDecisionsAtInference;
    /// <summary>
    /// 累计间隔时间
    /// </summary>
    private float timeSinceDecision;

    [Tooltip("Because we want an observation right before making a decision, we can force " + 
             "a camera to render before making a decision. Place the agentCam here if using " +
             "RenderTexture as observations.")]
    public Camera renderCamera;

    /// <summary>
    /// 屏蔽动作,即决策不会采取某些动作
    /// </summary>
    [Tooltip("Selecting will turn on action masking. Note that a model trained with action " +
             "masking turned on may not behave optimally when action masking is turned off.")]
    public bool maskActions = true;

    private const int NoAction = 0;  // do nothing!
    private const int Up = 1;
    private const int Down = 2;
    private const int Left = 3;
    private const int Right = 4;

    public override void InitializeAgent()
    {
        academy = FindObjectOfType(typeof(GridAcademy)) as GridAcademy;
    }

    /// <summary>
    /// 视觉无需收集信息,设置是否屏蔽某些动作
    /// </summary>
    public override void CollectObservations()
    {
        if (maskActions)
        {
            SetMask();
        }
    }

    /// <summary>
    /// 屏蔽某些动作
    /// </summary>
    private void SetMask()
    {
        // 防止代理选择碰撞墙的动作,比如5 X 5的图,最外面是墙,即边缘0,4的位置之外都是墙
        var positionX = (int) transform.position.x;
        var positionZ = (int) transform.position.z;
        var maxPosition = academy.gridSize - 1;
        //再走就-1 了,就是墙,所以要防止再左边,下面同理
        if (positionX == 0)
        {
            SetActionMask(Left);
        }

        if (positionX == maxPosition)
        {
            SetActionMask(Right);
        }

        if (positionZ == 0)
        {
            SetActionMask(Down);
        }

        if (positionZ == maxPosition)
        {
            SetActionMask(Up);
        }
    }

    
    public override void AgentAction(float[] vectorAction, string textAction)
    {
        AddReward(-0.01f);
        int action = Mathf.FloorToInt(vectorAction[0]);
        //计算出下一步的位置
        Vector3 targetPos = transform.position;
        switch (action)
        {
            case NoAction:
                // do nothing
                break;
            case Right:
                targetPos = transform.position + new Vector3(1f, 0, 0f);
                break;
            case Left:
                targetPos = transform.position + new Vector3(-1f, 0, 0f);
                break;
            case Up:
                targetPos = transform.position + new Vector3(0f, 0, 1f);
                break;
            case Down:
                targetPos = transform.position + new Vector3(0f, 0, -1f);
                break;
            default:
                throw new ArgumentException("Invalid action value");
        }

        Collider[] blockTest = Physics.OverlapBox(targetPos, new Vector3(0.3f, 0.3f, 0.3f));
        //如果不会碰到墙,就执行里面的,碰到墙则在原地,在设置动作屏蔽的时候,可以取消这个判断,否则还是需要的,不然会一直走出
        if (blockTest.Where(col => col.gameObject.CompareTag("wall")).ToArray().Length == 0)
        {
            //设置位置
            transform.position = targetPos;
            //碰到目标
            if (blockTest.Where(col => col.gameObject.CompareTag("goal")).ToArray().Length == 1)
            {
                Done();
                SetReward(1f);
            }
            //碰到陷阱
            if (blockTest.Where(col => col.gameObject.CompareTag("pit")).ToArray().Length == 1)
            {
                Done();
                SetReward(-1f);
            }
        }
    }

    // 刷新环境
    public override void AgentReset()
    {
        academy.AcademyReset();
    }

    public void FixedUpdate()
    {
        WaitTimeInference();
    }

    private void WaitTimeInference()
    {
        if(renderCamera != null)
        {
            renderCamera.Render();
        }

        if (!academy.GetIsInference())
        {
            RequestDecision();
        }
        else
        {
            if (timeSinceDecision >= timeBetweenDecisionsAtInference)
            {
                timeSinceDecision = 0f;
                RequestDecision();
            }
            else
            {
                timeSinceDecision += Time.fixedDeltaTime;
            }
        }
    }
}

总结

暂时先这三个例子吧,其他的以后慢慢添加,其实主要也就是看代码,理解思路,以后自己要是做的时候可以参照着来,可能每个游戏都不一样,但是一些基本的东西应该是一样的。

好了,今天就到这里了,希望对学习理解有帮助,大神看见勿喷,仅为自己的学习理解,能力有限,请多包涵,部分图片来自网络,侵删。

  • 2
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值