from envforgym import SimEnv
from stable_baselines3 import PPO
from geometry_msgs.msg import Twist
from torch.utils.tensorboard import SummaryWriter
import rospy
import matplotlib.pyplot as plt
class Test():
def __init__(self):
self.linear_velocity = 0
self.angular_velocity = 0
self.linear = []
self.angular = []
def control_callback(self, data):
self.linear_velocity = data.linear.x
self.angular_velocity = data.angular.z
if self.linear_velocity > 0:
self.linear.append(self.linear_velocity)
self.angular.append(self.angular_velocity)
if __name__ == '__main__':
env = SimEnv()
test = Test()
agent = PPO.load('/home/houjj/RLnavigation/src/models/PPO_2/368640',env=env)
obs = env.reset()[0]
terminated = False
step = 0
odom_receiver = rospy.Subscriber("/cmd_vel", Twist, test.control_callback, queue_size=1)
while not terminated:
step += 1
action,_state = agent.predict(obs)
obs,rew,terminated,truncated,info = env.step(action)
plt.figure()
plt.subplot(2, 1, 1)
plt.plot(test.linear)
plt.title('linear_velocity')
plt.subplot(2, 1, 2)
plt.plot(test.angular)
plt.title('angular_velocity')
plt.show()
上面这个是可行的,下面这个却不可行,不知道为什么,记录以下
from envforgym import SimEnv
from stable_baselines3 import PPO
from geometry_msgs.msg import Twist
from torch.utils.tensorboard import SummaryWriter
import rospy
def control_callback(data):
global linear_velocity
global angular_velocity
linear_velocity = data.linear.x
angular_velocity = data.angular.z
pass
if __name__ == '__main__':
env = SimEnv()
agent = PPO.load('/home/houjj/RLnavigation/src/models/PPO/307200',env=env)
log_dir = 'testlog'
writer = SummaryWriter(log_dir=log_dir)
obs = env.reset()[0]
terminated = False
step = 0
linear_velocity = 0
angular_velocity = 0
odom_receiver = rospy.Subscriber("/cmd_vel", Twist, control_callback, queue_size=1)
while not terminated:
step += 1
action,_state = agent.predict(obs)
obs,rew,terminated,truncated,info = env.step(action)
print("linear_velocity:", linear_velocity, "angular_velocity:", angular_velocity)
writer.add_scalar("linear_velocity", linear_velocity, step)
writer.add_scalar("angular_velocity", angular_velocity, step)
# 关闭writer
writer.close()
print("done")
# 这里缺少了ROS的主循环,需要添加ros.spin()来处理消息
rospy.spin()