You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

591 lines
23 KiB

from datetime import datetime
import pathlib
import shutil
from typing import Any, Dict, Optional
import numpy as np
from gr00t_wbc.control.envs.robocasa.sync_env import G1SyncEnv, SyncEnv
from gr00t_wbc.control.envs.robocasa.utils.controller_utils import (
get_body_ik_solver_settings_type,
update_robosuite_controller_configs,
)
from gr00t_wbc.control.main.constants import DEFAULT_BASE_HEIGHT, DEFAULT_NAV_CMD
from gr00t_wbc.control.main.teleop.configs.configs import SyncSimDataCollectionConfig
from gr00t_wbc.control.policy.teleop_policy import TeleopPolicy
from gr00t_wbc.control.policy.wbc_policy_factory import get_wbc_policy
from gr00t_wbc.control.robot_model.instantiation import get_robot_type_and_model
from gr00t_wbc.control.robot_model.robot_model import RobotModel
from gr00t_wbc.control.teleop.solver.hand.instantiation.g1_hand_ik_instantiation import (
instantiate_g1_hand_ik_solver,
)
from gr00t_wbc.control.teleop.teleop_retargeting_ik import TeleopRetargetingIK
from gr00t_wbc.control.utils.episode_state import EpisodeState
from gr00t_wbc.control.utils.text_to_speech import TextToSpeech
from gr00t_wbc.data.exporter import Gr00tDataExporter
from gr00t_wbc.data.utils import get_dataset_features, get_modality_config
MAX_MUJOCO_STATE_LEN = 800
COLLECTION_KEY = "c"
SKIP_KEY = "x"
class EpisodeManager:
"""Manages episode state transitions, done flags, and task completion hold counts.
This class encapsulates the logic for:
- Episode state management (IDLE -> RECORDING -> NEED_TO_SAVE)
- Done flag handling
- Task completion hold count tracking
- Data collection triggering based on manual/auto mode
- Step counting within episodes
"""
def __init__(self, config: SyncSimDataCollectionConfig):
self.config = config
self.task_completion_hold_count = -1
self.done = False
self.step_count = 0
# Initialize episode state and text-to-speech for both manual and automatic modes
self.episode_state = EpisodeState()
self.text_to_speech = TextToSpeech()
def should_collect_data(self) -> bool:
"""Determine if data should be collected at this timestep."""
if not self.config.data_collection:
return False
if self.config.manual_control:
# Manual mode: only collect when RECORDING or NEED_TO_SAVE
return self.episode_state.get_state() in [
self.episode_state.RECORDING,
self.episode_state.NEED_TO_SAVE,
]
else:
# Auto mode: collect when RECORDING or NEED_TO_SAVE
return self.episode_state.get_state() in [
self.episode_state.RECORDING,
self.episode_state.NEED_TO_SAVE,
]
def increment_step(self):
"""Increment the step counter."""
self.step_count += 1
def reset_step_count(self):
"""Reset the step counter to 0."""
self.step_count = 0
def get_step_count(self) -> int:
"""Get the current step count."""
return self.step_count
def handle_collection_trigger(
self, wbc_goal: dict, keyboard_input: str | None, step_info: dict
):
"""Handle data collection start/stop triggers.
Args:
wbc_goal: WBC goal dictionary
keyboard_input: Keyboard input from user (can be None)
step_info: Step information from environment containing success flag
"""
self.done = False
state_changed = False
if self.config.manual_control:
if wbc_goal.get("toggle_data_collection", False) or (
keyboard_input and keyboard_input == COLLECTION_KEY
):
self.episode_state.change_state()
# IDLE -> RECORDING or RECORDING -> NEED_TO_SAVE
state_changed = True
else:
# Auto mode: automatically start collecting data for new env
if self.episode_state.get_state() == self.episode_state.IDLE:
self.episode_state.change_state() # IDLE -> RECORDING
state_changed = True
self.done = step_info.get("success", False)
if (
self.episode_state.get_state() == self.episode_state.RECORDING
and self.done
and self.task_completion_hold_count == 0
):
self.episode_state.change_state() # RECORDING -> NEED_TO_SAVE
state_changed = True
# Check for CI test completion
if self.config.ci_test and self.step_count >= self._get_ci_test_steps():
self.episode_state.change_state() # RECORDING -> NEED_TO_SAVE
state_changed = True
if state_changed:
if self.episode_state.get_state() == self.episode_state.RECORDING:
self.text_to_speech.print_and_say("Started recording episode")
elif self.episode_state.get_state() == self.episode_state.NEED_TO_SAVE:
self.done = True
self.task_completion_hold_count = 0
self.text_to_speech.print_and_say("Stopping recording, preparing to save")
def check_export_and_completion(self, exporter: Gr00tDataExporter) -> bool:
"""Check if episode should be exported and update completion state.
Args:
exporter: Data exporter instance
Returns:
bool: True if environment needs to be reset
"""
need_reset = False
# Check if we should save the episode
if self.task_completion_hold_count == 0:
exporter.save_episode()
need_reset = True
self.task_completion_hold_count = -1
if self.episode_state.get_state() == self.episode_state.NEED_TO_SAVE:
self.episode_state.change_state() # NEED_TO_SAVE -> IDLE
self.text_to_speech.print_and_say("Episode saved.")
# State machine to check for having a success for N consecutive timesteps
elif self.done:
print(
f"Task success detected! Will collect {self.config.success_hold_steps} additional steps..."
)
print(f"currently {self.task_completion_hold_count}")
if self.task_completion_hold_count > 0:
self.task_completion_hold_count -= 1 # latched state, decrement count
print(f"Task completed! Collecting {self.task_completion_hold_count} more steps...")
else:
self.task_completion_hold_count = (
self.config.success_hold_steps
) # reset count on first success timestep
print(
f"Task success detected! Will collect {self.config.success_hold_steps} additional steps..."
)
else:
self.task_completion_hold_count = -1 # null the counter if there's no success
return need_reset
def handle_skip(
self, wbc_goal: dict, keyboard_input: str | None, exporter: Gr00tDataExporter
) -> bool:
"""Handle episode skip/abort.
Args:
wbc_goal: WBC goal dictionary
keyboard_input: Keyboard input from user (can be None)
exporter: Data exporter instance
Returns:
bool: True if episode was skipped and needs reset
"""
if wbc_goal.get("toggle_data_abort", False) or (
keyboard_input and keyboard_input == SKIP_KEY
):
exporter.skip_and_start_new_episode()
self.episode_state.reset_state()
self.text_to_speech.print_and_say("Episode discarded, starting new episode")
return True
return False
def _get_ci_test_steps(self) -> int:
"""Get CI test steps based on CI test mode."""
if self.config.get("ci_test_mode", "unit") == "unit":
return 50
else: # pre_merge
return 500
class CITestManager:
def __init__(self, config: SyncSimDataCollectionConfig):
self.config = config
self.ci_test_steps = 50 if config.ci_test_mode == "unit" else 500
self.enable_tracking_check = True if config.ci_test_mode == "pre_merge" else False
self.upper_body_speed_hist = []
self.last_q_upper_body = None
self.end_effector_tracking_errors = []
def check_upper_body_motion(
self, robot_model: RobotModel, wbc_action: dict, config: SyncSimDataCollectionConfig
):
upper_body_joint_indices = robot_model.get_joint_group_indices("upper_body")
q_upper_body = wbc_action["q"][upper_body_joint_indices]
if self.last_q_upper_body is None:
self.last_q_upper_body = q_upper_body.copy()
self.upper_body_speed_hist.append(
np.abs(q_upper_body - self.last_q_upper_body).mean() * config.control_frequency
)
self.last_q_upper_body = q_upper_body.copy()
assert q_upper_body.mean() != 0, "Upper body joints should not be zero"
def check_end_effector_tracking(
self, teleop_cmd: dict, obs: dict, config: SyncSimDataCollectionConfig, i: int
):
"""
Check end effector tracking error and validate thresholds.
Args:
teleop_cmd: Teleoperation command containing target poses
obs: Environment observation containing current poses
end_effector_tracking_errors: List to store tracking errors
config: Configuration object containing robot type
i: Current step index
ci_test_steps: Number of steps for CI test
upper_body_speed_hist: History of upper body joint speeds
"""
from scipy.spatial.transform import Rotation as R
# Get target poses from teleop command (replay data format)
target_left_wrist = teleop_cmd.get("left_wrist")
target_right_wrist = teleop_cmd.get("right_wrist")
wrist_pose = obs.get("wrist_pose")
# Extract left and right wrist poses from environment observation (7 for left + 7 for right)
left_pos = wrist_pose[:3]
left_quat = wrist_pose[3:7]
right_pos = wrist_pose[7:10]
right_quat = wrist_pose[10:14]
# Convert quaternions to rotation matrices for error calculation
left_rot_matrix = R.from_quat(left_quat, scalar_first=True).as_matrix()
right_rot_matrix = R.from_quat(right_quat, scalar_first=True).as_matrix()
# Construct 4x4 transformation matrices
actual_left_wrist = np.eye(4)
actual_left_wrist[:3, 3] = left_pos
actual_left_wrist[:3, :3] = left_rot_matrix
actual_right_wrist = np.eye(4)
actual_right_wrist[:3, 3] = right_pos
actual_right_wrist[:3, :3] = right_rot_matrix
# Calculate position error
left_pos_error = np.linalg.norm(target_left_wrist[:3, 3] - actual_left_wrist[:3, 3])
right_pos_error = np.linalg.norm(target_right_wrist[:3, 3] - actual_right_wrist[:3, 3])
# Calculate rotation error (similar to test_teleop_retargeting_ik)
left_rot_diff = actual_left_wrist[:3, :3] @ target_left_wrist[:3, :3].T
left_rot_error = np.arccos(np.clip((np.trace(left_rot_diff) - 1) / 2, -1, 1))
right_rot_diff = actual_right_wrist[:3, :3] @ target_right_wrist[:3, :3].T
right_rot_error = np.arccos(np.clip((np.trace(right_rot_diff) - 1) / 2, -1, 1))
# Store max error for this timestep
max_pos_error = max(left_pos_error, right_pos_error)
max_rot_error = max(left_rot_error, right_rot_error)
self.end_effector_tracking_errors.append((max_pos_error, max_rot_error))
if i >= self.ci_test_steps:
max_pos_errors = [error[0] for error in self.end_effector_tracking_errors]
max_rot_errors = [error[1] for error in self.end_effector_tracking_errors]
max_pos_error = np.max(max_pos_errors)
max_rot_error = np.max(max_rot_errors)
average_pos_error = np.mean(max_pos_errors)
average_rot_error = np.mean(max_rot_errors)
# More realistic thresholds based on observed data
max_pos_threshold = 0.07 # 7cm threshold
max_rot_threshold = np.deg2rad(17) # 17 degree threshold
average_pos_threshold = 0.05 # 5cm threshold
average_rot_threshold = np.deg2rad(12) # 12 degree threshold
print(f" Position errors - Max: {max_pos_error:.4f}")
print(f" Rotation errors - Max: {np.rad2deg(max_rot_error):.2f}°")
print(f" Average position error: {average_pos_error:.4f}")
print(f" Average rotation error: {np.rad2deg(average_rot_error):.2f}°")
assert (
max_pos_error < max_pos_threshold
), "Maximum end effector position tracking error exceeds threshold"
assert (
max_rot_error < max_rot_threshold
), "Maximum end effector rotation tracking error exceeds threshold"
assert (
average_pos_error < average_pos_threshold
), "Average end effector position tracking error exceeds threshold"
assert (
average_rot_error < average_rot_threshold
), "Average end effector rotation tracking error exceeds threshold"
assert (
np.array(self.upper_body_speed_hist).mean() > 0.03
), "Mean upper body joint velocities should be larger. Robot might not be moving."
print("End effector tracking validation passed.")
# Ensure end effector tracking validation has run
if not self.end_effector_tracking_errors:
assert False, "No end effector tracking data collected during CI test"
elif len(self.end_effector_tracking_errors) < self.ci_test_steps:
assert (
False
), f"Only {len(self.end_effector_tracking_errors)} end effector tracking samples collected"
def get_features(
robot_model: RobotModel, save_img_obs: bool = False, image_obs_configs: dict[str, dict] = {}
) -> dict[str, dict]:
"""Fixture providing test features dict."""
features = get_dataset_features(robot_model)
features.update(
{
"observation.sim.seed": {
"dtype": "int32",
"shape": (1,),
},
"observation.sim.max_mujoco_state_len": {
"dtype": "int32",
"shape": (1,),
},
"observation.sim.mujoco_state_len": {
"dtype": "int32",
"shape": (1,),
},
"observation.sim.mujoco_state": {
"dtype": "float64",
"shape": (MAX_MUJOCO_STATE_LEN,),
},
"observation.sim.left_wrist": {
"dtype": "float64",
"shape": (16,),
},
"observation.sim.right_wrist": {
"dtype": "float64",
"shape": (16,),
},
"observation.sim.left_fingers": {
"dtype": "float64",
"shape": (400,),
},
"observation.sim.right_fingers": {
"dtype": "float64",
"shape": (400,),
},
"observation.sim.target_upper_body_pose": {
"dtype": "float64",
"shape": (len(robot_model.get_joint_group_indices("upper_body")),),
},
# TODO: support different reduced robot models
}
)
features.pop("observation.img_state_delta")
features.pop("observation.images.ego_view")
if save_img_obs:
for key, value in image_obs_configs.items():
features.update(
{
f"observation.images.{key.replace('_image', '')}": {
"dtype": "video",
"shape": value["shape"],
"names": ["height", "width", "channel"],
}
}
)
return features
def generate_frame(
obs: Dict[str, Any],
wbc_action: Dict[str, Any],
seed: int,
mujoco_state: np.ndarray,
mujoco_state_len: int,
max_mujoco_state_len: int,
teleop_cmd: Dict[str, Any],
wbc_goal: Dict[str, Any],
save_img_obs: bool = False,
):
frame = {
"observation.state": np.array(obs["q"], dtype=np.float64),
"observation.eef_state": np.array(obs["wrist_pose"], dtype=np.float64),
"action": np.array(wbc_action["q"], dtype=np.float64),
"action.eef": np.array(wbc_goal["wrist_pose"], dtype=np.float64),
"teleop.navigate_command": np.array(
wbc_goal.get("navigate_cmd", DEFAULT_NAV_CMD), dtype=np.float64
),
"teleop.base_height_command": np.array(
teleop_cmd.get("base_height_command", [DEFAULT_BASE_HEIGHT]), dtype=np.float64
).reshape(
1,
),
"observation.sim.seed": np.array([seed], dtype=np.int32),
"observation.sim.mujoco_state_len": np.array([mujoco_state_len], dtype=np.int32),
"observation.sim.max_mujoco_state_len": np.array([max_mujoco_state_len], dtype=np.int32),
"observation.sim.mujoco_state": mujoco_state.astype(np.float64),
"observation.sim.left_wrist": teleop_cmd["left_wrist"].flatten().astype(np.float64),
"observation.sim.right_wrist": teleop_cmd["right_wrist"].flatten().astype(np.float64),
"observation.sim.left_fingers": teleop_cmd["left_fingers"]["position"]
.flatten()
.astype(np.float64),
"observation.sim.right_fingers": teleop_cmd["right_fingers"]["position"]
.flatten()
.astype(np.float64),
"observation.sim.target_upper_body_pose": wbc_goal["target_upper_body_pose"].astype(
np.float64
),
# "observation.sim.target_time": np.array([wbc_goal["target_time"]], dtype=np.float64),
# "observation.sim.interpolation_garbage_collection_time": np.array(
# [wbc_goal["interpolation_garbage_collection_time"]], dtype=np.float64
# ),
}
if save_img_obs:
for key, value in obs.items():
if key.endswith("image"):
frame[f"observation.images.{key.replace('_image', '')}"] = value
return frame
def get_data_exporter(
config: SyncSimDataCollectionConfig,
obs: Dict[str, Any],
robot_model: RobotModel,
save_path: Optional[pathlib.Path] = None,
) -> Gr00tDataExporter:
if save_path is None:
save_path = pathlib.Path(
f"./outputs/{datetime.now().strftime('%Y-%m-%d-%H-%M-%S')}-{config.robot}-sim-{config.task_name}/"
)
if config.remove_existing_dir:
if save_path.exists():
shutil.rmtree(save_path)
image_obs_configs = {}
for key, value in obs.items():
if key.endswith("image"):
image_obs_configs[key] = {
"dtype": "uint8",
"shape": value.shape,
}
# TODO: use standardized keys for training dataset
modality_config = get_modality_config(robot_model)
exporter = Gr00tDataExporter.create(
save_root=save_path,
fps=config.control_frequency,
features=get_features(
robot_model, save_img_obs=config.save_img_obs, image_obs_configs=image_obs_configs
),
modality_config=modality_config,
task=config.task_name,
script_config=config,
robot_type=config.robot,
vcodec="libx264", # Use a common codec that should be available
)
return exporter
def get_env(config: SyncSimDataCollectionConfig, **kwargs) -> SyncEnv:
robot_type, _ = get_robot_type_and_model(config.robot, enable_waist_ik=config.enable_waist)
print("Instantiating environment:", config.env_name, config.robot)
controller_configs = update_robosuite_controller_configs(
robot=config.robot,
wbc_version=config.wbc_version,
enable_gravity_compensation=config.enable_gravity_compensation,
)
kwargs.update(
{
"ik_indicator": config.ik_indicator,
"control_freq": config.control_frequency,
"renderer": config.renderer,
"controller_configs": controller_configs,
"enable_waist": config.enable_waist,
"enable_gravity_compensation": config.enable_gravity_compensation,
"gravity_compensation_joints": config.gravity_compensation_joints,
"sim_freq": config.sim_frequency,
}
)
if robot_type == "g1":
env_type = G1SyncEnv
else:
raise ValueError(f"Unsupported robot type: {robot_type}")
env = env_type(
env_name=config.task_name, # TODO: should merge with config.env_name
robot_name=config.robot,
**kwargs,
)
return env
def get_env_name(robot: str, task_name: str, enable_waist_ik: bool = False) -> str:
robot_type, _ = get_robot_type_and_model(robot, enable_waist_ik=enable_waist_ik)
env_name = f"gr00tlocomanip_{robot_type}_sim/{task_name}_{robot}_Env"
return env_name
def get_body_teleoped_joint_groups(robot: str) -> list[str]:
robot2body_teleoped_joint_groups = {
"G1FixedLowerBody": ["arms"],
"G1FixedBase": ["arms"],
"G1ArmsOnly": ["arms"],
"G1": ["upper_body"],
}
return robot2body_teleoped_joint_groups[robot]
def get_teleop_policy(
robot_type: str,
robot_model: RobotModel,
config: SyncSimDataCollectionConfig,
activate_keyboard_listener: bool = True,
) -> TeleopPolicy:
if robot_type == "g1":
left_hand_ik_solver, right_hand_ik_solver = instantiate_g1_hand_ik_solver()
else:
raise ValueError(f"Invalid robot type: {robot_type}")
# Initializing the teleop policy will block the main process until the Leap Motion is ready.
retargeting_ik = TeleopRetargetingIK(
robot_model=robot_model,
left_hand_ik_solver=left_hand_ik_solver,
right_hand_ik_solver=right_hand_ik_solver,
enable_visualization=config.enable_visualization,
body_active_joint_groups=get_body_teleoped_joint_groups(config.robot),
body_ik_solver_settings_type=get_body_ik_solver_settings_type(config.robot),
)
teleop_policy = TeleopPolicy(
robot_model=robot_model,
retargeting_ik=retargeting_ik,
body_control_device=config.body_control_device,
hand_control_device=config.hand_control_device,
body_streamer_ip=config.body_streamer_ip,
body_streamer_keyword=config.body_streamer_keyword,
enable_real_device=config.enable_real_device,
replay_data_path=config.replay_data_path,
replay_speed=config.replay_speed,
activate_keyboard_listener=activate_keyboard_listener,
)
return teleop_policy
def get_wbc_config(config: SyncSimDataCollectionConfig):
wbc_config = config.load_wbc_yaml()
wbc_config["upper_body_policy_type"] = "identity"
return wbc_config
def get_policies(
config: SyncSimDataCollectionConfig,
robot_type: str,
robot_model: RobotModel,
activate_keyboard_listener: bool = True,
):
wbc_config = get_wbc_config(config)
wbc_policy = get_wbc_policy(robot_type, robot_model, wbc_config, init_time=0.0)
wbc_policy.activate_policy()
teleop_policy = get_teleop_policy(robot_type, robot_model, config, activate_keyboard_listener)
if not config.manual_control:
teleop_policy.activate_policy()
return wbc_policy, teleop_policy