ml agent 结合matplotlab 生成轨迹图

Unity 生成txt

using System.Collections.Generic;
using Unity.MLAgents;
using UnityEngine;
using Obi;
using System.IO;


public enum Team
{
    Red = 0,
    Blue = 1
}

public class SatelliteEnvController : MonoBehaviour
{
    private StreamWriter writer;

    [System.Serializable]
    public class PursuerInfo
    {
        public AgentSatelliteP Agent;
        [HideInInspector]
        public Vector3 StartingPos;
        [HideInInspector]
        public Quaternion StartingRot;
        [HideInInspector]
        public Rigidbody Rb;
    }

    [System.Serializable]
    public class EvaderInfo
    {
        public AgentSatelliteE Agent;
        [HideInInspector]
        public Vector3 StartingPos;
        [HideInInspector]
        public Quaternion StartingRot;
        [HideInInspector]
        public Rigidbody Rb;
    }

    [System.Serializable]
    public class TargetInfo
    {
        public AgentSatelliteTar Agent;
        [HideInInspector]
        public Vector3 StartingPos;
        [HideInInspector]
        public Quaternion StartingRot;
        [HideInInspector]
        public Rigidbody Rb;
    }
    [Tooltip("Max Environment Steps")] public int MaxEnvironmentSteps = 3000;

    //List of Agents On Platform
    public List<PursuerInfo> PursuerList = new List<PursuerInfo>();
    public List<EvaderInfo> EvaderList = new List<EvaderInfo>();
    public List<TargetInfo> TargetList = new List<TargetInfo>();
    private SimpleMultiAgentGroup m_RedPursuerGroup;
    private SimpleMultiAgentGroup m_BlueEvaderGroup;
    private int m_ResetTimer;
    private int resetCounter = 0;

    // 在ResetScene方法外部声明一个变量来保存上一局的初始位置
    private Dictionary<Transform, Vector3> initialPositions = new Dictionary<Transform, Vector3>();

    void Start()
    {
        m_RedPursuerGroup = new SimpleMultiAgentGroup();
        m_BlueEvaderGroup = new SimpleMultiAgentGroup();
        foreach (var item in PursuerList)
        {
            item.StartingPos = item.Agent.transform.localPosition;
            item.StartingRot = item.Agent.transform.localRotation;
            item.Rb = item.Agent.GetComponent<Rigidbody>();
            m_RedPursuerGroup.RegisterAgent(item.Agent);
        }
        foreach (var item in EvaderList)
        {
            item.StartingPos = item.Agent.transform.localPosition;
            item.StartingRot = item.Agent.transform.localRotation;
            item.Rb = item.Agent.GetComponent<Rigidbody>();
            m_BlueEvaderGroup.RegisterAgent(item.Agent);
        }
        foreach (var item in TargetList)
        {
            item.StartingPos = item.Agent.transform.localPosition;
            item.StartingRot = item.Agent.transform.localRotation;
            item.Rb = item.Agent.GetComponent<Rigidbody>();
            m_BlueEvaderGroup.RegisterAgent(item.Agent);
        }
        ResetScene();
        // Create or open the text file
        string path = Application.dataPath + "/PositionsLog.txt";
        //string path = "C:/Users/YourUsername/Desktop/PositionsLog.txt";
        writer = new StreamWriter(path, true); // true to append data

    }
    void FixedUpdate()
    {
        var agentTar = TargetList[0];
        Vector3 target_Pos = agentTar.Agent.transform.localPosition;
        m_ResetTimer += 1;
        var agentred1 = PursuerList[0];
        var agentred2 = PursuerList[1];
        var agentred3 = PursuerList[2];
        var agentred4 = PursuerList[3];
        var agentblue = EvaderList[0];
        Vector3 posagentred1 = agentred1.Agent.transform.localPosition;
        Vector3 posagentblue = agentblue.Agent.transform.localPosition;
        Vector3 posagentred2 = agentred2.Agent.transform.localPosition;
        Vector3 relaPos2 = posagentblue - posagentred2;
        Vector3 posagentred3 = agentred3.Agent.transform.localPosition;
         Vector3 posagentred4 = agentred4.Agent.transform.localPosition;
        // 手动格式化 posagentred3 和 posagentred4 的位置,不带括号
        string posagentred1Str = string.Join(", ", posagentred1.x, posagentred1.y, posagentred1.z);
        string posagentred2Str = string.Join(", ", posagentred2.x, posagentred2.y, posagentred2.z);
        string posagentred3Str = string.Join(", ", posagentred3.x, posagentred3.y, posagentred3.z);
        string posagentred4Str = string.Join(", ", posagentred4.x, posagentred4.y, posagentred4.z);
        string target_PosStr = string.Join(", ", target_Pos.x, target_Pos.y, target_Pos.z);
        string posagentblueStr = string.Join(", ", posagentblue.x, posagentblue.y, posagentblue.z);

        // Write the positions to the text file
        writer.WriteLine($"{Time.time},{posagentred1Str},{posagentred2Str},{posagentred3Str},{posagentred4Str},{target_PosStr},{posagentblueStr}");
        // 追踪者中心距离目标的距离
        Vector3 relaPos0 = target_Pos - (posagentred1 + posagentred2 + posagentred3 + posagentred4) / 4;
        float distancePE0 = relaPos0.magnitude;

        // evader evaderTar 距离
        Vector3 relaPosEE = target_Pos - posagentblue;
        float disEE = relaPosEE.magnitude;

        // 追踪者与障碍物不能发生碰撞   距离至少相差20
        Vector3 relaPosPB = posagentblue - (posagentred1 + posagentred2 + posagentred3 + posagentred4) / 4;
        float distancePB = relaPosPB.magnitude;
        // persuer 之间的距离
        Vector3 relaPos12 = posagentred1 - posagentred2;
        Vector3 relaPos13 = posagentred1 - posagentred3;
        Vector3 relaPos14 = posagentred1 - posagentred4;
        Vector3 relaPos23 = posagentred2 - posagentred3;
        Vector3 relaPos24 = posagentred2 - posagentred4;
        Vector3 relaPos34 = posagentred3 - posagentred4;

        float distanceP12 = relaPos12.magnitude;
        float distanceP13 = relaPos13.magnitude;
        float distanceP14 = relaPos14.magnitude;
        float distanceP23 = relaPos23.magnitude;
        float distanceP24 = relaPos24.magnitude;
        float distanceP34 = relaPos34.magnitude;

        int maxPB = 6; // 假设1000为最大可能距离
        int maxP1B = 2; // 假设1000为最大可能距离

        int edgeD = 8;
        float edgeDcl = 11.3124f;
        float R_danger = 0;
        float Re_step = 0;
        float ReT_step = 0;

        // attacker:   编队奖励和引导捕获target的奖励
        float R_form = -Mathf.Abs(distanceP12 - edgeD) - Mathf.Abs(distanceP13 - edgeDcl) - Mathf.Abs(distanceP14 - edgeD) - Mathf.Abs(distanceP23 - edgeD) - Mathf.Abs(distanceP24 - edgeDcl) - Mathf.Abs(distanceP34 - edgeD);
        float Rstep = -distancePE0 / 200 + R_form / 20;
        //attacker:   引导避免碰撞evader的奖励
        //if (distancePB <= 10)
        //{
        //    Rstep = Rstep + (distancePB - maxPB) / 500;

        //}
        // defender: 引导拦截attacker的奖励
        Re_step = 1 - distancePB / 50;  //20
      
        // defender: 引导避免碰撞target的奖励
        if (disEE <= 5)
        {
            R_danger = disEE / 200;
        }
        Re_step = R_danger + Re_step;

        // Target:引导target远离 attacker的奖励
        //ReT_step = distancePE0 / 200; //+m_ResetTimer/ MaxEnvironmentSteps/10

        m_RedPursuerGroup.AddGroupReward(Rstep);
        m_BlueEvaderGroup.AddGroupReward(Re_step);  //ReT_step + defender 和  target  同属一个蓝方阵营
        double capD = 2;  // 捕获界限

        // 终端奖励:attacker捕获target成功
        if (distancePE0 <= capD)                                                                    //    Form_Centroid_disRstep2_V2
        {
            m_RedPursuerGroup.AddGroupReward(200);//1 - (float)m_ResetTimer / MaxEnvironmentSteps    500-(float)m_ResetTimer / MaxEnvironmentSteps*200
            m_BlueEvaderGroup.AddGroupReward(0);                                    //m_BlueEvaderGroup.AddGroupReward(0);//-1
            m_RedPursuerGroup.GroupEpisodeInterrupted();
            m_BlueEvaderGroup.GroupEpisodeInterrupted();
            ResetScene();
            Debug.Log("1");
            Debug.Log(Time.time);
        }

        // 终端奖励:蓝方阵营发生碰撞,defender和target
        if (disEE <= 1.5)                                                                    //    Form_Centroid_disRstep2_V2
        {
            m_RedPursuerGroup.AddGroupReward(0);//1 - (float)m_ResetTimer / MaxEnvironmentSteps    500-(float)m_ResetTimer / MaxEnvironmentSteps*200
            m_BlueEvaderGroup.AddGroupReward(-200);                                    //m_BlueEvaderGroup.AddGroupReward(0);//-1
            m_RedPursuerGroup.GroupEpisodeInterrupted();
            m_BlueEvaderGroup.GroupEpisodeInterrupted();
            ResetScene();
            //Debug.Log("1");
        }
        // 终端奖励:defender 拦截 attacker 成功
        if (distancePB <= maxPB)                                                                  //    Form_Centroid_disRstep2_V2
                                                                                                    //if (R_form >= -12) 
        {
            m_RedPursuerGroup.AddGroupReward(-200);//1 - (float)m_ResetTimer / MaxEnvironmentSteps    500-(float)m_ResetTimer / MaxEnvironmentSteps*200
            m_BlueEvaderGroup.AddGroupReward(0);     // 我只是想僵持着                               //m_BlueEvaderGroup.AddGroupReward(0);//-1
            m_RedPursuerGroup.GroupEpisodeInterrupted();
            m_BlueEvaderGroup.GroupEpisodeInterrupted();
            ResetScene();
            Debug.Log("colli");
        }

        // 终端奖励:attacker在规定时间内没有捕获target
        if (m_ResetTimer >= MaxEnvironmentSteps && MaxEnvironmentSteps > 0)
        {
            m_RedPursuerGroup.AddGroupReward(-distancePE0 * 10-100);   //-distancePE/maxD      - distancePE0*4 - 20
            m_BlueEvaderGroup.AddGroupReward(0);   //1+ distancePE/maxD
            m_BlueEvaderGroup.GroupEpisodeInterrupted();
            m_RedPursuerGroup.GroupEpisodeInterrupted();
            ResetScene();
            Debug.Log("2");

        }
    }


    public void ResetScene()
    {
        m_ResetTimer = 0;

        //Debug.Log("m_ResetTimer:" + m_ResetTimer);
        int resetFrequency = 1;
        resetCounter++;

        foreach (var item in PursuerList)
        {
            var randomPosX = Random.Range(-1f, 1f);
            var randomPosY = Random.Range(-1f, 1f);
            var randomPosZ = Random.Range(-1f, 1f);
            //Debug.Log("StartingPos:" + newStartPos);
            var newStartPos = item.Agent.initialPos + new Vector3(randomPosX, randomPosY, randomPosZ);
            var newRot = Quaternion.Euler(0, 0, 0);
            item.Agent.transform.SetPositionAndRotation(newStartPos, newRot);
            item.Rb.velocity = Vector3.zero;
            item.Rb.angularVelocity = Vector3.zero;
        }
        foreach (var item in TargetList)
        {

            var newStartPosT = item.Agent.initialPos;
            var newRot = Quaternion.Euler(0, 0, 0);
            item.Agent.transform.SetPositionAndRotation(newStartPosT, newRot);
            item.Rb.velocity = Vector3.zero;
            item.Rb.angularVelocity = Vector3.zero;

        }
        foreach (var item in EvaderList)
        {
            var randomPosX = RandomRangeSelector(-6f, -4f, 4f, 6f);
            var randomPosY = RandomRangeSelector(-6f, -4f, 4f, 6f);
            var randomPosZ = RandomRangeSelector(-3f, -2f, 2f, 3f);

            //var newStartPosE = item.Agent.initialPos  + new Vector3(randomPosX, randomPosY, randomPosZ);
            var newStartPosE = item.Agent.initialPos + new Vector3(randomPosX, randomPosY, randomPosZ);
            var newRot = Quaternion.Euler(0, 0, 0);
            item.Agent.transform.SetPositionAndRotation(newStartPosE, newRot);
            item.Rb.velocity = Vector3.zero;
            item.Rb.angularVelocity = Vector3.zero;
        }
    }
    float RandomRangeSelector(float min1, float max1, float min2, float max2)
    {
        // 生成0或1的随机整数来选择区间
        return Random.Range(0, 2) == 0 ? Random.Range(min1, max1) : Random.Range(min2, max2);
    }
    void OnDestroy()
    {
        // Close the writer when the object is destroyed
        if (writer != null)
        {
            writer.Close();
        }
    }
}

matplotlib绘制图形

import numpy as np
import matplotlib.pyplot as plt

# 读取txt文件中的数据
data = np.loadtxt('PositionsLog.txt', delimiter=',', dtype=float)

# 提取 x, y, z 列数据
xa1 = data[:, 1]
ya1 = data[:, 2]
za1 = data[:, 3]
xa2 = data[:, 4]
ya2 = data[:, 5]
za2 = data[:, 6]
xa3 = data[:, 7]
ya3 = data[:, 8]
za3 = data[:, 9]
xa4 = data[:, 10]
ya4 = data[:, 11]
za4 = data[:, 12]
xT = data[:, 13]
yT = data[:, 14]
zT = data[:, 15]
xD = data[:, 16]
yD = data[:, 17]
zD = data[:, 18]
# 创建三维轨迹图
fig = plt.figure(figsize=(10, 8))
ax = fig.add_subplot(111, projection='3d')

# 绘制轨迹线条
ax.plot(xa1, ya1, za1, label='Attacker1', linestyle='-', color='red')
ax.plot(xa2, ya2, za2, label='Attacker2', linestyle='--', color='green')
ax.plot(xa3, ya3, za3, label='Attacker3', linestyle='-.', color='blue')
ax.plot(xa4, ya4, za4, label='Attacker4', linestyle=':', color='orange')
ax.plot(xT, yT, zT, label='Defender', linestyle='-', color='purple')

# 绘制终点
ax.scatter(xa1[-1], ya1[-1], za1[-1], color='red', marker='^', s=100, label='End Attacker1')
ax.scatter(xa2[-1], ya2[-1], za2[-1], color='green', marker='^', s=100, label='End Attacker2')
ax.scatter(xa3[-1], ya3[-1], za3[-1], color='blue', marker='^', s=100, label='End Attacker3')
ax.scatter(xa4[-1], ya4[-1], za4[-1], color='orange', marker='^', s=100, label='End Attacker4')
ax.scatter(xT[-1], yT[-1], zT[-1], color='purple', marker='s', s=100, label='End Defender')

# 用虚线连接四个轨迹的终点
ax.plot([xa1[-1], xa2[-1], xa3[-1], xa4[-1], xa1[-1]],
        [ya1[-1], ya2[-1], ya3[-1], ya4[-1], ya1[-1]],
        [za1[-1], za2[-1], za3[-1], za4[-1], za1[-1]],
        linestyle='--', color='black', linewidth=2, label='Connection between End Points')
# 设置标签
ax.set_xlabel('X')
ax.set_ylabel('Y')
ax.set_zlabel('Z')
#ax.set_title('3D Position Trajectory')

# 显示图例和图形
ax.legend()
plt.show()
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值