no gpu, how to speed up? """
FSM State Implementations
Concrete implementations of different FSM states
"""
from typing import Dict
import numpy as np
import onnxruntime as ort
from FSM.fsm_base import FSMState, FSMStateName
from common.joystick import XboxFlag
from common.robot_data import RobotData
import math
import os
import yaml
class FSMStateMLP(FSMState):
"""MLP策略状态实现 - 与C++版本完全一致"""
def __init__(self, robot_data: RobotData):
super().__init__(robot_data)
# 获取包路径
current_dir = os.path.dirname(os.path.abspath(__file__))
config_path = os.path.join(current_dir, "config", "mlp.yaml")
with open(config_path, 'r') as f:
policy_config = yaml.safe_load(f)
# Load configuration exactly like C++
self.action_num_ = policy_config.get('actions_size', 12)
self.motor_num_ = policy_config.get('motor_num', 29)
self.dt_ = policy_config.get('dt', 0.002)
# Size configuration
size_config = policy_config.get('size', {})
self.num_hist_ = size_config.get('num_hist', 15)
self.obs_size_ = size_config.get('observations_size', 47)
# Control configuration
control_config = policy_config.get('control', {})
self.action_scale_ = control_config.get('action_scale', 0.5)
self.gait_cycle_period_ = control_config.get('gait_cycle_period', 0.9)
self.decimation_ = control_config.get('decimation', 5)
# Normalization configuration
norm_config = policy_config.get('normalization', {})
clip_config = norm_config.get('clip_scales', {})
obs_config = norm_config.get('obs_scales', {})
self.clip_obs_ = clip_config.get('clip_observations', 100.0)
self.clip_act_ = clip_config.get('clip_actions', 100.0)
self.lin_vel_scale_ = obs_config.get('lin_vel', 2.0)
self.ang_vel_scale_ = obs_config.get('ang_vel', 1.0)
self.dof_pos_scale_ = obs_config.get('dof_pos', 1.0)
self.dof_vel_scale_ = obs_config.get('dof_vel', 0.05)
# Read default joint angles (only action_num_ elements like C++)
init_config = policy_config.get('init_state', {})
default_angles_list = init_config.get('default_joint_angles', [0.0] * self.action_num_)
self.default_joint_angles_ = np.array(default_angles_list[:self.action_num_], dtype=np.float32)
# Read kp/kd gains
gains_config = policy_config.get('gains', {})
self.kp = np.array(gains_config.get('kp', [300.0] * self.motor_num_))
self.kd = np.array(gains_config.get('kd', [10.0] * self.motor_num_))
# Initialize buffers and actions
self.observations_ = np.zeros(self.obs_size_ * self.num_hist_, dtype=np.float32)
self.proprio_hist_buf_ = np.zeros(self.obs_size_ * self.num_hist_, dtype=np.float32)
self.last_actions_ = np.zeros(self.action_num_, dtype=np.float32)
self.actions_ = np.zeros(self.action_num_, dtype=np.float32)
# Flags matching C++
self.is_first_obs_ = True
self.is_first_action_ = True
self.phase_locked = False
# Initialize ONNX session
self.model_path = os.path.join(current_dir, "model", policy_config["model_path"])
self._init_onnx_session()
def _init_onnx_session(self):
"""初始化ONNX推理会话"""
try:
self.ort_session_ = ort.InferenceSession(self.model_path)
print(f"[FSMStateMLP-ONNX] ONNX model loaded successfully: {self.model_path}")
except Exception as e:
print(f"[FSMStateMLP] Failed to load ONNX model: {e}")
self.ort_session_ = None
def on_enter(self):
"""进入MLP状态"""
print("[FSMStateMLP] enter")
self.is_first_obs_ = True
self.is_first_action_ = True
def run(self, flag: XboxFlag):
"""运行MLP状态 - 与C++版本完全一致"""
print("[FSMStateMLP] run")
# Only run policy inference every decimation_ steps
if int(self.robot_data_.time_now_ / self.dt_) % self.decimation_ == 0:
self.compute_observation(flag)
self.compute_actions()
# Set joint commands exactly like C++
for i in range(self.action_num_):
# C++: robot_data_->q_d_(35 - motor_num_ + i)
joint_idx = 35 - self.motor_num_ + i
self.robot_data_.q_d_[joint_idx] = (
self.actions_[i] * self.action_scale_ + self.default_joint_angles_[i]
)
self.robot_data_.q_dot_d_[joint_idx] = 0.0
self.robot_data_.tau_d_[joint_idx] = 0.0
self.last_actions_[i] = self.actions_[i]
# Set kp/kd gains
self.robot_data_.joint_kp_p_[:self.motor_num_] = self.kp
self.robot_data_.joint_kd_p_[:self.motor_num_] = self.kd
def compute_observation(self, flag: XboxFlag):
"""计算观测量 - 与C++版本完全一致"""
t_now = float(self.robot_data_.time_now_)
# Phase calculation exactly like C++
phase = math.fmod(t_now, self.gait_cycle_period_)
cmd_norm = math.sqrt(
flag.x_speed_command * flag.x_speed_command +
flag.y_speed_command * flag.y_speed_command +
flag.yaw_speed_command * flag.yaw_speed_command
)
if cmd_norm >= 0.05:
self.phase_locked = False
tolerance = 0.1
if cmd_norm < 0.05 and abs(phase - self.gait_cycle_period_) < tolerance:
self.phase_locked = True
if self.phase_locked:
phase = 0
# Command vector exactly like C++
command = np.array([
math.sin(2 * math.pi * phase),
math.cos(2 * math.pi * phase),
flag.x_speed_command,
flag.y_speed_command,
flag.yaw_speed_command
], dtype=np.float32)
print(f'Input command: {command}')
# IMU data exactly like C++
rpy = np.array([
self.robot_data_.imu_data_[2], # roll
self.robot_data_.imu_data_[1], # pitch
self.robot_data_.imu_data_[0] # yaw
], dtype=np.float32) * 1.0
gyro = np.array([
self.robot_data_.imu_data_[3],
self.robot_data_.imu_data_[4],
self.robot_data_.imu_data_[5]
], dtype=np.float32) * self.ang_vel_scale_
# Construct proprio observation exactly like C++
# proprio << command, (joint_pos - default) * scale, joint_vel * scale, last_actions, gyro, rpy
joint_start_idx = 35 - self.motor_num_ # Same as C++
joint_pos = (
self.robot_data_.q_a_[joint_start_idx:joint_start_idx + 12].astype(np.float32) -
self.default_joint_angles_
) * self.dof_pos_scale_
joint_vel = (
self.robot_data_.q_dot_a_[joint_start_idx:joint_start_idx + 12].astype(np.float32)
) * self.dof_vel_scale_
# Concatenate exactly like C++
proprio = np.concatenate([
command, # 5 elements
joint_pos, # 12 elements
joint_vel, # 12 elements
self.last_actions_, # 12 elements
gyro, # 3 elements
rpy # 3 elements
]) # Total: 47 elements
# History buffer management exactly like C++
if self.is_first_obs_:
for i in range(self.num_hist_):
start_idx = i * self.obs_size_
end_idx = start_idx + self.obs_size_
self.proprio_hist_buf_[start_idx:end_idx] = proprio
self.is_first_obs_ = False
else:
# Shift history: head((num_hist-1)*obs_size) = tail((num_hist-1)*obs_size)
shift_size = (self.num_hist_ - 1) * self.obs_size_
self.proprio_hist_buf_[:shift_size] = self.proprio_hist_buf_[self.obs_size_:]
self.proprio_hist_buf_[shift_size:] = proprio
# Clip observations exactly like C++
self.observations_ = np.clip(self.proprio_hist_buf_, -self.clip_obs_, self.clip_obs_)
def compute_actions(self):
"""使用ONNX模型计算动作 - 与C++版本完全一致"""
if self.ort_session_ is None:
return
try:
# Prepare input tensor
input_data = self.observations_.reshape(1, -1).astype(np.float32)
# ONNX inference
input_name = self.ort_session_.get_inputs()[0].name
outputs = self.ort_session_.run(None, {input_name: input_data})
# Extract and clip actions exactly like C++
output_data = outputs[0][0]
for i in range(self.action_num_):
self.actions_[i] = np.clip(output_data[i], -self.clip_act_, self.clip_act_)
if self.is_first_action_:
print("[FSMStateMLP-ONNX] First Observation:")
for i in range(self.obs_size_):
print(f"{self.observations_[i]:.6f} ", end="")
print()
self.is_first_action_ = False
except Exception as e:
print(f"[FSMStateMLP] ONNX Runtime inference error: {e}")
def on_exit(self):
"""退出MLP状态"""
print("[FSMStateMLP] exit")
def check_transition(self, flag: XboxFlag) -> FSMStateName:
"""检查状态转换"""
if flag.fsm_state_command == "gotoSTOP":
return FSMStateName.STOP
elif flag.fsm_state_command == "gotoMLP":
return FSMStateName.MLP
# elif flag.fsm_state_command == "gotoMLPH":
# return FSMStateName.MLPH
# elif flag.fsm_state_command == "gotoMLPREF":
# return FSMStateName.MLPREF
# elif flag.fsm_state_command == "gotoMLP1":
# return FSMStateName.MLP1
elif flag.fsm_state_command == "gotoZERO":
return FSMStateName.ZERO
# elif flag.fsm_state_command == "gotoSTANDUP":
# return FSMStateName.STANDUP
# elif flag.fsm_state_command == "gotoGETUP":
# return FSMStateName.GETUP
# elif flag.fsm_state_command == "gotoAMP":
# return FSMStateName.AMP
# elif flag.fsm_state_command == "gotoMLPHA":
# return FSMStateName.MLPHA
else:
return None # 无状态转换
最新发布