You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

652 lines
26 KiB

"""MuJoCo simulation environment and loop for the G1 (and H1) humanoid robots.
DefaultEnv owns the MuJoCo model/data, computes PD torques from Unitree SDK
commands, steps physics, and publishes observations back via the SDK bridge.
BaseSimulator wraps DefaultEnv with rate-limiting and viewer/image update loops.
"""
import os
import pathlib
from pathlib import Path
import pickle
import tempfile
from threading import Lock, Thread
import time
from typing import Dict
import xml.etree.ElementTree as ET
import mujoco
import mujoco.viewer
import numpy as np
from scipy.spatial.transform import Rotation
from unitree_sdk2py.core.channel import ChannelFactoryInitialize
from gear_sonic.utils.mujoco_sim.metric_utils import check_contact, check_height
from gear_sonic.utils.mujoco_sim.sim_utils import get_subtree_body_names
from gear_sonic.utils.mujoco_sim.unitree_sdk2py_bridge import ElasticBand, UnitreeSdk2Bridge
from gear_sonic.utils.mujoco_sim.robot import Robot
GEAR_SONIC_ROOT = Path(__file__).resolve().parent.parent.parent.parent
class DefaultEnv:
"""Base environment class that handles simulation environment setup and step"""
def __init__(
self,
config: Dict[str, any],
env_name: str = "default",
camera_configs: Dict[str, any] = {},
onscreen: bool = False,
offscreen: bool = False,
enable_image_publish: bool = False,
):
self.config = config
self.env_name = env_name
self.robot = Robot(self.config)
self.num_body_dof = self.robot.NUM_JOINTS
self.num_hand_dof = self.robot.NUM_HAND_JOINTS
self.sim_dt = self.config["SIMULATE_DT"]
self.obs = None
self.torques = np.zeros(self.num_body_dof + self.num_hand_dof * 2)
self.torque_limit = np.array(self.robot.MOTOR_EFFORT_LIMIT_LIST)
self.camera_configs = camera_configs
self.reward_lock = Lock()
self.unitree_bridge = None
self.onscreen = onscreen
self.init_scene()
self.last_reward = 0
self.offscreen = offscreen
if self.offscreen:
self.init_renderers()
self.image_dt = self.config.get("IMAGE_DT", 0.033333)
self.image_publish_process = None
def start_image_publish_subprocess(self, start_method: str = "spawn", camera_port: int = 5555):
from gear_sonic.utils.mujoco_sim.image_publish_utils import ImagePublishProcess
if len(self.camera_configs) == 0:
print(
"Warning: No camera configs provided, image publishing subprocess will not be started"
)
return
start_method = self.config.get("MP_START_METHOD", "spawn")
self.image_publish_process = ImagePublishProcess(
camera_configs=self.camera_configs,
image_dt=self.image_dt,
zmq_port=camera_port,
start_method=start_method,
verbose=self.config.get("verbose", False),
)
self.image_publish_process.start_process()
def _get_dof_indices_by_class(self):
with tempfile.NamedTemporaryFile(mode="w+", delete=False, suffix=".xml") as f:
mujoco.mj_saveLastXML(f.name, self.mj_model)
temp_xml_path = f.name
try:
tree = ET.parse(temp_xml_path)
root = tree.getroot()
joint_class_map = {}
for joint_element in root.findall(".//joint[@class]"):
joint_name = joint_element.get("name")
joint_class = joint_element.get("class")
if joint_name and joint_class:
joint_id = mujoco.mj_name2id(
self.mj_model, mujoco.mjtObj.mjOBJ_JOINT, joint_name
)
if joint_id != -1:
dof_adr = self.mj_model.jnt_dofadr[joint_id]
if joint_class not in joint_class_map:
joint_class_map[joint_class] = []
joint_class_map[joint_class].append(dof_adr)
finally:
os.remove(temp_xml_path)
return joint_class_map
def _get_default_dof_properties(self):
with tempfile.NamedTemporaryFile(mode="w+", delete=False, suffix=".xml") as f:
mujoco.mj_saveLastXML(f.name, self.mj_model)
temp_xml_path = f.name
try:
tree = ET.parse(temp_xml_path)
root = tree.getroot()
default_dof_properties = {}
for default_element in root.findall(".//default/default[@class]"):
class_name = default_element.get("class")
joint_element = default_element.find("joint")
if class_name and joint_element is not None:
properties = {}
if "damping" in joint_element.attrib:
properties["damping"] = float(joint_element.get("damping"))
if "armature" in joint_element.attrib:
properties["armature"] = float(joint_element.get("armature"))
if "frictionloss" in joint_element.attrib:
properties["frictionloss"] = float(joint_element.get("frictionloss"))
if properties:
default_dof_properties[class_name] = properties
finally:
os.remove(temp_xml_path)
return default_dof_properties
def init_scene(self):
"""Initialize the default robot scene"""
xml_path = str(pathlib.Path(GEAR_SONIC_ROOT) / self.config["ROBOT_SCENE"])
self.mj_model = mujoco.MjModel.from_xml_path(xml_path)
self.mj_data = mujoco.MjData(self.mj_model)
self.mj_model.opt.timestep = self.sim_dt
self.torso_index = mujoco.mj_name2id(self.mj_model, mujoco.mjtObj.mjOBJ_BODY, "torso_link")
self.root_body = "pelvis"
self.root_body_id = self.mj_model.body(self.root_body).id
self.joint_class_map = self._get_dof_indices_by_class()
self.perform_sysid_search = self.config.get("perform_sysid_search", False)
# Check for static root link (fixed base)
self.use_floating_root_link = "floating_base_joint" in [
self.mj_model.joint(i).name for i in range(self.mj_model.njnt)
]
self.use_constrained_root_link = "constrained_base_joint" in [
self.mj_model.joint(i).name for i in range(self.mj_model.njnt)
]
# MuJoCo qpos/qvel arrays start with root DOFs before joint DOFs:
# floating base has 7 qpos (pos + quat) and 6 qvel (lin + ang velocity)
if self.use_floating_root_link:
self.qpos_offset = 7
self.qvel_offset = 6
else:
if self.use_constrained_root_link:
self.qpos_offset = 1
self.qvel_offset = 1
else:
raise ValueError(
"No root link found --"
"The absolute static root will make the simulation unstable."
)
# Enable the elastic band
if self.config["ENABLE_ELASTIC_BAND"] and self.use_floating_root_link:
self.elastic_band = ElasticBand()
if "g1" in self.config["ROBOT_TYPE"]:
if self.config["enable_waist"]:
self.band_attached_link = self.mj_model.body("pelvis").id
else:
self.band_attached_link = self.mj_model.body("torso_link").id
elif "h1" in self.config["ROBOT_TYPE"]:
self.band_attached_link = self.mj_model.body("torso_link").id
else:
self.band_attached_link = self.mj_model.body("base_link").id
if self.onscreen:
self.viewer = mujoco.viewer.launch_passive(
self.mj_model,
self.mj_data,
key_callback=self.elastic_band.MujuocoKeyCallback,
show_left_ui=False,
show_right_ui=False,
)
else:
mujoco.mj_forward(self.mj_model, self.mj_data)
self.viewer = None
else:
if self.onscreen:
self.viewer = mujoco.viewer.launch_passive(
self.mj_model, self.mj_data, show_left_ui=False, show_right_ui=False
)
else:
mujoco.mj_forward(self.mj_model, self.mj_data)
self.viewer = None
if self.viewer:
self.viewer.cam.azimuth = 120
self.viewer.cam.elevation = -30
self.viewer.cam.distance = 2.0
self.viewer.cam.lookat = np.array([0, 0, 0.5])
self.viewer.cam.type = mujoco.mjtCamera.mjCAMERA_TRACKING
self.viewer.cam.trackbodyid = self.mj_model.body("pelvis").id
self.body_joint_index = []
self.left_hand_index = []
self.right_hand_index = []
for i in range(self.mj_model.njnt):
name = self.mj_model.joint(i).name
if any(
[
part_name in name
for part_name in ["hip", "knee", "ankle", "waist", "shoulder", "elbow", "wrist"]
]
):
self.body_joint_index.append(i)
elif "left_hand" in name:
self.left_hand_index.append(i)
elif "right_hand" in name:
self.right_hand_index.append(i)
assert len(self.body_joint_index) == self.robot.NUM_JOINTS
assert len(self.left_hand_index) == self.robot.NUM_HAND_JOINTS
assert len(self.right_hand_index) == self.robot.NUM_HAND_JOINTS
self.body_joint_index = np.array(self.body_joint_index)
self.left_hand_index = np.array(self.left_hand_index)
self.right_hand_index = np.array(self.right_hand_index)
def init_renderers(self):
self.renderers = {}
for camera_name, camera_config in self.camera_configs.items():
renderer = mujoco.Renderer(
self.mj_model, height=camera_config["height"], width=camera_config["width"]
)
self.renderers[camera_name] = renderer
def compute_body_torques(self) -> np.ndarray:
# PD control: tau = tau_ff + kp * (q_des - q) + kd * (dq_des - dq)
body_torques = np.zeros(self.num_body_dof)
if self.unitree_bridge is not None and self.unitree_bridge.low_cmd:
for i in range(self.unitree_bridge.num_body_motor):
if self.unitree_bridge.use_sensor:
body_torques[i] = (
self.unitree_bridge.low_cmd.motor_cmd[i].tau
+ self.unitree_bridge.low_cmd.motor_cmd[i].kp
* (self.unitree_bridge.low_cmd.motor_cmd[i].q - self.mj_data.sensordata[i])
+ self.unitree_bridge.low_cmd.motor_cmd[i].kd
* (
self.unitree_bridge.low_cmd.motor_cmd[i].dq
- self.mj_data.sensordata[i + self.unitree_bridge.num_body_motor]
)
)
else:
body_torques[i] = (
self.unitree_bridge.low_cmd.motor_cmd[i].tau
+ self.unitree_bridge.low_cmd.motor_cmd[i].kp
* (
self.unitree_bridge.low_cmd.motor_cmd[i].q
- self.mj_data.qpos[self.body_joint_index[i] + self.qpos_offset - 1]
)
+ self.unitree_bridge.low_cmd.motor_cmd[i].kd
* (
self.unitree_bridge.low_cmd.motor_cmd[i].dq
- self.mj_data.qvel[self.body_joint_index[i] + self.qvel_offset - 1]
)
)
return body_torques
def get_head_pose(self) -> np.ndarray:
root_pos = self.mj_data.body("torso_link").xpos.copy()
# Reorder quaternion from MuJoCo [w,x,y,z] to scipy [x,y,z,w]
root_quat = self.mj_data.body("torso_link").xquat.copy()[[1, 2, 3, 0]]
head_pos = root_pos + Rotation.from_quat(root_quat).apply(np.array([0.0, 0.0, -0.044]))
return np.concatenate((head_pos, root_quat))
def get_root_vel(self) -> np.ndarray:
return self.mj_data.qvel[:6]
def compute_hand_torques(self) -> np.ndarray:
left_hand_torques = np.zeros(self.num_hand_dof)
right_hand_torques = np.zeros(self.num_hand_dof)
if self.unitree_bridge is not None and self.unitree_bridge.low_cmd:
for i in range(self.unitree_bridge.num_hand_motor):
left_hand_torques[i] = (
self.unitree_bridge.left_hand_cmd.motor_cmd[i].tau
+ self.unitree_bridge.left_hand_cmd.motor_cmd[i].kp
* (
self.unitree_bridge.left_hand_cmd.motor_cmd[i].q
- self.mj_data.qpos[self.left_hand_index[i] + self.qpos_offset - 1]
)
+ self.unitree_bridge.left_hand_cmd.motor_cmd[i].kd
* (
self.unitree_bridge.left_hand_cmd.motor_cmd[i].dq
- self.mj_data.qvel[self.left_hand_index[i] + self.qvel_offset - 1]
)
)
right_hand_torques[i] = (
self.unitree_bridge.right_hand_cmd.motor_cmd[i].tau
+ self.unitree_bridge.right_hand_cmd.motor_cmd[i].kp
* (
self.unitree_bridge.right_hand_cmd.motor_cmd[i].q
- self.mj_data.qpos[self.right_hand_index[i] + self.qpos_offset - 1]
)
+ self.unitree_bridge.right_hand_cmd.motor_cmd[i].kd
* (
self.unitree_bridge.right_hand_cmd.motor_cmd[i].dq
- self.mj_data.qvel[self.right_hand_index[i] + self.qvel_offset - 1]
)
)
return np.concatenate((left_hand_torques, right_hand_torques))
def compute_body_qpos(self) -> np.ndarray:
body_qpos = np.zeros(self.num_body_dof)
if self.unitree_bridge is not None and self.unitree_bridge.low_cmd:
for i in range(self.unitree_bridge.num_body_motor):
body_qpos[i] = self.unitree_bridge.low_cmd.motor_cmd[i].q
return body_qpos
def compute_hand_qpos(self) -> np.ndarray:
hand_qpos = np.zeros(self.num_hand_dof * 2)
if self.unitree_bridge is not None and self.unitree_bridge.low_cmd:
for i in range(self.unitree_bridge.num_hand_motor):
hand_qpos[i] = self.unitree_bridge.left_hand_cmd.motor_cmd[i].q
hand_qpos[i + self.num_hand_dof] = self.unitree_bridge.right_hand_cmd.motor_cmd[i].q
return hand_qpos
def prepare_obs(self) -> Dict[str, any]:
obs = {}
if self.use_floating_root_link:
obs["floating_base_pose"] = self.mj_data.qpos[:7]
obs["floating_base_vel"] = self.mj_data.qvel[:6]
obs["floating_base_acc"] = self.mj_data.qacc[:6]
else:
obs["floating_base_pose"] = np.zeros(7)
obs["floating_base_vel"] = np.zeros(6)
obs["floating_base_acc"] = np.zeros(6)
obs["secondary_imu_quat"] = self.mj_data.xquat[self.torso_index]
pose = np.zeros(13)
torso_link = self.mj_model.body("torso_link").id
# mj_objectVelocity returns [ang_vel, lin_vel]; swap to [lin_vel, ang_vel]
mujoco.mj_objectVelocity(
self.mj_model, self.mj_data, mujoco.mjtObj.mjOBJ_BODY, torso_link, pose[7:13], 1
)
pose[7:10], pose[10:13] = (
pose[10:13],
pose[7:10].copy(),
)
obs["secondary_imu_vel"] = pose[7:13]
obs["body_q"] = self.mj_data.qpos[self.body_joint_index + 7 - 1]
obs["body_dq"] = self.mj_data.qvel[self.body_joint_index + 6 - 1]
obs["body_ddq"] = self.mj_data.qacc[self.body_joint_index + 6 - 1]
obs["body_tau_est"] = self.mj_data.actuator_force[self.body_joint_index - 1]
if self.num_hand_dof > 0:
obs["left_hand_q"] = self.mj_data.qpos[self.left_hand_index + self.qpos_offset - 1]
obs["left_hand_dq"] = self.mj_data.qvel[self.left_hand_index + self.qvel_offset - 1]
obs["left_hand_ddq"] = self.mj_data.qacc[self.left_hand_index + self.qvel_offset - 1]
obs["left_hand_tau_est"] = self.mj_data.actuator_force[self.left_hand_index - 1]
obs["right_hand_q"] = self.mj_data.qpos[self.right_hand_index + self.qpos_offset - 1]
obs["right_hand_dq"] = self.mj_data.qvel[self.right_hand_index + self.qvel_offset - 1]
obs["right_hand_ddq"] = self.mj_data.qacc[self.right_hand_index + self.qvel_offset - 1]
obs["right_hand_tau_est"] = self.mj_data.actuator_force[self.right_hand_index - 1]
obs["time"] = self.mj_data.time
return obs
def sim_step(self):
self.obs = self.prepare_obs()
self.unitree_bridge.PublishLowState(self.obs)
if self.unitree_bridge.joystick:
self.unitree_bridge.PublishWirelessController()
if self.elastic_band:
if self.elastic_band.enable and self.use_floating_root_link:
pose = np.concatenate(
[
self.mj_data.xpos[self.band_attached_link],
self.mj_data.xquat[self.band_attached_link],
np.zeros(6),
]
)
mujoco.mj_objectVelocity(
self.mj_model,
self.mj_data,
mujoco.mjtObj.mjOBJ_BODY,
self.band_attached_link,
pose[7:13],
0,
)
pose[7:10], pose[10:13] = pose[10:13], pose[7:10].copy()
self.mj_data.xfrc_applied[self.band_attached_link] = self.elastic_band.Advance(pose)
else:
self.mj_data.xfrc_applied[self.band_attached_link] = np.zeros(6)
body_torques = self.compute_body_torques()
hand_torques = self.compute_hand_torques()
# -1: actuator array is 0-based while joint indices from the model are 1-based
self.torques[self.body_joint_index - 1] = body_torques
if self.num_hand_dof > 0:
self.torques[self.left_hand_index - 1] = hand_torques[: self.num_hand_dof]
self.torques[self.right_hand_index - 1] = hand_torques[self.num_hand_dof :]
self.torques = np.clip(self.torques, -self.torque_limit, self.torque_limit)
if self.config["FREE_BASE"]:
# Prepend 6 zeros for the floating-base root DOF actuators
self.mj_data.ctrl = np.concatenate((np.zeros(6), self.torques))
else:
self.mj_data.ctrl = self.torques
mujoco.mj_step(self.mj_model, self.mj_data)
self.check_fall()
def apply_perturbation(self, key):
perturbation_x_body = 0.0
perturbation_y_body = 0.0
if key == "up":
perturbation_x_body = 1.0
elif key == "down":
perturbation_x_body = -1.0
elif key == "left":
perturbation_y_body = 1.0
elif key == "right":
perturbation_y_body = -1.0
vel_body = np.array([perturbation_x_body, perturbation_y_body, 0.0])
vel_world = np.zeros(3)
base_quat = self.mj_data.qpos[3:7]
mujoco.mju_rotVecQuat(vel_world, vel_body, base_quat)
self.mj_data.qvel[0] += vel_world[0]
self.mj_data.qvel[1] += vel_world[1]
mujoco.mj_forward(self.mj_model, self.mj_data)
def update_viewer(self):
if self.viewer is not None:
self.viewer.sync()
def update_viewer_camera(self):
if self.viewer is not None:
if self.viewer.cam.type == mujoco.mjtCamera.mjCAMERA_TRACKING:
self.viewer.cam.type = mujoco.mjtCamera.mjCAMERA_FREE
else:
self.viewer.cam.type = mujoco.mjtCamera.mjCAMERA_TRACKING
def update_reward(self):
with self.reward_lock:
self.last_reward = 0
def get_reward(self):
with self.reward_lock:
return self.last_reward
def set_unitree_bridge(self, unitree_bridge):
self.unitree_bridge = unitree_bridge
def get_privileged_obs(self):
return {}
def update_render_caches(self):
render_caches = {}
for camera_name, camera_config in self.camera_configs.items():
renderer = self.renderers[camera_name]
if "params" in camera_config:
renderer.update_scene(self.mj_data, camera=camera_config["params"])
else:
renderer.update_scene(self.mj_data, camera=camera_name)
render_caches[camera_name + "_image"] = renderer.render()
if self.image_publish_process is not None:
self.image_publish_process.update_shared_memory(render_caches)
return render_caches
def handle_keyboard_button(self, key):
if self.elastic_band:
self.elastic_band.handle_keyboard_button(key)
if key == "backspace":
self.reset()
if key == "v":
self.update_viewer_camera()
if key in ["up", "down", "left", "right"]:
self.apply_perturbation(key)
def check_fall(self):
self.fall = False
if self.mj_data.qpos[2] < 0.2:
self.fall = True
print(f"Warning: Robot has fallen, height: {self.mj_data.qpos[2]:.3f} m")
if self.fall:
self.reset()
def check_self_collision(self):
robot_bodies = get_subtree_body_names(self.mj_model, self.mj_model.body(self.root_body).id)
self_collision, contact_bodies = check_contact(
self.mj_model, self.mj_data, robot_bodies, robot_bodies, return_all_contact_bodies=True
)
if self_collision:
print(f"Warning: Self-collision detected: {contact_bodies}")
return self_collision
def reset(self):
mujoco.mj_resetData(self.mj_model, self.mj_data)
class BaseSimulator:
"""Base simulator class that handles initialization and running of simulations"""
def __init__(
self, config: Dict[str, any], env_name: str = "default", redis_client=None, **kwargs
):
self.config = config
self.env_name = env_name
self.redis_client = redis_client
if self.redis_client is not None:
self.redis_client.set("push_left_hand", "false")
self.redis_client.set("push_right_hand", "false")
self.redis_client.set("push_torso", "false")
# Create rate objects
self.sim_dt = self.config["SIMULATE_DT"]
self.reward_dt = self.config.get("REWARD_DT", 0.02)
self.image_dt = self.config.get("IMAGE_DT", 0.033333)
self.viewer_dt = self.config.get("VIEWER_DT", 0.02)
self._running = True
self.robot = Robot(self.config)
# Create the environment
if env_name == "default":
self.sim_env = DefaultEnv(config, env_name, **kwargs)
else:
raise ValueError(
f"Invalid environment name: {env_name}. "
f"Only 'default' is supported in this minimal build."
)
try:
if self.config.get("INTERFACE", None):
ChannelFactoryInitialize(self.config["DOMAIN_ID"], self.config["INTERFACE"])
else:
ChannelFactoryInitialize(self.config["DOMAIN_ID"])
except Exception as e:
print(f"Note: Channel factory initialization attempt: {e}")
self.init_unitree_bridge()
self.sim_env.set_unitree_bridge(self.unitree_bridge)
self.init_subscriber()
self.init_publisher()
self.sim_thread = None
def start_as_thread(self):
self.sim_thread = Thread(target=self.start)
self.sim_thread.start()
def start_image_publish_subprocess(self, start_method: str = "spawn", camera_port: int = 5555):
self.sim_env.start_image_publish_subprocess(start_method, camera_port)
def init_subscriber(self):
pass
def init_publisher(self):
pass
def init_unitree_bridge(self):
self.unitree_bridge = UnitreeSdk2Bridge(self.config)
if self.config["USE_JOYSTICK"]:
self.unitree_bridge.SetupJoystick(
device_id=self.config["JOYSTICK_DEVICE"], js_type=self.config["JOYSTICK_TYPE"]
)
def start(self):
"""Main simulation loop"""
sim_cnt = 0
ts = time.time()
try:
while self._running and (
(self.sim_env.viewer and self.sim_env.viewer.is_running())
or (self.sim_env.viewer is None)
):
step_start = time.monotonic()
self.sim_env.sim_step()
now = time.time()
if now - ts > 1 / 10.0 and self.redis_client is not None:
head_pose = self.sim_env.get_head_pose()
self.redis_client.set("head_pos", pickle.dumps(head_pose[:3]))
self.redis_client.set("head_quat", pickle.dumps(head_pose[3:]))
ts = now
if sim_cnt % int(self.viewer_dt / self.sim_dt) == 0:
self.sim_env.update_viewer()
if sim_cnt % int(self.reward_dt / self.sim_dt) == 0:
self.sim_env.update_reward()
if sim_cnt % int(self.image_dt / self.sim_dt) == 0:
self.sim_env.update_render_caches()
# Simple rate limiter (replaces ROS rate)
elapsed = time.monotonic() - step_start
sleep_time = self.sim_dt - elapsed
if sleep_time > 0:
time.sleep(sleep_time)
sim_cnt += 1
except KeyboardInterrupt:
print("Simulator interrupted by user.")
finally:
self.close()
def __del__(self):
self.close()
def reset(self):
self.sim_env.reset()
def close(self):
self._running = False
try:
if self.sim_env.image_publish_process is not None:
self.sim_env.image_publish_process.stop()
if self.sim_env.viewer is not None:
self.sim_env.viewer.close()
except Exception as e:
print(f"Warning during close: {e}")
def get_privileged_obs(self):
return self.sim_env.get_privileged_obs()
def handle_keyboard_button(self, key):
self.sim_env.handle_keyboard_button(key)