You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

305 lines
11 KiB

from abc import abstractmethod
import threading
import time
from typing import Any, Dict, Tuple
import mujoco
import numpy as np
import rclpy
from decoupled_wbc.control.envs.g1.sim.image_publish_utils import ImagePublishProcess
from decoupled_wbc.control.envs.robocasa.utils.robocasa_env import (
Gr00tLocomanipRoboCasaEnv,
) # noqa: F401
from decoupled_wbc.control.robot_model.robot_model import RobotModel
from decoupled_wbc.control.utils.keyboard_dispatcher import KeyboardListenerSubscriber
class RoboCasaEnvServer:
"""
This class is responsible for running the simulation environment loop in a separate thread.
It communicates with the main thread via the `publish_obs` and `get_action` methods through `channel_bridge`.
It will also handle the viewer sync when `onscreen` is True.
"""
def __init__(
self,
env_name: str,
robot_name: str,
robot_model: RobotModel,
env_kwargs: Dict[str, Any],
**kwargs,
):
# initialize environment
if env_kwargs.get("onscreen", False):
env_kwargs["onscreen"] = False
self.onscreen = True # onscreen render in the main thread
self.render_camera = env_kwargs.get("render_camera", None)
else:
self.onscreen = False
self.env_name = env_name
self.env = Gr00tLocomanipRoboCasaEnv(env_name, robot_name, robot_model, **env_kwargs)
self.init_caches()
self.cache_lock = threading.Lock()
# initialize channel
self.init_channel()
# initialize ROS2 node
if not rclpy.ok():
rclpy.init()
self.node = rclpy.create_node("sim_robocasa")
self.thread = threading.Thread(target=rclpy.spin, args=(self.node,), daemon=True)
self.thread.start()
else:
self.thread = None
executor = rclpy.get_global_executor()
self.node = executor.get_nodes()[0] # will only take the first node
self.control_freq = env_kwargs.get("control_freq", 1 / 0.02)
self.sim_freq = kwargs.get("sim_freq", 1 / 0.005)
self.control_rate = self.node.create_rate(self.control_freq)
self.running = False
self.sim_thread = None
self.sync_lock = threading.Lock()
self.sync_mode = kwargs.get("sync_mode", False)
self.steps_per_action = kwargs.get("steps_per_action", 1)
self.image_dt = kwargs.get("image_dt", 0.04)
self.image_publish_process = None
self.viewer_freq = kwargs.get("viewer_freq", 1 / 0.02)
self.viewer = None
self.verbose = kwargs.get("verbose", True)
# Initialize keyboard listener for env reset
self.keyboard_listener = KeyboardListenerSubscriber()
self.reset()
@property
def base_env(self):
return self.env.env
def start_image_publish_subprocess(self, start_method: str = "spawn", camera_port: int = 5555):
"""Initialize image publishing subprocess if cameras are configured"""
if len(self.env.camera_names) == 0:
print(
"Warning: No camera configs provided, image publishing subprocess will not be started"
)
return
# Build camera configs from env camera settings
camera_configs = {}
for env_cam_name in self.env.camera_names:
camera_config = self.env.camera_key_mapper.get_camera_config(env_cam_name)
mapped_cam_name, cam_width, cam_height = camera_config
camera_configs[mapped_cam_name] = {"height": cam_height, "width": cam_width}
self.image_publish_process = ImagePublishProcess(
camera_configs=camera_configs,
image_dt=self.image_dt,
zmq_port=camera_port,
start_method=start_method,
verbose=self.verbose,
)
self.image_publish_process.start_process()
def update_render_caches(self, obs: Dict[str, Any]):
"""Update render cache and shared memory for subprocess"""
if self.image_publish_process is None:
return
# Extract image observations from obs dict
render_caches = {
k: v for k, v in obs.items() if k.endswith("_image") and isinstance(v, np.ndarray)
}
# Update shared memory if image publishing process is available
if render_caches:
self.image_publish_process.update_shared_memory(render_caches)
def init_caches(self):
self.caches = {
"obs": None,
"reward": None,
"terminated": None,
"truncated": None,
"info": None,
}
def reset(self, **kwargs):
if self.viewer is not None:
self.viewer.close()
obs, info = self.env.reset(**kwargs)
self.caches["obs"] = obs
self.caches["reward"] = 0
self.caches["terminated"] = False
self.caches["truncated"] = False
self.caches["info"] = info
# initialize viewer
if self.onscreen:
self.viewer = mujoco.viewer.launch_passive(
self.base_env.sim.model._model,
self.base_env.sim.data._data,
show_left_ui=False,
show_right_ui=False,
)
self.viewer.opt.geomgroup[0] = 0 # disable collision visualization
if self.render_camera is not None:
self.viewer.cam.type = mujoco.mjtCamera.mjCAMERA_FIXED
self.viewer.cam.fixedcamid = self.base_env.sim.model._model.cam(
self.render_camera
).id
# self.episode_state.reset_state()
return obs, info
@abstractmethod
def init_channel(self):
raise NotImplementedError("init_channel must be implemented by the subclass")
@abstractmethod
def publish_obs(self):
raise NotImplementedError("publish_obs must be implemented by the subclass")
@abstractmethod
def get_action(self) -> Tuple[Dict[str, Any], bool, bool]:
raise NotImplementedError("get_action must be implemented by the subclass")
def start_as_thread(self):
"""Start the simulation thread"""
if self.sim_thread is not None and self.sim_thread.is_alive():
return
self.sim_thread = threading.Thread(target=self.start)
self.sim_thread.daemon = True
self.sim_thread.start()
def set_sync_mode(self, sync_mode: bool, steps_per_action: int = 4):
"""Set the sync mode of the environment server"""
with self.sync_lock:
self.sync_mode = sync_mode
self.steps_per_action = steps_per_action
def _check_keyboard_input(self):
"""Check for keyboard input and handle state transitions"""
key = self.keyboard_listener.read_msg()
if key == "k":
print("\033[1;32m[Sim env]\033[0m Resetting sim environment")
self.reset()
def start(self):
"""Function executed by the simulation thread"""
iter_idx = 0
steps_per_cur_action = 0
t_start = time.monotonic()
self.running = True
while self.running:
# Check keyboard input for state transitions
self._check_keyboard_input()
# Publish observations and get new action
self.publish_obs()
action, ready, is_new_action = self.get_action()
# ready is True if the action is received from the control loop
# is_new_action is True if the action is new (not the same as the previous action)
with self.sync_lock:
sync_mode = self.sync_mode
max_steps_per_action = self.steps_per_action
# Process action if ready and within step limits
action_should_apply = ready and (
(not sync_mode) or steps_per_cur_action < max_steps_per_action
)
if action_should_apply:
obs, reward, terminated, truncated, info = self.env.step(action)
with self.cache_lock:
self.caches["obs"] = obs
self.caches["reward"] = reward
self.caches["terminated"] = terminated
self.caches["truncated"] = truncated
self.caches["info"] = info
if reward == 1.0 and iter_idx % 50 == 0:
print("\033[92mTask successful. Can save data now.\033[0m")
iter_idx += 1
steps_per_cur_action += 1
if self.verbose and sync_mode:
print("steps_per_cur_action: ", steps_per_cur_action)
# Update render caches at image publishing rate
if action_should_apply and iter_idx % int(self.image_dt * self.control_freq) == 0:
with self.cache_lock:
obs_copy = self.caches["obs"].copy()
self.update_render_caches(obs_copy)
# Reset step counter for new actions
if is_new_action:
steps_per_cur_action = 0
# Update viewer at specified frequency
if self.onscreen and iter_idx % (self.control_freq / self.viewer_freq) == 0:
self.viewer.sync()
# Check if we're meeting the desired control frequency
if iter_idx % 100 == 0:
end_time = time.monotonic()
if self.verbose:
print(
f"sim FPS: {100.0 / (end_time - t_start) * (self.sim_freq / self.control_freq)}"
)
if (end_time - t_start) > ((110.0 / self.control_freq)): # for tolerance
print(
f"Warning: Sim runs at "
"{100.0/(end_time - t_start) * (self.sim_freq / self.control_freq):.1f}Hz, "
f"but should run at {self.sim_freq:.1f}Hz"
)
t_start = end_time
# reset obj pos every 200 steps
if iter_idx % 200 == 0:
if hasattr(self.base_env, "reset_obj_pos"):
self.base_env.reset_obj_pos()
self.control_rate.sleep()
def get_privileged_obs(self):
"""Get privileged observation. Should be implemented by subclasses."""
obs = {}
with self.cache_lock:
if hasattr(self.base_env, "get_privileged_obs_keys"):
for key in self.base_env.get_privileged_obs_keys():
obs[key] = self.caches["obs"][key]
for key in self.caches["obs"].keys():
if key.endswith("_image"):
obs[key] = self.caches["obs"][key]
return obs
def stop(self):
"""Stop the simulation thread"""
self.running = False
if self.sim_thread is not None:
self.sim_thread.join(timeout=1.0) # Wait for thread to finish with timeout
self.sim_thread = None
def close(self):
self.stop()
if self.image_publish_process is not None:
self.image_publish_process.stop()
if self.onscreen:
self.viewer.close()
self.env.close()
def get_reward(self):
return self.base_env.reward()