写的目的
本篇想分享下看例子中的源码分析,其实也就是一些我理解之后的注释,一些思路,希望对你有帮助。
例子
Basic
这个例子主要是训练方块去左右移动,获得最大奖励,左边奖励小,右边大,于是最后会让方块就往右边走,是一格格走的。
主要源码分析:BasicAgent.cs
using UnityEngine;
using MLAgents;
public class BasicAgent : Agent
{
/// <summary>
/// 获取环境
/// </summary>
[Header("Specific to Basic")]
private BasicAcademy academy;
/// <summary>
/// 请求决策的时间间隔
/// </summary>
public float timeBetweenDecisionsAtInference;
/// <summary>
/// 累计间隔时间
/// </summary>
private float timeSinceDecision;
/// <summary>
/// 起始位置
/// </summary>
int position;
// <summary>
/// 小目标位置
/// </summary>
int smallGoalPosition;
// <summary>
/// 大目标位置
/// </summary>
int largeGoalPosition;
/// <summary>
/// 2个目标物体
/// </summary>
public GameObject largeGoal;
public GameObject smallGoal;
/// <summary>
/// 限制最大最小位置,防止跑出去
/// </summary>
int minPosition;
int maxPosition;
public override void InitializeAgent()
{
academy = FindObjectOfType(typeof(BasicAcademy)) as BasicAcademy;
}
/// <summary>
/// 用了one-host编码 即一个20位的列表 第postion个位置为1, 其他都是0。
/// 举个例子,简单点5位的one-host编码 比如position=3,即[0,0,0,1,0]
/// 这样做比较简单,因为就是一格格移动的,可以记录当前在哪一格,跟飞行棋走格子一样
/// </summary>
public override void CollectObservations()
{
AddVectorObs(position, 20);
}
/// <summary>
/// 因为是走格子的,所以是获得离散的输入就可以
/// </summary>
/// <param name="vectorAction"></param>
/// <param name="textAction"></param>
public override void AgentAction(float[] vectorAction, string textAction)
{
//获取离散的值。一般就是从0开始的,在brain面板里的Branch n Szie里填的,
//比如这个n是3,那就是0 1 2,三个值
var movement = (int)vectorAction[0];
int direction = 0;
//左右移动 左边-1 右边1
switch (movement)
{
case 1:
direction = -1;
break;
case 2:
direction = 1;
break;
}
//计算位置,限定位置最大最小范围
position += direction;
if (position < minPosition) { position = minPosition; }
if (position > maxPosition) { position = maxPosition; }
gameObject.transform.position = new Vector3(position - 10f, 0f, 0f);
//每次行动后给予惩罚,为了让他达到任何目标
AddReward(-0.01f);
if (position == smallGoalPosition)
{
Done();
AddReward(0.1f);
}
//大目标奖励多
if (position == largeGoalPosition)
{
Done();
AddReward(1f);
}
}
/// <summary>
/// 初始化第一次和每次迭代结束后重新设置位置
/// </summary>
public override void AgentReset()
{
position = 10;
minPosition = 0;
maxPosition = 20;
smallGoalPosition = 7;
largeGoalPosition = 17;
smallGoal.transform.position = new Vector3(smallGoalPosition - 10f, 0f, 0f);
largeGoal.transform.position = new Vector3(largeGoalPosition - 10f, 0f, 0f);
}
public override void AgentOnDone()
{
}
public void FixedUpdate()
{
WaitTimeInference();
}
/// <summary>
/// 固定时间请求决策
/// </summary>
private void WaitTimeInference()
{
if (!academy.GetIsInference())
{
RequestDecision();
}
else
{
if (timeSinceDecision >= timeBetweenDecisionsAtInference)
{
timeSinceDecision = 0f;
RequestDecision();
}
else
{
timeSinceDecision += Time.fixedDeltaTime;
}
}
}
}
3DBall
主要是训练平台让小球不掉下去,需要同时关注小球的速度,位置,平台的角度。
主要源码分析:Ball3DAgent.cs
using UnityEngine;
using MLAgents;
public class Ball3DAgent : Agent
{
[Header("Specific to Ball3D")]
public GameObject ball;
private Rigidbody ballRb;
/// <summary>
/// 初始化代理,获得平台刚体组件,其实是在Agent的OnEnable调用
/// </summary>
public override void InitializeAgent()
{
ballRb = ball.GetComponent<Rigidbody>();
}
/// <summary>
/// 获取观察环境,考虑平台角度,保持平衡,考虑相对位置,小球的速度,衡量是否来得及调整
/// </summary>
public override void CollectObservations()
{
//平台的旋转角度
AddVectorObs(gameObject.transform.rotation.z);
AddVectorObs(gameObject.transform.rotation.x);
//球和平台的相对位置,用世界坐标和相对父类坐标都可以,差一样的
AddVectorObs(ball.transform.position - gameObject.transform.position);
//小球的速度
AddVectorObs(ballRb.velocity);
}
/// <summary>
/// 决策后采取的动作
/// </summary>
/// <param name="vectorAction"></param>
/// <param name="textAction"></param>
public override void AgentAction(float[] vectorAction, string textAction)
{
//如果参数是连续的,获取Z,X的值,根据情况旋转角度,保持平衡
if (brain.brainParameters.vectorActionSpaceType == SpaceType.continuous)
{
var actionZ = 2f * Mathf.Clamp(vectorAction[0], -1f, 1f);
var actionX = 2f * Mathf.Clamp(vectorAction[1], -1f, 1f);
if ((gameObject.transform.rotation.z < 0.25f && actionZ > 0f) ||
(gameObject.transform.rotation.z > -0.25f && actionZ < 0f))
{
gameObject.transform.Rotate(new Vector3(0, 0, 1), actionZ);
}
if ((gameObject.transform.rotation.x < 0.25f && actionX > 0f) ||
(gameObject.transform.rotation.x > -0.25f && actionX < 0f))
{
gameObject.transform.Rotate(new Vector3(1, 0, 0), actionX);
}
}
//如果球在平台下了,或者跑出平台外了,给予惩罚,否则就奖励
if ((ball.transform.position.y - gameObject.transform.position.y) < -2f ||
Mathf.Abs(ball.transform.position.x - gameObject.transform.position.x) > 3f ||
Mathf.Abs(ball.transform.position.z - gameObject.transform.position.z) > 3f)
{
Done();
SetReward(-1f);
}
else
{
SetReward(0.1f);
}
}
//一次迭代后重置数据
public override void AgentReset()
{
gameObject.transform.rotation = new Quaternion(0f, 0f, 0f, 0f);
gameObject.transform.Rotate(new Vector3(1, 0, 0), Random.Range(-10f, 10f));
gameObject.transform.Rotate(new Vector3(0, 0, 1), Random.Range(-10f, 10f));
ballRb.velocity = new Vector3(0f, 0f, 0f);
ball.transform.position = new Vector3(Random.Range(-1.5f, 1.5f), 4f, Random.Range(-1.5f, 1.5f))
+ gameObject.transform.position;
}
}
GridWorld
在四周有墙的地方,让方块去找绿色为目标,而避开红色的陷阱,也是一格格走的,而且是视觉学习。
主要源码分析:GridAcademy.cs
using System.Collections.Generic;
using UnityEngine;
using System.Linq;
using MLAgents;
public class GridAcademy : Academy
{
/// <summary>
/// 陷阱 目标物体的列表
/// </summary>
[HideInInspector]
public List<GameObject> actorObjs;
/// <summary>
/// 相应玩家预制的标记值 numObstacles=2 numGoals=1
/// </summary>
[HideInInspector]
public int[] players;
/// <summary>
/// 代理
/// </summary>
public GameObject trueAgent;
/// <summary>
/// 格子大小 gridSize x gridSize
/// </summary>
public int gridSize;
/// <summary>
/// 摄像机物体
/// </summary>
public GameObject camObject;
/// <summary>
/// 场景相机
/// </summary>
Camera cam;
/// <summary>
/// 代理的视觉相机
/// </summary>
Camera agentCam;
/// <summary>
/// 代理物体预制
/// </summary>
public GameObject agentPref;
/// <summary>
/// 目标物体预制
/// </summary>
public GameObject goalPref;
/// <summary>
/// 陷阱物体预制
/// </summary>
public GameObject pitPref;
/// <summary>
/// 存放物体预制
/// </summary>
GameObject[] objects;
/// <summary>
/// 环境平台
/// </summary>
GameObject plane;
GameObject sN;
GameObject sS;
GameObject sE;
GameObject sW;
/// <summary>
/// 各种初始化
/// </summary>
public override void InitializeAcademy()
{
//从面板上获取填入的参数
gridSize = (int)resetParameters["gridSize"];
cam = camObject.GetComponent<Camera>();
objects = new GameObject[3] {agentPref, goalPref, pitPref};
agentCam = GameObject.Find("agentCam").GetComponent<Camera>();
actorObjs = new List<GameObject>();
plane = GameObject.Find("Plane");
sN = GameObject.Find("sN");
sS = GameObject.Find("sS");
sW = GameObject.Find("sW");
sE = GameObject.Find("sE");
}
/// <summary>
/// 设置环境
/// </summary>
public void SetEnvironment()
{
//根据gridSize调整相机
cam.transform.position = new Vector3(-((int)resetParameters["gridSize"] - 1) / 2f,
(int)resetParameters["gridSize"] * 1.25f,
-((int)resetParameters["gridSize"] - 1) / 2f);
cam.orthographicSize = ((int)resetParameters["gridSize"] + 5f) / 2f;
List<int> playersList = new List<int>();
for (int i = 0; i < (int)resetParameters["numObstacles"]; i++)
{
playersList.Add(2);
}
for (int i = 0; i < (int)resetParameters["numGoals"]; i++)
{
playersList.Add(1);
}
players = playersList.ToArray();
//根据gridSize调整场景物体
plane.transform.localScale = new Vector3(gridSize / 10.0f, 1f, gridSize / 10.0f);
plane.transform.position = new Vector3((gridSize - 1) / 2f, -0.5f, (gridSize - 1) / 2f);
sN.transform.localScale = new Vector3(1, 1, gridSize + 2);
sS.transform.localScale = new Vector3(1, 1, gridSize + 2);
sN.transform.position = new Vector3((gridSize - 1) / 2f, 0.0f, gridSize);
sS.transform.position = new Vector3((gridSize - 1) / 2f, 0.0f, -1);
sE.transform.localScale = new Vector3(1, 1, gridSize + 2);
sW.transform.localScale = new Vector3(1, 1, gridSize + 2);
sE.transform.position = new Vector3(gridSize, 0.0f, (gridSize - 1) / 2f);
sW.transform.position = new Vector3(-1, 0.0f, (gridSize - 1) / 2f);
agentCam.orthographicSize = (gridSize) / 2f;
agentCam.transform.position = new Vector3((gridSize - 1) / 2f, gridSize + 1f, (gridSize - 1) / 2f);
}
public override void AcademyReset()
{
foreach (GameObject actor in actorObjs)
{
DestroyImmediate(actor);
}
SetEnvironment();
actorObjs.Clear();
//重新设置场景里格子里的物体,根据players的数量随机生成坐标0-24号的格子的位置,HashSet不重复的位置
//后面算出对应的行,列,即位置, 比如25个格子 那10号 就是第1行 第1列(行列号从0开始)
HashSet<int> numbers = new HashSet<int>();
while (numbers.Count < players.Length + 1)
{
numbers.Add(Random.Range(0, gridSize * gridSize));
}
int[] numbersA = Enumerable.ToArray(numbers);
//计算numbersA具体位置 单位为1 格子号
for (int i = 0; i < players.Length; i++)
{
//行
int x = (numbersA[i]) / gridSize;
//列
int y = (numbersA[i]) % gridSize;
GameObject actorObj = Instantiate(objects[players[i]]);
actorObj.transform.position = new Vector3(x, -0.25f, y);
actorObjs.Add(actorObj);
}
//获取numbersA后一个随机的位置
int x_a = (numbersA[players.Length]) / gridSize;
int y_a = (numbersA[players.Length]) % gridSize;
trueAgent.transform.position = new Vector3(x_a, -0.25f, y_a);
}
public override void AcademyStep()
{
}
}
还有GridAgent.cs
using System;
using UnityEngine;
using System.Linq;
using MLAgents;
public class GridAgent : Agent
{
[Header("Specific to GridWorld")]
private GridAcademy academy;
/// <summary>
/// 请求决策的时间间隔
/// </summary>
public float timeBetweenDecisionsAtInference;
/// <summary>
/// 累计间隔时间
/// </summary>
private float timeSinceDecision;
[Tooltip("Because we want an observation right before making a decision, we can force " +
"a camera to render before making a decision. Place the agentCam here if using " +
"RenderTexture as observations.")]
public Camera renderCamera;
/// <summary>
/// 屏蔽动作,即决策不会采取某些动作
/// </summary>
[Tooltip("Selecting will turn on action masking. Note that a model trained with action " +
"masking turned on may not behave optimally when action masking is turned off.")]
public bool maskActions = true;
private const int NoAction = 0; // do nothing!
private const int Up = 1;
private const int Down = 2;
private const int Left = 3;
private const int Right = 4;
public override void InitializeAgent()
{
academy = FindObjectOfType(typeof(GridAcademy)) as GridAcademy;
}
/// <summary>
/// 视觉无需收集信息,设置是否屏蔽某些动作
/// </summary>
public override void CollectObservations()
{
if (maskActions)
{
SetMask();
}
}
/// <summary>
/// 屏蔽某些动作
/// </summary>
private void SetMask()
{
// 防止代理选择碰撞墙的动作,比如5 X 5的图,最外面是墙,即边缘0,4的位置之外都是墙
var positionX = (int) transform.position.x;
var positionZ = (int) transform.position.z;
var maxPosition = academy.gridSize - 1;
//再走就-1 了,就是墙,所以要防止再左边,下面同理
if (positionX == 0)
{
SetActionMask(Left);
}
if (positionX == maxPosition)
{
SetActionMask(Right);
}
if (positionZ == 0)
{
SetActionMask(Down);
}
if (positionZ == maxPosition)
{
SetActionMask(Up);
}
}
public override void AgentAction(float[] vectorAction, string textAction)
{
AddReward(-0.01f);
int action = Mathf.FloorToInt(vectorAction[0]);
//计算出下一步的位置
Vector3 targetPos = transform.position;
switch (action)
{
case NoAction:
// do nothing
break;
case Right:
targetPos = transform.position + new Vector3(1f, 0, 0f);
break;
case Left:
targetPos = transform.position + new Vector3(-1f, 0, 0f);
break;
case Up:
targetPos = transform.position + new Vector3(0f, 0, 1f);
break;
case Down:
targetPos = transform.position + new Vector3(0f, 0, -1f);
break;
default:
throw new ArgumentException("Invalid action value");
}
Collider[] blockTest = Physics.OverlapBox(targetPos, new Vector3(0.3f, 0.3f, 0.3f));
//如果不会碰到墙,就执行里面的,碰到墙则在原地,在设置动作屏蔽的时候,可以取消这个判断,否则还是需要的,不然会一直走出
if (blockTest.Where(col => col.gameObject.CompareTag("wall")).ToArray().Length == 0)
{
//设置位置
transform.position = targetPos;
//碰到目标
if (blockTest.Where(col => col.gameObject.CompareTag("goal")).ToArray().Length == 1)
{
Done();
SetReward(1f);
}
//碰到陷阱
if (blockTest.Where(col => col.gameObject.CompareTag("pit")).ToArray().Length == 1)
{
Done();
SetReward(-1f);
}
}
}
// 刷新环境
public override void AgentReset()
{
academy.AcademyReset();
}
public void FixedUpdate()
{
WaitTimeInference();
}
private void WaitTimeInference()
{
if(renderCamera != null)
{
renderCamera.Render();
}
if (!academy.GetIsInference())
{
RequestDecision();
}
else
{
if (timeSinceDecision >= timeBetweenDecisionsAtInference)
{
timeSinceDecision = 0f;
RequestDecision();
}
else
{
timeSinceDecision += Time.fixedDeltaTime;
}
}
}
}
总结
暂时先这三个例子吧,其他的以后慢慢添加,其实主要也就是看代码,理解思路,以后自己要是做的时候可以参照着来,可能每个游戏都不一样,但是一些基本的东西应该是一样的。
好了,今天就到这里了,希望对学习理解有帮助,大神看见勿喷,仅为自己的学习理解,能力有限,请多包涵,部分图片来自网络,侵删。