You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

201 lines
6.9 KiB

"""Offloads rendered MuJoCo camera images to a subprocess via shared memory + ZMQ."""
import multiprocessing as mp
from multiprocessing import shared_memory
import time
from typing import Any, Dict
import numpy as np
from gear_sonic.utils.mujoco_sim.sensor_server import ImageMessageSchema, SensorServer
def get_multiprocessing_info(verbose: bool = True):
"""Get information about multiprocessing start methods"""
if verbose:
print(f"Available start methods: {mp.get_all_start_methods()}")
return mp.get_start_method()
class ImagePublishProcess:
"""Subprocess for publishing images using shared memory and ZMQ"""
def __init__(
self,
camera_configs: Dict[str, Any],
image_dt: float,
zmq_port: int = 5555,
start_method: str = "spawn",
verbose: bool = False,
):
self.camera_configs = camera_configs
self.image_dt = image_dt
self.zmq_port = zmq_port
self.verbose = verbose
self.shared_memory_blocks = {}
self.shared_memory_info = {}
self.process = None
self.mp_context = mp.get_context(start_method)
if self.verbose:
print(f"Using multiprocessing context: {start_method}")
self.stop_event = self.mp_context.Event()
self.data_ready_event = self.mp_context.Event()
self.stop_event.clear()
self.data_ready_event.clear()
for camera_name, camera_config in camera_configs.items():
height = camera_config["height"]
width = camera_config["width"]
size = height * width * 3
shm = shared_memory.SharedMemory(create=True, size=size)
self.shared_memory_blocks[camera_name] = shm
self.shared_memory_info[camera_name] = {
"name": shm.name,
"size": size,
"shape": (height, width, 3),
"dtype": np.uint8,
}
def start_process(self):
"""Start the image publishing subprocess"""
self.process = self.mp_context.Process(
target=self._image_publish_worker,
args=(
self.shared_memory_info,
self.image_dt,
self.zmq_port,
self.stop_event,
self.data_ready_event,
self.verbose,
),
)
self.process.start()
def update_shared_memory(self, render_caches: Dict[str, np.ndarray]):
"""Update shared memory with new rendered images"""
images_updated = 0
for camera_name in self.camera_configs.keys():
image_key = f"{camera_name}_image"
if image_key in render_caches:
image = render_caches[image_key]
if image.dtype != np.uint8:
image = (image * 255).astype(np.uint8)
shm = self.shared_memory_blocks[camera_name]
shared_array = np.ndarray(
self.shared_memory_info[camera_name]["shape"],
dtype=self.shared_memory_info[camera_name]["dtype"],
buffer=shm.buf,
)
np.copyto(shared_array, image)
images_updated += 1
if images_updated > 0:
self.data_ready_event.set()
def stop(self):
"""Stop the image publishing subprocess"""
self.stop_event.set()
if self.process and self.process.is_alive():
self.process.join(timeout=5)
if self.process.is_alive():
self.process.terminate()
self.process.join(timeout=2)
if self.process.is_alive():
self.process.kill()
self.process.join()
for camera_name, shm in self.shared_memory_blocks.items():
try:
shm.close()
shm.unlink()
except Exception as e:
print(f"Warning: Failed to cleanup shared memory for {camera_name}: {e}")
self.shared_memory_blocks.clear()
@staticmethod
def _image_publish_worker(
shared_memory_info, image_dt, zmq_port, stop_event, data_ready_event, verbose
):
"""Worker function that runs in the subprocess"""
try:
sensor_server = SensorServer()
sensor_server.start_server(port=zmq_port)
shared_arrays = {}
shm_blocks = {}
for camera_name, info in shared_memory_info.items():
shm = shared_memory.SharedMemory(name=info["name"])
shm_blocks[camera_name] = shm
shared_arrays[camera_name] = np.ndarray(
info["shape"], dtype=info["dtype"], buffer=shm.buf
)
print(
f"Image publishing subprocess started with {len(shared_arrays)} cameras "
f"on ZMQ port {zmq_port}"
)
loop_count = 0
last_data_time = time.time()
while not stop_event.is_set():
loop_count += 1
timeout = min(image_dt, 0.1)
data_available = data_ready_event.wait(timeout=timeout)
current_time = time.time()
if data_available:
data_ready_event.clear()
if loop_count % 50 == 0:
print("Image publish frequency: ", 1 / (current_time - last_data_time))
last_data_time = current_time
try:
from gear_sonic.utils.mujoco_sim.sensor_server import ImageUtils
image_copies = {name: arr.copy() for name, arr in shared_arrays.items()}
message_dict = {
"images": image_copies,
"timestamps": {name: current_time for name in image_copies.keys()},
}
image_msg = ImageMessageSchema(
timestamps=message_dict.get("timestamps"),
images=message_dict.get("images", None),
)
serialized_data = image_msg.serialize()
for camera_name, image_copy in image_copies.items():
serialized_data[f"{camera_name}"] = ImageUtils.encode_image(image_copy)
sensor_server.send_message(serialized_data)
except Exception as e:
print(f"Error publishing images: {e}")
if not data_available:
time.sleep(0.001)
except KeyboardInterrupt:
print("Image publisher interrupted by user")
finally:
try:
for shm in shm_blocks.values():
shm.close()
sensor_server.stop_server()
except Exception as e:
print(f"Error during subprocess cleanup: {e}")