You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
63 lines
2.4 KiB
63 lines
2.4 KiB
from typing import Any, Dict, Tuple
|
|
|
|
from unitree_sdk2py.core.channel import ChannelFactoryInitialize
|
|
|
|
from decoupled_wbc.control.envs.g1.sim.unitree_sdk2py_bridge import UnitreeSdk2Bridge
|
|
from decoupled_wbc.control.envs.robocasa.async_env_server import RoboCasaEnvServer
|
|
from decoupled_wbc.control.robot_model.instantiation import get_robot_type_and_model
|
|
|
|
|
|
class RoboCasaG1EnvServer(RoboCasaEnvServer):
|
|
def __init__(
|
|
self, env_name: str, wbc_config: Dict[str, Any], env_kwargs: Dict[str, Any], **kwargs
|
|
):
|
|
if UnitreeSdk2Bridge is None:
|
|
raise ImportError("UnitreeSdk2Bridge is required for RoboCasaG1EnvServer")
|
|
self.wbc_config = wbc_config
|
|
_, robot_model = get_robot_type_and_model(
|
|
"G1",
|
|
enable_waist_ik=wbc_config["enable_waist"],
|
|
)
|
|
if env_kwargs.get("camera_names", None) is None:
|
|
env_kwargs["camera_names"] = [
|
|
"robot0_oak_egoview",
|
|
"robot0_oak_left_monoview",
|
|
"robot0_oak_right_monoview",
|
|
"robot0_rs_tppview",
|
|
]
|
|
if env_kwargs.get("render_camera", None) is None:
|
|
if env_kwargs.get("renderer", "mjviewer") == "mjviewer":
|
|
env_kwargs["render_camera"] = "robot0_oak_egoview"
|
|
else:
|
|
env_kwargs["render_camera"] = [
|
|
"robot0_oak_egoview",
|
|
"robot0_rs_tppview",
|
|
]
|
|
|
|
super().__init__(env_name, "G1", robot_model, env_kwargs=env_kwargs, **kwargs)
|
|
|
|
def init_channel(self):
|
|
|
|
try:
|
|
if self.wbc_config.get("INTERFACE", None):
|
|
ChannelFactoryInitialize(self.wbc_config["DOMAIN_ID"], self.wbc_config["INTERFACE"])
|
|
else:
|
|
ChannelFactoryInitialize(self.wbc_config["DOMAIN_ID"])
|
|
except Exception:
|
|
# If it fails because it's already initialized, that's okay
|
|
pass
|
|
|
|
self.channel_bridge = UnitreeSdk2Bridge(config=self.wbc_config)
|
|
|
|
def publish_obs(self):
|
|
# with self.cache_lock:
|
|
obs = self.caches["obs"]
|
|
self.channel_bridge.PublishLowState(obs)
|
|
|
|
def get_action(self) -> Tuple[Dict[str, Any], bool, bool]:
|
|
q, ready, is_new_action = self.channel_bridge.GetAction()
|
|
return {"q": q}, ready, is_new_action
|
|
|
|
def reset(self):
|
|
super().reset()
|
|
self.channel_bridge.reset()
|