You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

586 lines
23 KiB

import sys
from typing import Any, Dict, Tuple
import gymnasium as gym
from gymnasium.envs.registration import register
import numpy as np
from robocasa.environments.locomanipulation import REGISTERED_LOCOMANIPULATION_ENVS
from robocasa.models.robots import GR00T_LOCOMANIP_ENVS_ROBOTS
from robosuite.environments.robot_env import RobotEnv
from scipy.spatial.transform import Rotation as R
from gr00t_wbc.control.envs.g1.utils.joint_safety import JointSafetyMonitor
from gr00t_wbc.control.envs.robocasa.utils.controller_utils import (
update_robosuite_controller_configs,
)
from gr00t_wbc.control.envs.robocasa.utils.robocasa_env import ( # noqa: F401
ALLOWED_LANGUAGE_CHARSET,
Gr00tLocomanipRoboCasaEnv,
)
from gr00t_wbc.control.robot_model.instantiation import get_robot_type_and_model
from gr00t_wbc.control.utils.n1_utils import (
prepare_gym_space_for_eval,
prepare_observation_for_eval,
)
from gr00t_wbc.data.constants import RS_VIEW_CAMERA_HEIGHT, RS_VIEW_CAMERA_WIDTH
class SyncEnv(gym.Env):
MAX_MUJOCO_STATE_LEN = 800
def __init__(self, env_name, robot_name, **kwargs):
self.env_name = env_name
self.robot_name = robot_name
self.onscreen = kwargs.get("onscreen", True)
self.enable_gravity_compensation = kwargs.pop("enable_gravity_compensation", False)
self.gravity_compensation_joints = kwargs.pop("gravity_compensation_joints", ["arms"])
_, self.robot_model = get_robot_type_and_model(
robot_name, enable_waist_ik=kwargs.pop("enable_waist", False)
)
env_kwargs = {
"onscreen": kwargs.get("onscreen", True),
"offscreen": kwargs.get("offscreen", False),
"renderer": kwargs.get("renderer", "mjviewer"),
"render_camera": kwargs.get("render_camera", "frontview"),
"camera_names": kwargs.get("camera_names", ["frontview"]),
"camera_heights": kwargs.get("camera_heights", None),
"camera_widths": kwargs.get("camera_widths", None),
"controller_configs": kwargs["controller_configs"],
"control_freq": kwargs.get("control_freq", 50),
"translucent_robot": kwargs.get("translucent_robot", True),
"ik_indicator": kwargs.get("ik_indicator", False),
"randomize_cameras": kwargs.get("randomize_cameras", True),
}
self.env = Gr00tLocomanipRoboCasaEnv(
env_name, robot_name, robot_model=self.robot_model, **env_kwargs
)
self.init_cache()
self.reset()
@property
def base_env(self) -> RobotEnv:
return self.env.env
def overwrite_floating_base_action(self, navigate_cmd):
if self.base_env.robots[0].robot_model.default_base == "FloatingLeggedBase":
self.env.unwrapped.overridden_floating_base_action = navigate_cmd
def get_mujoco_state_info(self):
mujoco_state = self.base_env.sim.get_state().flatten()
assert len(mujoco_state) < SyncEnv.MAX_MUJOCO_STATE_LEN
padding_width = SyncEnv.MAX_MUJOCO_STATE_LEN - len(mujoco_state)
padded_mujoco_state = np.pad(
mujoco_state, (0, padding_width), mode="constant", constant_values=0
)
max_mujoco_state_len = SyncEnv.MAX_MUJOCO_STATE_LEN
mujoco_state_len = len(mujoco_state)
mujoco_state = padded_mujoco_state.copy()
return max_mujoco_state_len, mujoco_state_len, mujoco_state
def reset_to(self, state: Dict[str, Any]) -> Dict[str, Any] | None:
if hasattr(self.base_env, "reset_to"):
self.base_env.reset_to(state)
else:
# todo: maybe update robosuite to have reset_to()
env = self.base_env
if "model_file" in state:
xml = env.edit_model_xml(state["model_file"])
env.reset_from_xml_string(xml)
env.sim.reset()
if "states" in state:
env.sim.set_state_from_flattened(state["states"])
env.sim.forward()
obs = self.env.force_update_observation(timestep=0)
self.cache["obs"] = obs
return
def get_state(self) -> Dict[str, Any]:
return self.base_env.get_state()
def is_success(self):
"""
Check if the task condition(s) is reached. Should return a dictionary
{ str: bool } with at least a "task" key for the overall task success,
and additional optional keys corresponding to other task criteria.
"""
# First, try to use the base environment's is_success method if it exists
if hasattr(self.base_env, "is_success"):
return self.base_env.is_success()
# Fall back to using _check_success if available
elif hasattr(self.base_env, "_check_success"):
succ = self.base_env._check_success()
if isinstance(succ, dict):
assert "task" in succ
return succ
return {"task": succ}
# If neither method exists, return failure
else:
return {"task": False}
def init_cache(self):
self.cache = {
"obs": None,
"reward": None,
"terminated": None,
"truncated": None,
"info": None,
}
def reset(self, seed=None, options=None) -> Tuple[Dict[str, any], Dict[str, any]]:
self.init_cache()
obs, info = self.env.reset(seed=seed, options=options)
self.cache["obs"] = obs
self.cache["reward"] = 0
self.cache["terminated"] = False
self.cache["truncated"] = False
self.cache["info"] = info
return self.observe(), info
def observe(self) -> Dict[str, any]:
# Get observations from body and hands
assert (
self.cache["obs"] is not None
), "Observation cache is not initialized, please reset the environment first"
raw_obs = self.cache["obs"]
# Body and hand joint measurements come in actuator order, so we need to convert them to joint order
whole_q = self.robot_model.get_configuration_from_actuated_joints(
body_actuated_joint_values=raw_obs["body_q"],
left_hand_actuated_joint_values=raw_obs["left_hand_q"],
right_hand_actuated_joint_values=raw_obs["right_hand_q"],
)
whole_dq = self.robot_model.get_configuration_from_actuated_joints(
body_actuated_joint_values=raw_obs["body_dq"],
left_hand_actuated_joint_values=raw_obs["left_hand_dq"],
right_hand_actuated_joint_values=raw_obs["right_hand_dq"],
)
whole_ddq = self.robot_model.get_configuration_from_actuated_joints(
body_actuated_joint_values=raw_obs["body_ddq"],
left_hand_actuated_joint_values=raw_obs["left_hand_ddq"],
right_hand_actuated_joint_values=raw_obs["right_hand_ddq"],
)
whole_tau_est = self.robot_model.get_configuration_from_actuated_joints(
body_actuated_joint_values=raw_obs["body_tau_est"],
left_hand_actuated_joint_values=raw_obs["left_hand_tau_est"],
right_hand_actuated_joint_values=raw_obs["right_hand_tau_est"],
)
eef_obs = self.get_eef_obs(whole_q)
obs = {
"q": whole_q,
"dq": whole_dq,
"ddq": whole_ddq,
"tau_est": whole_tau_est,
"floating_base_pose": raw_obs["floating_base_pose"],
"floating_base_vel": raw_obs["floating_base_vel"],
"floating_base_acc": raw_obs["floating_base_acc"],
"wrist_pose": np.concatenate([eef_obs["left_wrist_pose"], eef_obs["right_wrist_pose"]]),
}
# Add state keys for model input
obs = prepare_observation_for_eval(self.robot_model, obs)
obs["annotation.human.task_description"] = raw_obs["language.language_instruction"]
if hasattr(self.base_env, "get_privileged_obs_keys"):
for key in self.base_env.get_privileged_obs_keys():
obs[key] = raw_obs[key]
for key in raw_obs.keys():
if key.endswith("_image"):
obs[key] = raw_obs[key]
# TODO: add video.key without _image suffix for evaluation, convert to uint8, remove later
obs[f"video.{key.replace('_image', '')}"] = raw_obs[key]
return obs
def step(
self, action: Dict[str, any]
) -> Tuple[Dict[str, any], float, bool, bool, Dict[str, any]]:
self.queue_action(action)
return self.get_step_info()
def get_observation(self):
return self.base_env._get_observations() # assumes base env is robosuite
def get_step_info(self) -> Tuple[Dict[str, any], float, bool, bool, Dict[str, any]]:
return (
self.observe(),
self.cache["reward"],
self.cache["terminated"],
self.cache["truncated"],
self.cache["info"],
)
def convert_q_to_actuated_joint_order(self, q: np.ndarray) -> np.ndarray:
body_q = self.robot_model.get_body_actuated_joints(q)
left_hand_q = self.robot_model.get_hand_actuated_joints(q, side="left")
right_hand_q = self.robot_model.get_hand_actuated_joints(q, side="right")
whole_q = np.zeros_like(q)
whole_q[self.robot_model.get_joint_group_indices("body")] = body_q
whole_q[self.robot_model.get_joint_group_indices("left_hand")] = left_hand_q
whole_q[self.robot_model.get_joint_group_indices("right_hand")] = right_hand_q
return whole_q
def set_ik_indicator(self, teleop_cmd):
"""Set the IK indicators for the simulator"""
if "left_wrist" in teleop_cmd and "right_wrist" in teleop_cmd:
left_wrist_input_pose = teleop_cmd["left_wrist"]
right_wrist_input_pose = teleop_cmd["right_wrist"]
ik_wrapper = self.base_env
ik_wrapper.set_target_poses_outside_env([left_wrist_input_pose, right_wrist_input_pose])
def render(self):
if self.base_env.viewer is not None:
self.base_env.viewer.update()
if self.onscreen:
self.base_env.render()
def queue_action(self, action: Dict[str, any]):
# action is in pinocchio joint order, we need to convert it to actuator order
action_q = self.convert_q_to_actuated_joint_order(action["q"])
# Compute gravity compensation torques if enabled
tau_q = np.zeros_like(action_q)
if self.enable_gravity_compensation and self.robot_model is not None:
try:
# Get current robot configuration from cache (more efficient than observe())
raw_obs = self.cache["obs"]
# Convert from actuated joint order to joint order for Pinocchio
current_q_joint_order = self.robot_model.get_configuration_from_actuated_joints(
body_actuated_joint_values=raw_obs["body_q"],
left_hand_actuated_joint_values=raw_obs["left_hand_q"],
right_hand_actuated_joint_values=raw_obs["right_hand_q"],
)
# Compute gravity compensation in joint order using current robot configuration
gravity_torques_joint_order = self.robot_model.compute_gravity_compensation_torques(
current_q_joint_order, joint_groups=self.gravity_compensation_joints
)
# Convert gravity torques to actuated joint order
gravity_torques_actuated = self.convert_q_to_actuated_joint_order(
gravity_torques_joint_order
)
# Add gravity compensation to torques
tau_q += gravity_torques_actuated
except Exception as e:
print(f"Error applying gravity compensation in sync_env: {e}")
obs, reward, terminated, truncated, info = self.env.step({"q": action_q, "tau": tau_q})
self.cache["obs"] = obs
self.cache["reward"] = reward
self.cache["terminated"] = terminated
self.cache["truncated"] = truncated
self.cache["info"] = info
def queue_state(self, state: Dict[str, any]):
# This function is for debugging or cross-playback between sim and real only.
state_q = self.convert_q_to_actuated_joint_order(state["q"])
obs, reward, terminated, truncated, info = self.env.unwrapped.step_only_kinematics(
{"q": state_q}
)
self.cache["obs"] = obs
self.cache["reward"] = reward
self.cache["terminated"] = terminated
self.cache["truncated"] = truncated
self.cache["info"] = info
@property
def observation_space(self) -> gym.Space:
# @todo: check if the low and high bounds are correct for body_obs.
q_space = gym.spaces.Box(low=-np.inf, high=np.inf, shape=(self.robot_model.num_dofs,))
dq_space = gym.spaces.Box(low=-np.inf, high=np.inf, shape=(self.robot_model.num_dofs,))
ddq_space = gym.spaces.Box(low=-np.inf, high=np.inf, shape=(self.robot_model.num_dofs,))
tau_est_space = gym.spaces.Box(low=-np.inf, high=np.inf, shape=(self.robot_model.num_dofs,))
floating_base_pose_space = gym.spaces.Box(low=-np.inf, high=np.inf, shape=(7,))
floating_base_vel_space = gym.spaces.Box(low=-np.inf, high=np.inf, shape=(6,))
floating_base_acc_space = gym.spaces.Box(low=-np.inf, high=np.inf, shape=(6,))
wrist_pose_space = gym.spaces.Box(low=-np.inf, high=np.inf, shape=(7 + 7,))
obs_space = gym.spaces.Dict(
{
"floating_base_pose": floating_base_pose_space,
"floating_base_vel": floating_base_vel_space,
"floating_base_acc": floating_base_acc_space,
"q": q_space,
"dq": dq_space,
"ddq": ddq_space,
"tau_est": tau_est_space,
"wrist_pose": wrist_pose_space,
}
)
obs_space = prepare_gym_space_for_eval(self.robot_model, obs_space)
obs_space["annotation.human.task_description"] = gym.spaces.Text(
max_length=256, charset=ALLOWED_LANGUAGE_CHARSET
)
if hasattr(self.base_env, "get_privileged_obs_keys"):
for key, shape in self.base_env.get_privileged_obs_keys().items():
space = gym.spaces.Box(low=-np.inf, high=np.inf, shape=shape)
obs_space[key] = space
robocasa_obs_space = self.env.observation_space
for key in robocasa_obs_space.keys():
if key.endswith("_image"):
space = gym.spaces.Box(
low=-np.inf, high=np.inf, shape=robocasa_obs_space[key].shape
)
obs_space[key] = space
# TODO: add video.key without _image suffix for evaluation, remove later
space_uint = gym.spaces.Box(low=0, high=255, shape=space.shape, dtype=np.uint8)
obs_space[f"video.{key.replace('_image', '')}"] = space_uint
return obs_space
def reset_obj_pos(self):
# For Tairan's goal-reaching task, a hacky way to reset the object position is needed.
if hasattr(self.base_env, "reset_obj_pos"):
self.base_env.reset_obj_pos()
@property
def action_space(self) -> gym.Space:
return self.env.action_space
def close(self):
self.env.close()
def __repr__(self):
return (
f"SyncEnv(env_name={self.env_name}, \n"
f" observation_space={self.observation_space}, \n"
f" action_space={self.action_space})"
)
def get_joint_gains(self):
controller = self.base_env.robots[0].composite_controller
gains = {}
key_mapping = {
"left": "left_arm",
"right": "right_arm",
"legs": "legs",
"torso": "waist",
"head": "neck",
}
for k in controller.part_controllers.keys():
if hasattr(controller.part_controllers[k], "kp"):
if k in key_mapping:
gains[key_mapping[k]] = controller.part_controllers[k].kp
else:
gains[k] = controller.part_controllers[k].kp
gains.update(
{
"left_hand": self.base_env.sim.model.actuator_gainprm[
self.base_env.robots[0]._ref_actuators_indexes_dict["left_gripper"], 0
],
"right_hand": self.base_env.sim.model.actuator_gainprm[
self.base_env.robots[0]._ref_actuators_indexes_dict["right_gripper"], 0
],
}
)
joint_gains = np.zeros(self.robot_model.num_dofs)
for k in gains.keys():
joint_gains[self.robot_model.get_joint_group_indices(k)] = gains[k]
return joint_gains
def get_joint_damping(self):
controller = self.base_env.robots[0].composite_controller
damping = {}
key_mapping = {
"left": "left_arm",
"right": "right_arm",
"legs": "legs",
"torso": "waist",
"head": "neck",
}
for k in controller.part_controllers.keys():
if hasattr(controller.part_controllers[k], "kd"):
if k in key_mapping:
damping[key_mapping[k]] = controller.part_controllers[k].kd
else:
damping[k] = controller.part_controllers[k].kd
damping.update(
{
"left_hand": -self.base_env.sim.model.actuator_biasprm[
self.base_env.robots[0]._ref_actuators_indexes_dict["left_gripper"], 2
],
"right_hand": -self.base_env.sim.model.actuator_biasprm[
self.base_env.robots[0]._ref_actuators_indexes_dict["right_gripper"], 2
],
}
)
joint_damping = np.zeros(self.robot_model.num_dofs)
for k in damping.keys():
joint_damping[self.robot_model.get_joint_group_indices(k)] = damping[k]
return joint_damping
def get_eef_obs(self, q: np.ndarray) -> Dict[str, np.ndarray]:
self.robot_model.cache_forward_kinematics(q)
eef_obs = {}
for side in ["left", "right"]:
wrist_placement = self.robot_model.frame_placement(
self.robot_model.supplemental_info.hand_frame_names[side]
)
wrist_pos, wrist_quat = wrist_placement.translation[:3], R.from_matrix(
wrist_placement.rotation
).as_quat(scalar_first=True)
eef_obs[f"{side}_wrist_pose"] = np.concatenate([wrist_pos, wrist_quat])
return eef_obs
class G1SyncEnv(SyncEnv):
def __init__(
self,
env_name,
robot_name,
**kwargs,
):
renderer = kwargs.get("renderer", "mjviewer")
if renderer == "mjviewer":
default_render_camera = ["robot0_oak_egoview"]
elif renderer in ["mujoco", "rerun"]:
default_render_camera = [
"robot0_oak_egoview",
"robot0_oak_left_monoview",
"robot0_oak_right_monoview",
]
else:
raise NotImplementedError
default_camera_names = [
"robot0_oak_egoview",
"robot0_oak_left_monoview",
"robot0_oak_right_monoview",
]
default_camera_heights = [
RS_VIEW_CAMERA_HEIGHT,
RS_VIEW_CAMERA_HEIGHT,
RS_VIEW_CAMERA_HEIGHT,
]
default_camera_widths = [
RS_VIEW_CAMERA_WIDTH,
RS_VIEW_CAMERA_WIDTH,
RS_VIEW_CAMERA_WIDTH,
]
kwargs.update(
{
"onscreen": kwargs.get("onscreen", True),
"offscreen": kwargs.get("offscreen", False),
"render_camera": kwargs.get("render_camera", default_render_camera),
"camera_names": kwargs.get("camera_names", default_camera_names),
"camera_heights": kwargs.get("camera_heights", default_camera_heights),
"camera_widths": kwargs.get("camera_widths", default_camera_widths),
"translucent_robot": kwargs.get("translucent_robot", False),
}
)
super().__init__(env_name=env_name, robot_name=robot_name, **kwargs)
# Initialize safety monitor (visualization disabled) - G1 specific
self.safety_monitor = JointSafetyMonitor(
self.robot_model,
enable_viz=False,
env_type="sim", # G1SyncEnv is always simulation
)
self.safety_monitor.ramp_duration_steps = 0
self.safety_monitor.startup_complete = True
self.safety_monitor.LOWER_BODY_VELOCITY_LIMIT = (
1e10 # disable lower body velocity limits since it impacts WBC
)
self.last_safety_ok = True # Track last safety status from queue_action
@property
def observation_space(self):
obs_space = super().observation_space
obs_space["torso_quat"] = gym.spaces.Box(low=-np.inf, high=np.inf, shape=(4,))
obs_space["torso_ang_vel"] = gym.spaces.Box(low=-np.inf, high=np.inf, shape=(3,))
return obs_space
def observe(self):
obs = super().observe()
obs["torso_quat"] = self.cache["obs"]["secondary_imu_quat"]
obs["torso_ang_vel"] = self.cache["obs"]["secondary_imu_vel"][3:6]
return obs
def queue_action(self, action: Dict[str, any]):
# Safety check before queuing action
obs = self.observe()
safety_result = self.safety_monitor.handle_violations(obs, action)
action = safety_result["action"]
# Save safety status for efficient access
self.last_safety_ok = not safety_result.get("shutdown_required", False)
# Check if shutdown is required
if safety_result["shutdown_required"]:
self.safety_monitor.trigger_system_shutdown()
# Call parent queue_action with potentially modified action
super().queue_action(action)
def get_joint_safety_status(self) -> bool:
"""Get current joint safety status from the last queue_action safety check.
Returns:
bool: True if joints are safe (no shutdown required), False if unsafe
"""
return self.last_safety_ok
def create_gym_sync_env_class(env, robot, robot_alias, wbc_version):
class_name = f"{env}_{robot}_{wbc_version}"
id_name = f"gr00tlocomanip_{robot_alias}/{class_name}"
if robot_alias.startswith("g1"):
env_class_type = G1SyncEnv
elif robot_alias.startswith("gr1"):
env_class_type = globals().get("GR1SyncEnv", SyncEnv)
else:
env_class_type = SyncEnv
controller_configs = update_robosuite_controller_configs(
robot=robot,
wbc_version=wbc_version,
)
env_class_type = type(
class_name,
(env_class_type,),
{
"__init__": lambda self, **kwargs: super(self.__class__, self).__init__(
env_name=env,
robot_name=robot,
controller_configs=controller_configs,
**kwargs,
)
},
)
current_module = sys.modules["gr00t_wbc.control.envs.robocasa.sync_env"]
setattr(current_module, class_name, env_class_type)
register(
id=id_name, # Unique ID for the environment
entry_point=f"gr00t_wbc.control.envs.robocasa.sync_env:{class_name}",
)
WBC_VERSION = "gear_wbc"
for ENV in REGISTERED_LOCOMANIPULATION_ENVS:
for ROBOT, ROBOT_ALIAS in GR00T_LOCOMANIP_ENVS_ROBOTS.items():
create_gym_sync_env_class(ENV, ROBOT, ROBOT_ALIAS, WBC_VERSION)
if __name__ == "__main__":
env = gym.make("gr00tlocomanip_g1_sim/PnPBottle_g1_gear_wbc")
print(env.observation_space)