You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
156 lines
6.4 KiB
156 lines
6.4 KiB
from gr00t_wbc.control.robot_model.robot_model import RobotModel
|
|
from gr00t_wbc.data.constants import RS_VIEW_CAMERA_HEIGHT, RS_VIEW_CAMERA_WIDTH
|
|
|
|
|
|
def get_modality_config(robot_model: RobotModel, add_stereo_camera: bool = False) -> dict:
|
|
"""
|
|
Get the modality config for the robot model.
|
|
"""
|
|
left_hand_indices = sorted(robot_model.get_joint_group_indices("left_hand"))
|
|
right_hand_indices = sorted(robot_model.get_joint_group_indices("right_hand"))
|
|
left_arm_indices = sorted(robot_model.get_joint_group_indices("left_arm"))
|
|
right_arm_indices = sorted(robot_model.get_joint_group_indices("right_arm"))
|
|
waist_indices = sorted(robot_model.get_joint_group_indices("waist"))
|
|
left_leg_indices = sorted(robot_model.get_joint_group_indices("left_leg"))
|
|
right_leg_indices = sorted(robot_model.get_joint_group_indices("right_leg"))
|
|
|
|
modality_config = {
|
|
"state": {
|
|
"left_leg": {"start": left_leg_indices[0], "end": left_leg_indices[-1] + 1},
|
|
"right_leg": {"start": right_leg_indices[0], "end": right_leg_indices[-1] + 1},
|
|
"waist": {"start": waist_indices[0], "end": waist_indices[-1] + 1},
|
|
"left_arm": {"start": left_arm_indices[0], "end": left_arm_indices[-1] + 1},
|
|
"left_hand": {"start": left_hand_indices[0], "end": left_hand_indices[-1] + 1},
|
|
"right_arm": {"start": right_arm_indices[0], "end": right_arm_indices[-1] + 1},
|
|
"right_hand": {"start": right_hand_indices[0], "end": right_hand_indices[-1] + 1},
|
|
"left_wrist_pos": {"start": 0, "end": 3, "original_key": "observation.eef_state"},
|
|
"left_wrist_abs_quat": {
|
|
"start": 3,
|
|
"end": 7,
|
|
"original_key": "observation.eef_state",
|
|
"rotation_type": "quaternion",
|
|
},
|
|
"right_wrist_pos": {"start": 7, "end": 10, "original_key": "observation.eef_state"},
|
|
"right_wrist_abs_quat": {
|
|
"start": 10,
|
|
"end": 14,
|
|
"original_key": "observation.eef_state",
|
|
"rotation_type": "quaternion",
|
|
},
|
|
},
|
|
"action": {
|
|
"left_leg": {"start": left_leg_indices[0], "end": left_leg_indices[-1] + 1},
|
|
"right_leg": {"start": right_leg_indices[0], "end": right_leg_indices[-1] + 1},
|
|
"waist": {"start": waist_indices[0], "end": waist_indices[-1] + 1},
|
|
"left_arm": {"start": left_arm_indices[0], "end": left_arm_indices[-1] + 1},
|
|
"left_hand": {"start": left_hand_indices[0], "end": left_hand_indices[-1] + 1},
|
|
"right_arm": {"start": right_arm_indices[0], "end": right_arm_indices[-1] + 1},
|
|
"right_hand": {"start": right_hand_indices[0], "end": right_hand_indices[-1] + 1},
|
|
"left_wrist_pos": {"start": 0, "end": 3, "original_key": "action.eef"},
|
|
"left_wrist_abs_quat": {
|
|
"start": 3,
|
|
"end": 7,
|
|
"original_key": "action.eef",
|
|
"rotation_type": "quaternion",
|
|
},
|
|
"right_wrist_pos": {"start": 7, "end": 10, "original_key": "action.eef"},
|
|
"right_wrist_abs_quat": {
|
|
"start": 10,
|
|
"end": 14,
|
|
"original_key": "action.eef",
|
|
"rotation_type": "quaternion",
|
|
},
|
|
"base_height_command": {
|
|
"start": 0,
|
|
"end": 1,
|
|
"original_key": "teleop.base_height_command",
|
|
},
|
|
"navigate_command": {"start": 0, "end": 3, "original_key": "teleop.navigate_command"},
|
|
},
|
|
"video": {"ego_view": {"original_key": "observation.images.ego_view"}},
|
|
"annotation": {"human.task_description": {"original_key": "task_index"}},
|
|
}
|
|
if add_stereo_camera:
|
|
modality_config["video"].update(
|
|
{
|
|
"ego_view_left_mono": {"original_key": "observation.images.ego_view_left_mono"},
|
|
"ego_view_right_mono": {"original_key": "observation.images.ego_view_right_mono"},
|
|
}
|
|
)
|
|
|
|
return modality_config
|
|
|
|
|
|
def get_dataset_features(robot_model: RobotModel, add_stereo_camera: bool = False) -> dict:
|
|
"""
|
|
Get the dataset features for the robot model.
|
|
"""
|
|
dataset_features = {
|
|
"observation.images.ego_view": {
|
|
"dtype": "video",
|
|
"shape": [RS_VIEW_CAMERA_HEIGHT, RS_VIEW_CAMERA_WIDTH, 3],
|
|
"names": ["height", "width", "channel"],
|
|
},
|
|
"observation.state": {
|
|
"dtype": "float64",
|
|
"shape": (robot_model.num_joints,),
|
|
"names": robot_model.joint_names,
|
|
},
|
|
"observation.eef_state": {
|
|
"dtype": "float64",
|
|
"shape": (14,),
|
|
"names": [
|
|
"left_wrist_pos",
|
|
"left_wrist_abs_quat",
|
|
"right_wrist_pos",
|
|
"right_wrist_abs_quat",
|
|
],
|
|
},
|
|
"action": {
|
|
"dtype": "float64",
|
|
"shape": (robot_model.num_joints,),
|
|
"names": robot_model.joint_names,
|
|
},
|
|
"action.eef": {
|
|
"dtype": "float64",
|
|
"shape": (14,),
|
|
"names": [
|
|
"left_wrist_pos",
|
|
"left_wrist_abs_quat",
|
|
"right_wrist_pos",
|
|
"right_wrist_abs_quat",
|
|
],
|
|
},
|
|
"observation.img_state_delta": {
|
|
"dtype": "float32",
|
|
"shape": (1,),
|
|
"names": "img_state_delta",
|
|
},
|
|
"teleop.navigate_command": {
|
|
"dtype": "float64",
|
|
"shape": (3,),
|
|
"names": ["lin_vel_x", "lin_vel_y", "ang_vel_z"],
|
|
},
|
|
"teleop.base_height_command": {
|
|
"dtype": "float64",
|
|
"shape": (1,),
|
|
"names": "base_height_command",
|
|
},
|
|
}
|
|
if add_stereo_camera:
|
|
dataset_features.update(
|
|
{
|
|
"observation.images.ego_view_left_mono": {
|
|
"dtype": "video",
|
|
"shape": [RS_VIEW_CAMERA_HEIGHT, RS_VIEW_CAMERA_WIDTH, 3],
|
|
"names": ["height", "width", "channel"],
|
|
},
|
|
"observation.images.ego_view_right_mono": {
|
|
"dtype": "video",
|
|
"shape": [RS_VIEW_CAMERA_HEIGHT, RS_VIEW_CAMERA_WIDTH, 3],
|
|
"names": ["height", "width", "channel"],
|
|
},
|
|
}
|
|
)
|
|
|
|
return dataset_features
|