You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
443 lines
19 KiB
443 lines
19 KiB
import os
|
|
from typing import Any, Dict, List, Tuple
|
|
|
|
from gymnasium import spaces
|
|
import mujoco
|
|
import numpy as np
|
|
import robocasa
|
|
from robocasa.utils.gym_utils.gymnasium_basic import (
|
|
RoboCasaEnv,
|
|
create_env_robosuite,
|
|
)
|
|
from robocasa.wrappers.ik_wrapper import IKWrapper
|
|
from robosuite.controllers import load_composite_controller_config
|
|
from robosuite.utils.log_utils import ROBOSUITE_DEFAULT_LOGGER
|
|
|
|
from decoupled_wbc.control.envs.robocasa.utils.cam_key_converter import CameraKeyMapper
|
|
from decoupled_wbc.control.envs.robocasa.utils.robot_key_converter import Gr00tObsActionConverter
|
|
from decoupled_wbc.control.robot_model.robot_model import RobotModel
|
|
|
|
ALLOWED_LANGUAGE_CHARSET = (
|
|
"abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789 ,.\n\t[]{}()!?'_:"
|
|
)
|
|
|
|
|
|
class Gr00tLocomanipRoboCasaEnv(RoboCasaEnv):
|
|
def __init__(
|
|
self,
|
|
env_name: str,
|
|
robots_name: str,
|
|
robot_model: RobotModel, # gr00t robot model
|
|
input_space: str = "JOINT_SPACE", # either "JOINT_SPACE" or "EEF_SPACE"
|
|
camera_names: List[str] = ["egoview"],
|
|
camera_heights: List[int] | None = None,
|
|
camera_widths: List[int] | None = None,
|
|
onscreen: bool = False,
|
|
offscreen: bool = False,
|
|
dump_rollout_dataset_dir: str | None = None,
|
|
rollout_hdf5: str | None = None,
|
|
rollout_trainset: int | None = None,
|
|
controller_configs: str | None = None,
|
|
ik_indicator: bool = False,
|
|
**kwargs,
|
|
):
|
|
# ========= Create env =========
|
|
if controller_configs is None:
|
|
if "G1" in robots_name:
|
|
controller_configs = (
|
|
"robocasa/examples/third_party_controller/default_mink_ik_g1_wbc.json"
|
|
)
|
|
elif "GR1" in robots_name:
|
|
controller_configs = (
|
|
"robocasa/examples/third_party_controller/default_mink_ik_gr1_smallkd.json"
|
|
)
|
|
else:
|
|
assert False, f"Unsupported robot name: {robots_name}"
|
|
controller_configs = os.path.join(
|
|
os.path.dirname(robocasa.__file__),
|
|
"../",
|
|
controller_configs,
|
|
)
|
|
controller_configs = load_composite_controller_config(
|
|
controller=controller_configs,
|
|
robot=robots_name.split("_")[0],
|
|
)
|
|
if input_space == "JOINT_SPACE":
|
|
controller_configs["type"] = "BASIC"
|
|
controller_configs["composite_controller_specific_configs"] = {}
|
|
controller_configs["control_delta"] = False
|
|
|
|
self.camera_key_mapper = CameraKeyMapper()
|
|
self.camera_names = camera_names
|
|
|
|
if camera_widths is None:
|
|
self.camera_widths = [
|
|
self.camera_key_mapper.get_camera_config(name)[1] for name in camera_names
|
|
]
|
|
else:
|
|
self.camera_widths = camera_widths
|
|
if camera_heights is None:
|
|
self.camera_heights = [
|
|
self.camera_key_mapper.get_camera_config(name)[2] for name in camera_names
|
|
]
|
|
else:
|
|
self.camera_heights = camera_heights
|
|
|
|
self.env, self.env_kwargs = create_env_robosuite(
|
|
env_name=env_name,
|
|
robots=robots_name.split("_"),
|
|
controller_configs=controller_configs,
|
|
camera_names=camera_names,
|
|
camera_widths=self.camera_widths,
|
|
camera_heights=self.camera_heights,
|
|
enable_render=offscreen,
|
|
onscreen=onscreen,
|
|
**kwargs, # Forward kwargs to create_env_robosuite
|
|
)
|
|
|
|
if ik_indicator:
|
|
self.env = IKWrapper(self.env, ik_indicator=True)
|
|
|
|
# ========= create converters first to get total DOFs =========
|
|
# For now, assume single robot (multi-robot support can be added later)
|
|
self.obs_action_converter: List[Gr00tObsActionConverter] = [
|
|
Gr00tObsActionConverter(
|
|
robot_model=robot_model,
|
|
robosuite_robot_model=self.env.robots[i],
|
|
)
|
|
for i in range(len(self.env.robots))
|
|
]
|
|
|
|
self.body_dofs = sum(converter.body_dof for converter in self.obs_action_converter)
|
|
self.gripper_dofs = sum(converter.gripper_dof for converter in self.obs_action_converter)
|
|
self.total_dofs = self.body_dofs + self.gripper_dofs
|
|
self.body_nu = sum(converter.body_nu for converter in self.obs_action_converter)
|
|
self.gripper_nu = sum(converter.gripper_nu for converter in self.obs_action_converter)
|
|
self.total_nu = self.body_nu + self.gripper_nu
|
|
|
|
# ========= create spaces to match total DOFs =========
|
|
self.get_observation_space()
|
|
self.get_action_space()
|
|
|
|
self.enable_render = offscreen
|
|
self.render_obs_key = f"{camera_names[0]}_image"
|
|
self.render_cache = None
|
|
|
|
self.dump_rollout_dataset_dir = dump_rollout_dataset_dir
|
|
self.gr00t_exporter = None
|
|
self.np_exporter = None
|
|
|
|
self.rollout_hdf5 = rollout_hdf5
|
|
self.rollout_trainset = rollout_trainset
|
|
self.rollout_initial_state = {}
|
|
|
|
self.verbose = False
|
|
for k, v in self.observation_space.items():
|
|
self.verbose and print("{OBS}", k, v)
|
|
for k, v in self.action_space.items():
|
|
self.verbose and print("{ACTION}", k, v)
|
|
|
|
self.overridden_floating_base_action = None
|
|
|
|
def get_observation_space(self):
|
|
self.observation_space = spaces.Dict({})
|
|
|
|
# Add all the observation spaces
|
|
self.observation_space["time"] = spaces.Box(
|
|
low=-np.inf, high=np.inf, shape=(1,), dtype=np.float32
|
|
)
|
|
self.observation_space["floating_base_pose"] = spaces.Box(
|
|
low=-np.inf, high=np.inf, shape=(7,), dtype=np.float32
|
|
)
|
|
self.observation_space["floating_base_vel"] = spaces.Box(
|
|
low=-np.inf, high=np.inf, shape=(6,), dtype=np.float32
|
|
)
|
|
self.observation_space["floating_base_acc"] = spaces.Box(
|
|
low=-np.inf, high=np.inf, shape=(6,), dtype=np.float32
|
|
)
|
|
self.observation_space["body_q"] = spaces.Box(
|
|
low=-np.inf, high=np.inf, shape=(self.body_dofs,), dtype=np.float32
|
|
)
|
|
self.observation_space["body_dq"] = spaces.Box(
|
|
low=-np.inf, high=np.inf, shape=(self.body_dofs,), dtype=np.float32
|
|
)
|
|
self.observation_space["body_ddq"] = spaces.Box(
|
|
low=-np.inf, high=np.inf, shape=(self.body_dofs,), dtype=np.float32
|
|
)
|
|
self.observation_space["body_tau_est"] = spaces.Box(
|
|
low=-np.inf, high=np.inf, shape=(self.body_nu,), dtype=np.float32
|
|
)
|
|
self.observation_space["left_hand_q"] = spaces.Box(
|
|
low=-np.inf, high=np.inf, shape=(self.gripper_dofs // 2,), dtype=np.float32
|
|
)
|
|
self.observation_space["left_hand_dq"] = spaces.Box(
|
|
low=-np.inf, high=np.inf, shape=(self.gripper_dofs // 2,), dtype=np.float32
|
|
)
|
|
self.observation_space["left_hand_ddq"] = spaces.Box(
|
|
low=-np.inf, high=np.inf, shape=(self.gripper_dofs // 2,), dtype=np.float32
|
|
)
|
|
self.observation_space["left_hand_tau_est"] = spaces.Box(
|
|
low=-np.inf, high=np.inf, shape=(self.gripper_nu // 2,), dtype=np.float32
|
|
)
|
|
self.observation_space["right_hand_q"] = spaces.Box(
|
|
low=-np.inf, high=np.inf, shape=(self.gripper_dofs // 2,), dtype=np.float32
|
|
)
|
|
self.observation_space["right_hand_dq"] = spaces.Box(
|
|
low=-np.inf, high=np.inf, shape=(self.gripper_dofs // 2,), dtype=np.float32
|
|
)
|
|
self.observation_space["right_hand_ddq"] = spaces.Box(
|
|
low=-np.inf, high=np.inf, shape=(self.gripper_dofs // 2,), dtype=np.float32
|
|
)
|
|
self.observation_space["right_hand_tau_est"] = spaces.Box(
|
|
low=-np.inf, high=np.inf, shape=(self.gripper_nu // 2,), dtype=np.float32
|
|
)
|
|
|
|
self.observation_space["language.language_instruction"] = spaces.Text(
|
|
max_length=256, charset=ALLOWED_LANGUAGE_CHARSET
|
|
)
|
|
|
|
# Add camera observation spaces
|
|
for camera_name, w, h in zip(self.camera_names, self.camera_widths, self.camera_heights):
|
|
k = self.camera_key_mapper.get_camera_config(camera_name)[0]
|
|
self.observation_space[f"{k}_image"] = spaces.Box(
|
|
low=0, high=255, shape=(h, w, 3), dtype=np.uint8
|
|
)
|
|
|
|
# Add extra privileged observation spaces
|
|
if hasattr(self.env, "get_privileged_obs_keys"):
|
|
for key, shape in self.env.get_privileged_obs_keys().items():
|
|
self.observation_space[key] = spaces.Box(
|
|
low=-np.inf, high=np.inf, shape=shape, dtype=np.float32
|
|
)
|
|
|
|
# Add robot-specific observation spaces
|
|
if hasattr(self.env.robots[0].robot_model, "torso_body"):
|
|
self.observation_space["secondary_imu_quat"] = spaces.Box(
|
|
low=-np.inf, high=np.inf, shape=(4,), dtype=np.float32
|
|
)
|
|
self.observation_space["secondary_imu_vel"] = spaces.Box(
|
|
low=-np.inf, high=np.inf, shape=(6,), dtype=np.float32
|
|
)
|
|
|
|
def get_action_space(self):
|
|
self.action_space = spaces.Dict(
|
|
{"q": spaces.Box(low=-np.inf, high=np.inf, shape=(self.total_dofs,), dtype=np.float32)}
|
|
)
|
|
|
|
def reset(self, seed=None, options=None):
|
|
raw_obs, info = super().reset(seed=seed, options=options)
|
|
obs = self.get_gr00t_observation(raw_obs)
|
|
|
|
lang = self.env.get_ep_meta().get("lang", "")
|
|
ROBOSUITE_DEFAULT_LOGGER.info(f"Instruction: {lang}")
|
|
|
|
return obs, info
|
|
|
|
def step(
|
|
self, action: Dict[str, Any]
|
|
) -> Tuple[Dict[str, Any], float, bool, bool, Dict[str, Any]]:
|
|
# action={"q": xxx, "tau": xxx}
|
|
for k, v in action.items():
|
|
self.verbose and print("<ACTION>", k, v)
|
|
|
|
joint_actoin_vec = action["q"]
|
|
action_dict = {}
|
|
for ii, robot in enumerate(self.env.robots):
|
|
pf = robot.robot_model.naming_prefix
|
|
_action_dict = self.obs_action_converter[ii].gr00t_to_robocasa_action_dict(
|
|
joint_actoin_vec
|
|
)
|
|
action_dict.update({f"{pf}{k}": v for k, v in _action_dict.items()})
|
|
if action.get("tau", None) is not None:
|
|
_torque_dict = self.obs_action_converter[ii].gr00t_to_robocasa_action_dict(
|
|
action["tau"]
|
|
)
|
|
action_dict.update({f"{pf}{k}_tau": v for k, v in _torque_dict.items()})
|
|
if self.overridden_floating_base_action is not None:
|
|
action_dict["robot0_base"] = self.overridden_floating_base_action
|
|
raw_obs, reward, terminated, truncated, info = super().step(action_dict)
|
|
obs = self.get_gr00t_observation(raw_obs)
|
|
|
|
for k, v in obs.items():
|
|
self.verbose and print("<OBS>", k, v.shape if k.startswith("video.") else v)
|
|
self.verbose = False
|
|
|
|
return obs, reward, terminated, truncated, info
|
|
|
|
def step_only_kinematics(
|
|
self, action: Dict[str, Any]
|
|
) -> Tuple[Dict[str, Any], float, bool, bool, Dict[str, Any]]:
|
|
joint_actoin_vec = action["q"]
|
|
for ii, robot in enumerate(self.env.robots):
|
|
joint_names = np.array(self.env.sim.model.joint_names)[robot._ref_joint_indexes]
|
|
body_q = self.obs_action_converter[ii].gr00t_to_robocasa_joint_order(
|
|
joint_names, joint_actoin_vec
|
|
)
|
|
self.env.sim.data.qpos[robot._ref_joint_pos_indexes] = body_q
|
|
|
|
for side in ["left", "right"]:
|
|
joint_names = np.array(self.env.sim.model.joint_names)[
|
|
robot._ref_joints_indexes_dict[side + "_gripper"]
|
|
]
|
|
gripper_q = self.obs_action_converter[ii].gr00t_to_robocasa_joint_order(
|
|
joint_names, joint_actoin_vec
|
|
)
|
|
self.env.sim.data.qpos[robot._ref_gripper_joint_pos_indexes[side]] = gripper_q
|
|
|
|
mujoco.mj_forward(self.env.sim.model._model, self.env.sim.data._data)
|
|
|
|
obs = self.force_update_observation()
|
|
return obs, 0, False, False, {"success": False}
|
|
|
|
def force_update_observation(self, timestep=0):
|
|
raw_obs = self.env._get_observations(force_update=True, timestep=timestep)
|
|
obs = self.get_basic_observation(raw_obs)
|
|
obs = self.get_gr00t_observation(obs)
|
|
return obs
|
|
|
|
def get_basic_observation(self, raw_obs):
|
|
# this function takes a lot of time, so we disable it for now
|
|
# raw_obs.update(gather_robot_observations(self.env, format_gripper_space=False))
|
|
|
|
# Image are in (H, W, C), flip it upside down
|
|
def process_img(img):
|
|
return np.copy(img[::-1, :, :])
|
|
|
|
for obs_name, obs_value in raw_obs.items():
|
|
if obs_name.endswith("_image"):
|
|
# image observations
|
|
raw_obs[obs_name] = process_img(obs_value)
|
|
else:
|
|
# non-image observations
|
|
raw_obs[obs_name] = obs_value.astype(np.float32)
|
|
|
|
# Return black image if rendering is disabled
|
|
if not self.enable_render:
|
|
for ii, name in enumerate(self.camera_names):
|
|
raw_obs[f"{name}_image"] = np.zeros(
|
|
(self.camera_heights[ii], self.camera_widths[ii], 3), dtype=np.uint8
|
|
)
|
|
|
|
self.render_cache = raw_obs[self.render_obs_key]
|
|
raw_obs["language"] = self.env.get_ep_meta().get("lang", "")
|
|
|
|
return raw_obs
|
|
|
|
def convert_body_q(self, q: np.ndarray) -> np.ndarray:
|
|
# q is in the order of the joints
|
|
robot = self.env.robots[0]
|
|
joint_names = np.array(self.env.sim.model.joint_names)[robot._ref_joint_indexes]
|
|
# this joint names are in the order of the obs_vec
|
|
actuated_q = self.obs_action_converter[0].robocasa_to_gr00t_actuated_order(
|
|
joint_names, q, "body"
|
|
)
|
|
return actuated_q
|
|
|
|
def convert_gripper_q(self, q: np.ndarray, side: str = "left") -> np.ndarray:
|
|
# q is in the order of the joints
|
|
robot = self.env.robots[0]
|
|
joint_names = np.array(self.env.sim.model.joint_names)[
|
|
robot._ref_joints_indexes_dict[side + "_gripper"]
|
|
]
|
|
actuated_q = self.obs_action_converter[0].robocasa_to_gr00t_actuated_order(
|
|
joint_names, q, side + "_gripper"
|
|
)
|
|
return actuated_q
|
|
|
|
def convert_gripper_tau(self, tau: np.ndarray, side: str = "left") -> np.ndarray:
|
|
# tau is in the order of the actuators
|
|
robot = self.env.robots[0]
|
|
actuator_idx = robot._ref_actuators_indexes_dict[side + "_gripper"]
|
|
actuated_joint_names = [
|
|
self.env.sim.model.joint_id2name(self.env.sim.model.actuator_trnid[i][0])
|
|
for i in actuator_idx
|
|
]
|
|
actuated_tau = self.obs_action_converter[0].robocasa_to_gr00t_actuated_order(
|
|
actuated_joint_names, tau, side + "_gripper"
|
|
)
|
|
return actuated_tau
|
|
|
|
def get_gr00t_observation(self, raw_obs: Dict[str, Any]) -> Dict[str, Any]:
|
|
obs = {}
|
|
|
|
if self.env.sim.model.jnt_type[0] == mujoco.mjtJoint.mjJNT_FREE:
|
|
# If the first joint is a free joint, use this way to get the floating base data
|
|
obs["floating_base_pose"] = self.env.sim.data.qpos[:7]
|
|
obs["floating_base_vel"] = self.env.sim.data.qvel[:6]
|
|
obs["floating_base_acc"] = self.env.sim.data.qacc[:6]
|
|
else:
|
|
# Otherwise, use self.env.sim.model to fetch the floating base pose
|
|
root_body_id = self.env.sim.model.body_name2id("robot0_base")
|
|
|
|
# Get position and orientation from body state
|
|
root_pos = self.env.sim.data.body_xpos[root_body_id]
|
|
root_quat = self.env.sim.data.body_xquat[root_body_id] # quaternion in wxyz format
|
|
|
|
# Combine position and quaternion to form 7-DOF pose
|
|
obs["floating_base_pose"] = np.concatenate([root_pos, root_quat])
|
|
# set vel and acc to 0
|
|
obs["floating_base_vel"] = np.zeros(6)
|
|
obs["floating_base_acc"] = np.zeros(6)
|
|
|
|
obs["body_q"] = self.convert_body_q(raw_obs["robot0_joint_pos"])
|
|
obs["body_dq"] = self.convert_body_q(raw_obs["robot0_joint_vel"])
|
|
obs["body_ddq"] = self.convert_body_q(raw_obs["robot0_joint_acc"])
|
|
|
|
obs["left_hand_q"] = self.convert_gripper_q(raw_obs["robot0_left_gripper_qpos"], "left")
|
|
obs["left_hand_dq"] = self.convert_gripper_q(raw_obs["robot0_left_gripper_qvel"], "left")
|
|
obs["left_hand_ddq"] = self.convert_gripper_q(raw_obs["robot0_left_gripper_qacc"], "left")
|
|
obs["right_hand_q"] = self.convert_gripper_q(raw_obs["robot0_right_gripper_qpos"], "right")
|
|
obs["right_hand_dq"] = self.convert_gripper_q(raw_obs["robot0_right_gripper_qvel"], "right")
|
|
obs["right_hand_ddq"] = self.convert_gripper_q(
|
|
raw_obs["robot0_right_gripper_qacc"], "right"
|
|
)
|
|
|
|
robot = self.env.robots[0]
|
|
body_tau_idx_list = []
|
|
left_gripper_tau_idx_list = []
|
|
right_gripper_tau_idx_list = []
|
|
for part_name, actuator_idx in robot._ref_actuators_indexes_dict.items():
|
|
if "left_gripper" in part_name:
|
|
left_gripper_tau_idx_list.extend(actuator_idx)
|
|
elif "right_gripper" in part_name:
|
|
right_gripper_tau_idx_list.extend(actuator_idx)
|
|
elif "base" in part_name:
|
|
assert (
|
|
len(actuator_idx) == 0 or robot.robot_model.default_base == "FloatingLeggedBase"
|
|
)
|
|
else:
|
|
body_tau_idx_list.extend(actuator_idx)
|
|
|
|
body_tau_idx_list = sorted(body_tau_idx_list)
|
|
left_gripper_tau_idx_list = sorted(left_gripper_tau_idx_list)
|
|
right_gripper_tau_idx_list = sorted(right_gripper_tau_idx_list)
|
|
obs["body_tau_est"] = self.convert_body_q(
|
|
self.env.sim.data.actuator_force[body_tau_idx_list]
|
|
)
|
|
obs["right_hand_tau_est"] = self.convert_gripper_tau(
|
|
self.env.sim.data.actuator_force[right_gripper_tau_idx_list], "right"
|
|
)
|
|
obs["left_hand_tau_est"] = self.convert_gripper_tau(
|
|
self.env.sim.data.actuator_force[left_gripper_tau_idx_list], "left"
|
|
)
|
|
|
|
obs["time"] = self.env.sim.data.time
|
|
|
|
# Add camera images
|
|
for ii, camera_name in enumerate(self.camera_names):
|
|
mapped_camera_name = self.camera_key_mapper.get_camera_config(camera_name)[0]
|
|
obs[f"{mapped_camera_name}_image"] = raw_obs[f"{camera_name}_image"]
|
|
|
|
# Add privileged observations
|
|
if hasattr(self.env, "get_privileged_obs_keys"):
|
|
for key in self.env.get_privileged_obs_keys():
|
|
obs[key] = raw_obs[key]
|
|
|
|
# Add robot-specific observations
|
|
if hasattr(self.env.robots[0].robot_model, "torso_body"):
|
|
obs["secondary_imu_quat"] = raw_obs["robot0_torso_link_imu_quat"]
|
|
obs["secondary_imu_vel"] = raw_obs["robot0_torso_link_imu_vel"]
|
|
|
|
obs["language.language_instruction"] = raw_obs["language"]
|
|
|
|
return obs
|