Unity 生成txt
using System.Collections.Generic;
using Unity.MLAgents;
using UnityEngine;
using Obi;
using System.IO;
public enum Team
{
Red = 0,
Blue = 1
}
public class SatelliteEnvController : MonoBehaviour
{
private StreamWriter writer;
[System.Serializable]
public class PursuerInfo
{
public AgentSatelliteP Agent;
[HideInInspector]
public Vector3 StartingPos;
[HideInInspector]
public Quaternion StartingRot;
[HideInInspector]
public Rigidbody Rb;
}
[System.Serializable]
public class EvaderInfo
{
public AgentSatelliteE Agent;
[HideInInspector]
public Vector3 StartingPos;
[HideInInspector]
public Quaternion StartingRot;
[HideInInspector]
public Rigidbody Rb;
}
[System.Serializable]
public class TargetInfo
{
public AgentSatelliteTar Agent;
[HideInInspector]
public Vector3 StartingPos;
[HideInInspector]
public Quaternion StartingRot;
[HideInInspector]
public Rigidbody Rb;
}
[Tooltip("Max Environment Steps")] public int MaxEnvironmentSteps = 3000;
//List of Agents On Platform
public List<PursuerInfo> PursuerList = new List<PursuerInfo>();
public List<EvaderInfo> EvaderList = new List<EvaderInfo>();
public List<TargetInfo> TargetList = new List<TargetInfo>();
private SimpleMultiAgentGroup m_RedPursuerGroup;
private SimpleMultiAgentGroup m_BlueEvaderGroup;
private int m_ResetTimer;
private int resetCounter = 0;
// 在ResetScene方法外部声明一个变量来保存上一局的初始位置
private Dictionary<Transform, Vector3> initialPositions = new Dictionary<Transform, Vector3>();
void Start()
{
m_RedPursuerGroup = new SimpleMultiAgentGroup();
m_BlueEvaderGroup = new SimpleMultiAgentGroup();
foreach (var item in PursuerList)
{
item.StartingPos = item.Agent.transform.localPosition;
item.StartingRot = item.Agent.transform.localRotation;
item.Rb = item.Agent.GetComponent<Rigidbody>();
m_RedPursuerGroup.RegisterAgent(item.Agent);
}
foreach (var item in EvaderList)
{
item.StartingPos = item.Agent.transform.localPosition;
item.StartingRot = item.Agent.transform.localRotation;
item.Rb = item.Agent.GetComponent<Rigidbody>();
m_BlueEvaderGroup.RegisterAgent(item.Agent);
}
foreach (var item in TargetList)
{
item.StartingPos = item.Agent.transform.localPosition;
item.StartingRot = item.Agent.transform.localRotation;
item.Rb = item.Agent.GetComponent<Rigidbody>();
m_BlueEvaderGroup.RegisterAgent(item.Agent);
}
ResetScene();
// Create or open the text file
string path = Application.dataPath + "/PositionsLog.txt";
//string path = "C:/Users/YourUsername/Desktop/PositionsLog.txt";
writer = new StreamWriter(path, true); // true to append data
}
void FixedUpdate()
{
var agentTar = TargetList[0];
Vector3 target_Pos = agentTar.Agent.transform.localPosition;
m_ResetTimer += 1;
var agentred1 = PursuerList[0];
var agentred2 = PursuerList[1];
var agentred3 = PursuerList[2];
var agentred4 = PursuerList[3];
var agentblue = EvaderList[0];
Vector3 posagentred1 = agentred1.Agent.transform.localPosition;
Vector3 posagentblue = agentblue.Agent.transform.localPosition;
Vector3 posagentred2 = agentred2.Agent.transform.localPosition;
Vector3 relaPos2 = posagentblue - posagentred2;
Vector3 posagentred3 = agentred3.Agent.transform.localPosition;
Vector3 posagentred4 = agentred4.Agent.transform.localPosition;
// 手动格式化 posagentred3 和 posagentred4 的位置,不带括号
string posagentred1Str = string.Join(", ", posagentred1.x, posagentred1.y, posagentred1.z);
string posagentred2Str = string.Join(", ", posagentred2.x, posagentred2.y, posagentred2.z);
string posagentred3Str = string.Join(", ", posagentred3.x, posagentred3.y, posagentred3.z);
string posagentred4Str = string.Join(", ", posagentred4.x, posagentred4.y, posagentred4.z);
string target_PosStr = string.Join(", ", target_Pos.x, target_Pos.y, target_Pos.z);
string posagentblueStr = string.Join(", ", posagentblue.x, posagentblue.y, posagentblue.z);
// Write the positions to the text file
writer.WriteLine($"{Time.time},{posagentred1Str},{posagentred2Str},{posagentred3Str},{posagentred4Str},{target_PosStr},{posagentblueStr}");
// 追踪者中心距离目标的距离
Vector3 relaPos0 = target_Pos - (posagentred1 + posagentred2 + posagentred3 + posagentred4) / 4;
float distancePE0 = relaPos0.magnitude;
// evader evaderTar 距离
Vector3 relaPosEE = target_Pos - posagentblue;
float disEE = relaPosEE.magnitude;
// 追踪者与障碍物不能发生碰撞 距离至少相差20
Vector3 relaPosPB = posagentblue - (posagentred1 + posagentred2 + posagentred3 + posagentred4) / 4;
float distancePB = relaPosPB.magnitude;
// persuer 之间的距离
Vector3 relaPos12 = posagentred1 - posagentred2;
Vector3 relaPos13 = posagentred1 - posagentred3;
Vector3 relaPos14 = posagentred1 - posagentred4;
Vector3 relaPos23 = posagentred2 - posagentred3;
Vector3 relaPos24 = posagentred2 - posagentred4;
Vector3 relaPos34 = posagentred3 - posagentred4;
float distanceP12 = relaPos12.magnitude;
float distanceP13 = relaPos13.magnitude;
float distanceP14 = relaPos14.magnitude;
float distanceP23 = relaPos23.magnitude;
float distanceP24 = relaPos24.magnitude;
float distanceP34 = relaPos34.magnitude;
int maxPB = 6; // 假设1000为最大可能距离
int maxP1B = 2; // 假设1000为最大可能距离
int edgeD = 8;
float edgeDcl = 11.3124f;
float R_danger = 0;
float Re_step = 0;
float ReT_step = 0;
// attacker: 编队奖励和引导捕获target的奖励
float R_form = -Mathf.Abs(distanceP12 - edgeD) - Mathf.Abs(distanceP13 - edgeDcl) - Mathf.Abs(distanceP14 - edgeD) - Mathf.Abs(distanceP23 - edgeD) - Mathf.Abs(distanceP24 - edgeDcl) - Mathf.Abs(distanceP34 - edgeD);
float Rstep = -distancePE0 / 200 + R_form / 20;
//attacker: 引导避免碰撞evader的奖励
//if (distancePB <= 10)
//{
// Rstep = Rstep + (distancePB - maxPB) / 500;
//}
// defender: 引导拦截attacker的奖励
Re_step = 1 - distancePB / 50; //20
// defender: 引导避免碰撞target的奖励
if (disEE <= 5)
{
R_danger = disEE / 200;
}
Re_step = R_danger + Re_step;
// Target:引导target远离 attacker的奖励
//ReT_step = distancePE0 / 200; //+m_ResetTimer/ MaxEnvironmentSteps/10
m_RedPursuerGroup.AddGroupReward(Rstep);
m_BlueEvaderGroup.AddGroupReward(Re_step); //ReT_step + defender 和 target 同属一个蓝方阵营
double capD = 2; // 捕获界限
// 终端奖励:attacker捕获target成功
if (distancePE0 <= capD) // Form_Centroid_disRstep2_V2
{
m_RedPursuerGroup.AddGroupReward(200);//1 - (float)m_ResetTimer / MaxEnvironmentSteps 500-(float)m_ResetTimer / MaxEnvironmentSteps*200
m_BlueEvaderGroup.AddGroupReward(0); //m_BlueEvaderGroup.AddGroupReward(0);//-1
m_RedPursuerGroup.GroupEpisodeInterrupted();
m_BlueEvaderGroup.GroupEpisodeInterrupted();
ResetScene();
Debug.Log("1");
Debug.Log(Time.time);
}
// 终端奖励:蓝方阵营发生碰撞,defender和target
if (disEE <= 1.5) // Form_Centroid_disRstep2_V2
{
m_RedPursuerGroup.AddGroupReward(0);//1 - (float)m_ResetTimer / MaxEnvironmentSteps 500-(float)m_ResetTimer / MaxEnvironmentSteps*200
m_BlueEvaderGroup.AddGroupReward(-200); //m_BlueEvaderGroup.AddGroupReward(0);//-1
m_RedPursuerGroup.GroupEpisodeInterrupted();
m_BlueEvaderGroup.GroupEpisodeInterrupted();
ResetScene();
//Debug.Log("1");
}
// 终端奖励:defender 拦截 attacker 成功
if (distancePB <= maxPB) // Form_Centroid_disRstep2_V2
//if (R_form >= -12)
{
m_RedPursuerGroup.AddGroupReward(-200);//1 - (float)m_ResetTimer / MaxEnvironmentSteps 500-(float)m_ResetTimer / MaxEnvironmentSteps*200
m_BlueEvaderGroup.AddGroupReward(0); // 我只是想僵持着 //m_BlueEvaderGroup.AddGroupReward(0);//-1
m_RedPursuerGroup.GroupEpisodeInterrupted();
m_BlueEvaderGroup.GroupEpisodeInterrupted();
ResetScene();
Debug.Log("colli");
}
// 终端奖励:attacker在规定时间内没有捕获target
if (m_ResetTimer >= MaxEnvironmentSteps && MaxEnvironmentSteps > 0)
{
m_RedPursuerGroup.AddGroupReward(-distancePE0 * 10-100); //-distancePE/maxD - distancePE0*4 - 20
m_BlueEvaderGroup.AddGroupReward(0); //1+ distancePE/maxD
m_BlueEvaderGroup.GroupEpisodeInterrupted();
m_RedPursuerGroup.GroupEpisodeInterrupted();
ResetScene();
Debug.Log("2");
}
}
public void ResetScene()
{
m_ResetTimer = 0;
//Debug.Log("m_ResetTimer:" + m_ResetTimer);
int resetFrequency = 1;
resetCounter++;
foreach (var item in PursuerList)
{
var randomPosX = Random.Range(-1f, 1f);
var randomPosY = Random.Range(-1f, 1f);
var randomPosZ = Random.Range(-1f, 1f);
//Debug.Log("StartingPos:" + newStartPos);
var newStartPos = item.Agent.initialPos + new Vector3(randomPosX, randomPosY, randomPosZ);
var newRot = Quaternion.Euler(0, 0, 0);
item.Agent.transform.SetPositionAndRotation(newStartPos, newRot);
item.Rb.velocity = Vector3.zero;
item.Rb.angularVelocity = Vector3.zero;
}
foreach (var item in TargetList)
{
var newStartPosT = item.Agent.initialPos;
var newRot = Quaternion.Euler(0, 0, 0);
item.Agent.transform.SetPositionAndRotation(newStartPosT, newRot);
item.Rb.velocity = Vector3.zero;
item.Rb.angularVelocity = Vector3.zero;
}
foreach (var item in EvaderList)
{
var randomPosX = RandomRangeSelector(-6f, -4f, 4f, 6f);
var randomPosY = RandomRangeSelector(-6f, -4f, 4f, 6f);
var randomPosZ = RandomRangeSelector(-3f, -2f, 2f, 3f);
//var newStartPosE = item.Agent.initialPos + new Vector3(randomPosX, randomPosY, randomPosZ);
var newStartPosE = item.Agent.initialPos + new Vector3(randomPosX, randomPosY, randomPosZ);
var newRot = Quaternion.Euler(0, 0, 0);
item.Agent.transform.SetPositionAndRotation(newStartPosE, newRot);
item.Rb.velocity = Vector3.zero;
item.Rb.angularVelocity = Vector3.zero;
}
}
float RandomRangeSelector(float min1, float max1, float min2, float max2)
{
// 生成0或1的随机整数来选择区间
return Random.Range(0, 2) == 0 ? Random.Range(min1, max1) : Random.Range(min2, max2);
}
void OnDestroy()
{
// Close the writer when the object is destroyed
if (writer != null)
{
writer.Close();
}
}
}
matplotlib绘制图形
import numpy as np
import matplotlib.pyplot as plt
# 读取txt文件中的数据
data = np.loadtxt('PositionsLog.txt', delimiter=',', dtype=float)
# 提取 x, y, z 列数据
xa1 = data[:, 1]
ya1 = data[:, 2]
za1 = data[:, 3]
xa2 = data[:, 4]
ya2 = data[:, 5]
za2 = data[:, 6]
xa3 = data[:, 7]
ya3 = data[:, 8]
za3 = data[:, 9]
xa4 = data[:, 10]
ya4 = data[:, 11]
za4 = data[:, 12]
xT = data[:, 13]
yT = data[:, 14]
zT = data[:, 15]
xD = data[:, 16]
yD = data[:, 17]
zD = data[:, 18]
# 创建三维轨迹图
fig = plt.figure(figsize=(10, 8))
ax = fig.add_subplot(111, projection='3d')
# 绘制轨迹线条
ax.plot(xa1, ya1, za1, label='Attacker1', linestyle='-', color='red')
ax.plot(xa2, ya2, za2, label='Attacker2', linestyle='--', color='green')
ax.plot(xa3, ya3, za3, label='Attacker3', linestyle='-.', color='blue')
ax.plot(xa4, ya4, za4, label='Attacker4', linestyle=':', color='orange')
ax.plot(xT, yT, zT, label='Defender', linestyle='-', color='purple')
# 绘制终点
ax.scatter(xa1[-1], ya1[-1], za1[-1], color='red', marker='^', s=100, label='End Attacker1')
ax.scatter(xa2[-1], ya2[-1], za2[-1], color='green', marker='^', s=100, label='End Attacker2')
ax.scatter(xa3[-1], ya3[-1], za3[-1], color='blue', marker='^', s=100, label='End Attacker3')
ax.scatter(xa4[-1], ya4[-1], za4[-1], color='orange', marker='^', s=100, label='End Attacker4')
ax.scatter(xT[-1], yT[-1], zT[-1], color='purple', marker='s', s=100, label='End Defender')
# 用虚线连接四个轨迹的终点
ax.plot([xa1[-1], xa2[-1], xa3[-1], xa4[-1], xa1[-1]],
[ya1[-1], ya2[-1], ya3[-1], ya4[-1], ya1[-1]],
[za1[-1], za2[-1], za3[-1], za4[-1], za1[-1]],
linestyle='--', color='black', linewidth=2, label='Connection between End Points')
# 设置标签
ax.set_xlabel('X')
ax.set_ylabel('Y')
ax.set_zlabel('Z')
#ax.set_title('3D Position Trajectory')
# 显示图例和图形
ax.legend()
plt.show()