分享github 工程:https://github.com/IsaWinding/mlagent02.git
1.本文介绍要一个战斗中,英雄Ai行为操作用MlAgent来训练操作
图中为英雄训练的次数,生成的不同.onnx 文件。
训练次数较少的时候,训练的英雄会随机在周围的墙边移动
随着训练次数变多,训练的英雄会直接奔向敌方小兵的位置,进行攻击操作。
英雄行为脚本:
using System.Collections;
using System.Collections.Generic;
using UnityEngine;
using Unity.MLAgents;
using Unity.MLAgents.Sensors;
using Unity.MLAgents.Actuators;
public class UnitAgent : Agent
{
private Unit player;
public override void Initialize()
{
player = this.GetComponent<Unit>();
MaxStep = 20000;
}
public override void CollectObservations(VectorSensor sensor)
{
if (player == null)
player = this.GetComponent<Unit>();
if (player == null)
return;
sensor.AddObservation(player.transform.localPosition);
sensor.AddObservation(player.curHp);
sensor.AddObservation(player.Target);
if (player.Target != null && !player.Target.IsDead())
{
sensor.AddObservation(player.Target.transform.localPosition);
sensor.AddObservation(player.Target.curHp);
}
else
{
sensor.AddObservation(Vector3.zero);
sensor.AddObservation(0);
}
}
public override void OnActionReceived(ActionBuffers actions)
{
var controlSignal = Vector3.zero;
var vectorAction = actions.ContinuousActions;
controlSignal.x = vectorAction[0];
controlSignal.z = vectorAction[1];
player.MoveByDir(controlSignal);
if (vectorAction[2] > 0f){
player.AtkAjust(()=> {
AddReward(0.1f);
},()=> {
AddReward(1f);
});
}
//死亡重启
if (player.curHp <= 0)
{
//AddReward(-3f);
EndEpisode();
}
//越界重启
if (Mathf.Abs(player.transform.localPosition.y) >= 2 || Mathf.Abs(player.transform.localPosition.x) >= 20 ||
Mathf.Abs(player.transform.localPosition.x) >= 20)
{
//AddReward(-1f);
EndEpisode();
}
}
public override void OnEpisodeBegin()
{
if (player == null)
player = this.GetComponent<Unit>();
if (player == null)
return;
player.ResetBattleField();
}
public override void Heuristic(in ActionBuffers actionsOut)
{
var continuousActionsOut = actionsOut.ContinuousActions;
continuousActionsOut[0] = Input.GetAxis("Horizontal");
continuousActionsOut[1] = Input.GetAxis("Vertical");
continuousActionsOut[2] = Input.GetAxis("Jump");
}
}
单位逻辑脚本:
using System.Collections;
using System.Collections.Generic;
using UnityEngine;
public class Unit:MonoBehaviour
{
[HideInInspector]
public int id;
public UnitInfo info;
public AIPolicy aiPolicy;
private AIPath path;
public Unit Target;
private UnitAni unitAni;
public float followDistance { get { return info.flowRange; } }
public float warnDistance { get { return info.warRange; } }
public float atkDistance { get { return info.atkRange; } }
public float moveSpeed { get { return info.moveSpeed; } }
public float curHp { get { return info.hp; }set { info.hp = value; } }
public float MaxHp { get { return info.hpMax; } }
private float atkCd { get { return info.atkCd; } }
private float atk { get { return info.atk; } }
public CampType campType { get { return info.campType; } }
private float nextAtkTime = 0f;
private bool IsNeedBackToHome = false;
private bool isAttack = false;
private bool isAttackTarget = false;
private bool isMoveToTargetPos = false;
private Vector3 findTargetPos;
private Vector3 targetPos;
private bool isCanReborn = true;
private Vector3 bornPos;
public List<Unit> AllUnits;
public UnitManager manager;
public void ResetBattleField()
{
manager.OnEpisodeBegin();
}
public void CreatUnit(int pId,UnitInfo pInfo,Vector3 pBornPos, AIPolicy pAIPolicy, AIPath pPath)
{
id = pId;
info = pInfo;
bornPos = pBornPos;
aiPolicy = pAIPolicy;
path = pPath.InitByGOList();
path.Init();
unitAni = this.gameObject.GetComponent<UnitAni>();
}
public void SetColor(Color pColor)
{
var mat = GetComponentInChildren<Material>();
mat.SetColor("",pColor);
}
public void OnEpisodeBegin()
{
Idle();
isAttackTarget = false;
CancelInvoke("Reborn");
Reborn();
}
public void OnAiAction(List<Unit> allUnits)
{
AllUnits = allUnits;
if (IsDead())
return;
if(aiPolicy != null)
aiPolicy.OnRun(allUnits);
}
public void Idle(){
SetCurTarget(null);
StopAttack();
StopMove();
unitAni.PlayAniByType(AniNameType.Idle, 1);
}
public void StopMove(){isMoveToTargetPos = false;}
public void Attack(){isAttack = true;}
public void StopAttack(){isAttack = false;}
public void SetCurTarget(Unit pCharacterInput){
Target = pCharacterInput;
findTargetPos = this.transform.position;
}
public bool NeedBackToHome(){
if (IsNeedBackToHome)
return true;
if (Target == null)
return false;
if (Target.IsDead())
return true;
var distance = Vector3.Distance(this.transform.position, findTargetPos);
return distance >= followDistance;
}
public void SelectAdjustTarget(List<Unit> pAllUnits){
var unit = GetOneUnitInRange(pAllUnits,warnDistance);
SetCurTarget(unit);
}
public bool IsDead(){return info.hp <= 0;}
public bool IsHaveTarget(){return Target != null;}
public void BackToHome(){
var distance = Vector3.Distance(this.transform.position, findTargetPos);
if (distance >= 1){
IsNeedBackToHome = true;
SetCurTarget(null);
MoveToTargetPos(findTargetPos);
}
else{
IsNeedBackToHome = false;
}
}
public void OnLoopMove(){
if (path.IsReachNextPoint(this.transform.position))
{
path.OnReachNextPoint();
}
var nextPos = path.GetNextPoint();
if (nextPos != null)
MoveToTargetPos(nextPos.pos);
else
Idle();
}
public void MoToCurTarget(){
StopAttack();
MoveToTargetPos(Target.transform.position);
}
public void MoveByDir(Vector3 pDir)
{
var targetPos = this.transform.position + pDir.normalized*2f;
MoveToTargetPos(targetPos);
}
public void MoveToTargetPos(Vector3 pTargetPos){
StopAttack();
targetPos = pTargetPos;
isMoveToTargetPos = true;
}
public void FaceToCurTarget(){
if(Target != null)
this.transform.LookAt(Target.transform.position);
}
private bool IsInRange(Unit pTarget, float pRange)
{
var targetPos = pTarget.transform.position;
var selfPos = this.transform.position;
if (Vector3.Distance(targetPos, selfPos) <= pRange){
return true;
}
return false;
}
public bool CurTargetInRange(float pRange){
if (Target != null && !Target.IsDead())
{
return IsInRange(Target, pRange);
}
return false;
}
public bool IsCanCampSelect(CampType pTargetCampType)
{
if (campType == CampType.PlayerA)
{
return pTargetCampType == CampType.PlayerB || pTargetCampType == CampType.Monster;
}
else if (campType == CampType.PlayerB)
{
return pTargetCampType == CampType.PlayerA || pTargetCampType == CampType.Monster;
}
else if (campType == CampType.Monster)
{
return pTargetCampType == CampType.PlayerA || pTargetCampType == CampType.PlayerB;
}
return false;
}
public Unit GetOneUnitInRange(List<Unit> pAllUnits,float pRange)
{
for (var i = 0; i < pAllUnits.Count; i++){
var unit_ = pAllUnits[i];
if (!unit_.IsDead() && IsCanCampSelect(unit_.campType) &&IsInRange(pAllUnits[i], pRange))
return pAllUnits[i];
}
return null;
}
public bool NeedSelectTarget(List<Unit> pAllUnits)//是否需要重新选择追踪目标
{
if (Target != null && !Target.IsDead() && IsInRange(Target, followDistance))
return false;
var character = GetOneUnitInRange(pAllUnits,warnDistance);
return character != null;
}
private void OnDead()
{
isAttack = false;
isAttackTarget = false;
isMoveToTargetPos = false;
Target = null;
unitAni.PlayAniByType(AniNameType.Dead, 6);
if (isCanReborn)
{
Invoke("Reborn", 2f);
isCanReborn = false;
}
}
public void HpChange(float pOld, float hp)
{
//characterHp.SetHpInfo(hp, MaxHp);
if (hp <= 0){
OnDead();
}
}
private void Reborn()
{
isAttack = false;
isAttackTarget = false;
isMoveToTargetPos = false;
Target = null;
curHp = MaxHp;
//SetHpInfo();
RpcReborn();
}
private void RpcReborn()
{
this.transform.position = bornPos;
isCanReborn = true;
}
private float oldHp;
public void OnDamage(float pDamage, System.Action pAtkCB, System.Action pKillCB)
{
oldHp = curHp;
curHp -= pDamage;
if (curHp > MaxHp)
curHp = MaxHp;
if (curHp < 0)
curHp = 0;
HpChange(oldHp, curHp);
//SetHpInfo();
if (curHp <= 0){
pKillCB?.Invoke();
}
pAtkCB?.Invoke();
}
public void AtkAjust(System.Action pAtkCB,System.Action pKillCB)
{
if (Target != null && !Target.IsDead()&& IsInRange(Target, followDistance))
{
if (IsInRange(Target, atkDistance))
AtkCurTarget(pAtkCB, pKillCB);
else {
MoveToCurTarget();
}
}
else
{
SelectAdjustTarget(AllUnits);
MoveToCurTarget();
}
}
public void MoveToCurTarget()
{
if(Target!= null)
MoveToTargetPos(Target.transform.position);
}
public void AtkCurTarget(System.Action pAtkCB, System.Action pKillCB)
{
StopMove();
this.transform.LookAt(Target.transform);
var curTime = Time.realtimeSinceStartup;
if (curTime >= nextAtkTime)
{
isAttackTarget = true;
nextAtkTime = curTime + atkCd;
unitAni.PlayAniByType(AniNameType.Attack, 4, () => {
DoNormalAttack(pAtkCB, pKillCB);
isAttackTarget = false;
});
}
}
public void DoNormalAttack(System.Action pAtkCB, System.Action pKillCB)
{
if (Target != null)
Target.OnDamage(atk, pAtkCB, pKillCB);
}
private Vector3 GetMoveDir(Vector3 pTargetPos){
var dir = pTargetPos - this.transform.position;
return dir.normalized * moveSpeed;
}
private void FixedUpdate()
{
if (IsDead())
return;
//curTime += Time.deltaTime;
if (isAttack)
{
var curTime = Time.realtimeSinceStartup;
if (curTime >= nextAtkTime)
{
nextAtkTime = curTime + atkCd;
unitAni.PlayAniByType(AniNameType.Attack, 4, () =>
{
isAttack = false;
DoNormalAttack(null, null);
});
}
}
if (isMoveToTargetPos && !isAttackTarget)
{
if (Vector3.Distance(targetPos, this.transform.position) <= 1)
{
isMoveToTargetPos = false;
}
else
{
this.transform.localPosition += GetMoveDir(targetPos) * Time.deltaTime;
this.transform.LookAt(targetPos);
unitAni.PlayAniByType(AniNameType.Move, 3);
}
}
if (!isAttack && !isMoveToTargetPos&&!isAttackTarget)
{
unitAni.PlayAniByType(AniNameType.Idle, 1);
}
}
}
寻路路径
using System.Collections;
using System.Collections.Generic;
using UnityEngine;
public enum PathType
{
Once = 1,
Loop = 2,
PingPong = 3,
}
[System.Serializable]
public class PathPoint
{
public Vector3 pos;
public PathPoint NextPoint;
public PathPoint PrePoint;
public PathPoint(Vector3 pPos)
{
pos = pPos;
}
}
[System.Serializable]
public class AIPath
{
[HideInInspector]
public List<PathPoint> Paths = new List<PathPoint>();
public List<GameObject> PathGos = new List<GameObject>();
public PathType pathType = PathType.Loop;
private PathPoint nextPoint;
public void SetAdjustTwoPointPath(Vector3 pMiddlePos, float pDistance)
{
Paths.Clear();
var posR = pMiddlePos + Vector3.right * pDistance;
var pathPointR = new PathPoint(posR);
Paths.Add(pathPointR);
var posL = pMiddlePos + Vector3.right * -pDistance;
var pathPointL = new PathPoint(posL);
Paths.Add(pathPointL);
}
public AIPath InitByGOList()
{
var aiPath = new AIPath();
aiPath.Paths.Clear();
foreach (var temp in (PathGos))
{
var pathPointR = new PathPoint(temp.transform.position);
aiPath.Paths.Add(pathPointR);
}
return aiPath;
}
public void Init()
{
for (var i = 0; i < Paths.Count; i++)
{
var path = Paths[i];
if (i == 0)
{
path.NextPoint = Paths[i + 1];
if (pathType == PathType.Loop)
path.PrePoint = Paths[Paths.Count - 1];
}
else if (i == Paths.Count - 1)
{
if (pathType == PathType.Loop)
path.NextPoint = Paths[0];
path.PrePoint = Paths[i - 1];
}
else
{
path.NextPoint = Paths[i + 1];
path.PrePoint = Paths[i - 1];
}
}
nextPoint = Paths[0];
}
public PathPoint GetNextPoint()
{
return nextPoint;
}
public bool IsReachNextPoint(Vector3 pos)
{
if (Mathf.Abs(nextPoint.pos.x - pos.x) <= 1)
return true;
return false;
}
private bool isForward = true;
public void OnReachNextPoint()
{
if (pathType == PathType.Once)
{
nextPoint = nextPoint.NextPoint;
}
else if (pathType == PathType.Loop)
{
nextPoint = nextPoint.NextPoint;
}
else if (pathType == PathType.PingPong)
{
if (isForward)
{
if (nextPoint.NextPoint != null)
nextPoint = nextPoint.NextPoint;
else
{
nextPoint = nextPoint.PrePoint;
isForward = false;
}
}
else
{
if (nextPoint.PrePoint != null)
nextPoint = nextPoint.PrePoint;
else
{
nextPoint = nextPoint.NextPoint;
isForward = true;
}
}
}
}
}
ai策略代码:
using System.Collections;
using System.Collections.Generic;
using UnityEngine;
public enum AIPolicyType
{
None = 0,
Soldier = 1,
Tower = 2,
Hero = 3
}
public class TowerPolicy : AIPolicy{
protected override List<AIAction> allActions {
get {
if (allActions_ == null)
{
allActions_ = new List<AIAction>() {
new AttackToTargetAction(),
new SelectTargetAction(),
new IdleAction()};
}
return allActions_;
}
}
public TowerPolicy(Unit pMonsterAi)
{
bindAi = pMonsterAi;
}
}
public class SoldierPolicy : AIPolicy{
protected override List<AIAction> allActions {
get {
if (allActions_ == null)
{
allActions_ = new List<AIAction>() {
new AttackToTargetAction(),
new MoveToTargetAction(),
new SelectTargetAction(),
new MoveAction(),
new IdleAction()};
}
return allActions_;
}
}
public SoldierPolicy(Unit pMonsterAi) {
bindAi = pMonsterAi;
}
}
public class AIPolicy
{
protected List<AIAction> allActions_;
protected virtual List<AIAction> allActions { get { return allActions_;} }
private AIAction curAiAction;
protected Unit bindAi;
public AIPolicy() { }
public AIPolicy(Unit pMonsterAi) {
bindAi = pMonsterAi;
}
public void OnRun(List<Unit> pAllUnits)
{
for (var i = 0; i < allActions.Count; i++)
{
if (allActions[i].IsPass(bindAi, pAllUnits))
{
curAiAction = allActions[i];
break;
}
}
curAiAction.OnAcion(bindAi, pAllUnits);
}
}
public class BackToHomeAction : AIAction{
public override AIConditioner Conditioner { get { if (conditioner == null) conditioner = new BackToHomeConditioner(); return conditioner; } }
public override void OnAcion(Unit pBindAi, List<Unit> pAllUnits){
pBindAi.BackToHome();
}
}
public class AttackToTargetAction : AIAction
{
public override AIConditioner Conditioner { get { if (conditioner == null) conditioner = new AttackToTargetConditioner(); return conditioner; } }
public override void OnAcion(Unit pBindAi, List<Unit> pAllUnits)
{
pBindAi.StopMove();
pBindAi.FaceToCurTarget();
pBindAi.Attack();
}
}
public class MoveToTargetAction : AIAction
{
public override AIConditioner Conditioner { get { if (conditioner == null) conditioner = new MoveToTargetConditioner(); return conditioner; } }
public override void OnAcion(Unit pBindAi, List<Unit> pAllUnits)
{
pBindAi.MoToCurTarget();
}
}
public class SelectTargetAction : AIAction
{
public override AIConditioner Conditioner { get { if (conditioner == null) conditioner = new ForcusTargetConditioner(); return conditioner; } }
public override void OnAcion(Unit pBindAi, List<Unit> pAllUnits)
{
pBindAi.SelectAdjustTarget(pAllUnits);
}
}
public class MoveAction : AIAction {
public override AIConditioner Conditioner { get { if (conditioner == null) conditioner = new MoveConditioer(); return conditioner; } }
public override void OnAcion(Unit pBindAi, List<Unit> pAllUnits)
{
pBindAi.OnLoopMove();
}
}
public class IdleAction : AIAction {
public override AIConditioner Conditioner { get {if (conditioner == null) conditioner = new IdleConditoner(); return conditioner; } }
public override void OnAcion(Unit pBindAi, List<Unit> pAllUnits)
{
pBindAi.Idle();
}
}
public class AIAction
{
protected AIConditioner conditioner;
public virtual AIConditioner Conditioner { get { if (conditioner == null) conditioner = new AIConditioner(); return conditioner; } }
public virtual void OnAcion(Unit pBindAi, List<Unit> pAllUnits) { }
public bool IsPass(Unit pBindAi, List<Unit> pAllUnits) {
return Conditioner.IsPass(pBindAi, pAllUnits);
}
}
public class BackToHomeConditioner : AIConditioner//目标超出追踪范围,返回原地
{
public override bool IsPass(Unit pBindAi, List<Unit> pAllUnits)
{
return pBindAi.NeedBackToHome();
}
}
public class AttackToTargetConditioner : AIConditioner//攻击目标
{
public override bool IsPass(Unit pBindAi, List<Unit> pAllUnits)
{
return pBindAi.CurTargetInRange(pBindAi.atkDistance);
}
}
public class MoveToTargetConditioner : AIConditioner//追踪目标
{
public override bool IsPass(Unit pBindAi, List<Unit> pAllUnits)
{
return pBindAi.CurTargetInRange(pBindAi.followDistance);
}
}
public class ForcusTargetConditioner : AIConditioner//锁定目标
{
public override bool IsPass(Unit pBindAi, List<Unit> pAllUnits)
{
return pBindAi.NeedSelectTarget(pAllUnits);
}
}
public class MoveConditioer : AIConditioner //巡逻d
{
public override bool IsPass(Unit pBindAi, List<Unit> pAllUnits)
{
return !pBindAi.IsHaveTarget();
}
}
public class IdleConditoner: AIConditioner//待机
{
public override bool IsPass(Unit pBindAi, List<Unit> pAllUnits)
{
return !pBindAi.IsDead();
}
}
public class AIConditioner
{
public virtual bool IsPass(Unit pBindAi, List<Unit> pAllUnits)
{
return true;
}
}
大部分核心代码分享了,感兴趣的同学,可以直接下载github 的资源测试了解。
再次分享github 工程:https://github.com/IsaWinding/mlagent02.git