You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
226 lines
8.6 KiB
226 lines
8.6 KiB
from typing import Any, Dict, SupportsFloat, Tuple
|
|
|
|
import gymnasium as gym
|
|
from gymnasium import spaces
|
|
import numpy as np
|
|
|
|
from decoupled_wbc.control.main.constants import DEFAULT_BASE_HEIGHT, DEFAULT_NAV_CMD
|
|
from decoupled_wbc.control.main.teleop.configs.configs import SyncSimDataCollectionConfig
|
|
from decoupled_wbc.control.policy.wbc_policy_factory import get_wbc_policy
|
|
from decoupled_wbc.control.robot_model import RobotModel
|
|
from decoupled_wbc.control.robot_model.instantiation import get_robot_type_and_model
|
|
|
|
|
|
class WholeBodyControlWrapper(gym.Wrapper):
|
|
"""Gymnasium wrapper to integrate whole-body control for locomotion/manipulation sims."""
|
|
|
|
def __init__(self, env, script_config):
|
|
super().__init__(env)
|
|
self.script_config = script_config
|
|
self.script_config["robot"] = env.unwrapped.robot_name
|
|
self.wbc_policy = self.setup_wbc_policy()
|
|
self._action_space = self._wbc_action_space()
|
|
|
|
@property
|
|
def robot_model(self) -> RobotModel:
|
|
"""Return the robot model from the wrapped environment."""
|
|
return self.env.unwrapped.robot_model # type: ignore
|
|
|
|
def reset(self, **kwargs):
|
|
obs, info = self.env.reset(**kwargs)
|
|
self.wbc_policy = self.setup_wbc_policy()
|
|
self.wbc_policy.set_observation(obs)
|
|
return obs, info
|
|
|
|
def step(self, action: Dict[str, Any]) -> Tuple[Any, SupportsFloat, bool, bool, Dict[str, Any]]:
|
|
action_dict = concat_action(self.robot_model, action)
|
|
|
|
wbc_goal = {}
|
|
for key in ["navigate_cmd", "base_height_command", "target_upper_body_pose"]:
|
|
if key in action_dict:
|
|
wbc_goal[key] = action_dict[key]
|
|
|
|
self.wbc_policy.set_goal(wbc_goal)
|
|
wbc_action = self.wbc_policy.get_action()
|
|
|
|
result = super().step(wbc_action)
|
|
self.wbc_policy.set_observation(result[0])
|
|
return result
|
|
|
|
def setup_wbc_policy(self):
|
|
robot_type, robot_model = get_robot_type_and_model(
|
|
self.script_config["robot"],
|
|
enable_waist_ik=self.script_config.get("enable_waist", False),
|
|
)
|
|
config = SyncSimDataCollectionConfig.from_dict(self.script_config)
|
|
config.update(
|
|
{
|
|
"save_img_obs": False,
|
|
"ik_indicator": False,
|
|
"enable_real_device": False,
|
|
"replay_data_path": None,
|
|
}
|
|
)
|
|
wbc_config = config.load_wbc_yaml()
|
|
wbc_config["upper_body_policy_type"] = "identity"
|
|
wbc_policy = get_wbc_policy(robot_type, robot_model, wbc_config, init_time=0.0)
|
|
self.total_dofs = len(robot_model.get_joint_group_indices("upper_body"))
|
|
wbc_policy.activate_policy()
|
|
return wbc_policy
|
|
|
|
def _get_joint_group_size(self, group_name: str) -> int:
|
|
"""Return the number of joints in a group, cached since lookup is static."""
|
|
if not hasattr(self, "_joint_group_size_cache"):
|
|
self._joint_group_size_cache = {}
|
|
if group_name not in self._joint_group_size_cache:
|
|
self._joint_group_size_cache[group_name] = len(
|
|
self.robot_model.get_joint_group_indices(group_name)
|
|
)
|
|
return self._joint_group_size_cache[group_name]
|
|
|
|
def _wbc_action_space(self) -> spaces.Dict:
|
|
action_space: Dict[str, spaces.Space] = {
|
|
"action.navigate_command": spaces.Box(
|
|
low=-np.inf, high=np.inf, shape=(3,), dtype=np.float32
|
|
),
|
|
"action.base_height_command": spaces.Box(
|
|
low=-np.inf, high=np.inf, shape=(1,), dtype=np.float32
|
|
),
|
|
"action.left_hand": spaces.Box(
|
|
low=-np.inf,
|
|
high=np.inf,
|
|
shape=(self._get_joint_group_size("left_hand"),),
|
|
dtype=np.float32,
|
|
),
|
|
"action.right_hand": spaces.Box(
|
|
low=-np.inf,
|
|
high=np.inf,
|
|
shape=(self._get_joint_group_size("right_hand"),),
|
|
dtype=np.float32,
|
|
),
|
|
"action.left_arm": spaces.Box(
|
|
low=-np.inf,
|
|
high=np.inf,
|
|
shape=(self._get_joint_group_size("left_arm"),),
|
|
dtype=np.float32,
|
|
),
|
|
"action.right_arm": spaces.Box(
|
|
low=-np.inf,
|
|
high=np.inf,
|
|
shape=(self._get_joint_group_size("right_arm"),),
|
|
dtype=np.float32,
|
|
),
|
|
}
|
|
if (
|
|
"waist"
|
|
in self.robot_model.supplemental_info.joint_groups["upper_body_no_hands"]["groups"] # type: ignore[attr-defined]
|
|
):
|
|
action_space["action.waist"] = spaces.Box(
|
|
low=-np.inf,
|
|
high=np.inf,
|
|
shape=(self._get_joint_group_size("waist"),),
|
|
dtype=np.float32,
|
|
)
|
|
return spaces.Dict(action_space)
|
|
|
|
|
|
def concat_action(robot_model: RobotModel, goal: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""Combine individual joint-group targets into the upper-body action vector."""
|
|
processed_goal = {}
|
|
for key, value in goal.items():
|
|
processed_goal[key.replace("action.", "")] = value
|
|
|
|
first_value = next(iter(processed_goal.values()))
|
|
action = np.zeros(first_value.shape[:-1] + (robot_model.num_dofs,))
|
|
|
|
action_dict = {}
|
|
action_dict["navigate_cmd"] = processed_goal.pop("navigate_command", DEFAULT_NAV_CMD)
|
|
action_dict["base_height_command"] = np.array(
|
|
processed_goal.pop("base_height_command", DEFAULT_BASE_HEIGHT)
|
|
)
|
|
|
|
for joint_group, value in processed_goal.items():
|
|
indices = robot_model.get_joint_group_indices(joint_group)
|
|
action[..., indices] = value
|
|
|
|
upper_body_indices = robot_model.get_joint_group_indices("upper_body")
|
|
action = action[..., upper_body_indices]
|
|
action_dict["target_upper_body_pose"] = action
|
|
return action_dict
|
|
|
|
|
|
def prepare_observation_for_eval(robot_model: RobotModel, obs: dict) -> dict:
|
|
"""Add joint-group slices to an observation dict (real + sim evaluation helper)."""
|
|
assert "q" in obs, "q is not in the observation"
|
|
|
|
whole_q = obs["q"]
|
|
assert whole_q.shape[-1] == robot_model.num_joints, "q has wrong shape"
|
|
|
|
left_arm_q = whole_q[..., robot_model.get_joint_group_indices("left_arm")]
|
|
right_arm_q = whole_q[..., robot_model.get_joint_group_indices("right_arm")]
|
|
waist_q = whole_q[..., robot_model.get_joint_group_indices("waist")]
|
|
left_leg_q = whole_q[..., robot_model.get_joint_group_indices("left_leg")]
|
|
right_leg_q = whole_q[..., robot_model.get_joint_group_indices("right_leg")]
|
|
left_hand_q = whole_q[..., robot_model.get_joint_group_indices("left_hand")]
|
|
right_hand_q = whole_q[..., robot_model.get_joint_group_indices("right_hand")]
|
|
|
|
obs["state.left_arm"] = left_arm_q
|
|
obs["state.right_arm"] = right_arm_q
|
|
obs["state.waist"] = waist_q
|
|
obs["state.left_leg"] = left_leg_q
|
|
obs["state.right_leg"] = right_leg_q
|
|
obs["state.left_hand"] = left_hand_q
|
|
obs["state.right_hand"] = right_hand_q
|
|
|
|
return obs
|
|
|
|
|
|
def prepare_gym_space_for_eval(
|
|
robot_model: RobotModel, gym_space: gym.spaces.Dict
|
|
) -> gym.spaces.Dict:
|
|
"""Extend a gym Dict space with the joint-group keys used during evaluation."""
|
|
left_arm_space = spaces.Box(
|
|
low=-np.inf,
|
|
high=np.inf,
|
|
shape=(len(robot_model.get_joint_group_indices("left_arm")),),
|
|
)
|
|
right_arm_space = spaces.Box(
|
|
low=-np.inf,
|
|
high=np.inf,
|
|
shape=(len(robot_model.get_joint_group_indices("right_arm")),),
|
|
)
|
|
waist_space = spaces.Box(
|
|
low=-np.inf,
|
|
high=np.inf,
|
|
shape=(len(robot_model.get_joint_group_indices("waist")),),
|
|
)
|
|
left_leg_space = spaces.Box(
|
|
low=-np.inf,
|
|
high=np.inf,
|
|
shape=(len(robot_model.get_joint_group_indices("left_leg")),),
|
|
)
|
|
right_leg_space = spaces.Box(
|
|
low=-np.inf,
|
|
high=np.inf,
|
|
shape=(len(robot_model.get_joint_group_indices("right_leg")),),
|
|
)
|
|
left_hand_space = spaces.Box(
|
|
low=-np.inf,
|
|
high=np.inf,
|
|
shape=(len(robot_model.get_joint_group_indices("left_hand")),),
|
|
)
|
|
right_hand_space = spaces.Box(
|
|
low=-np.inf,
|
|
high=np.inf,
|
|
shape=(len(robot_model.get_joint_group_indices("right_hand")),),
|
|
)
|
|
|
|
gym_space["state.left_arm"] = left_arm_space
|
|
gym_space["state.right_arm"] = right_arm_space
|
|
gym_space["state.waist"] = waist_space
|
|
gym_space["state.left_leg"] = left_leg_space
|
|
gym_space["state.right_leg"] = right_leg_space
|
|
gym_space["state.left_hand"] = left_hand_space
|
|
gym_space["state.right_hand"] = right_hand_space
|
|
|
|
return gym_space
|