You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
534 lines
23 KiB
534 lines
23 KiB
"""Joint safety monitor for G1 robot.
|
|
|
|
This module implements safety monitoring for arm and finger joint velocities using
|
|
joint groups defined in the robot model's supplemental info. Leg joints are not monitored.
|
|
"""
|
|
|
|
from datetime import datetime
|
|
import sys
|
|
import time
|
|
from typing import Dict, List, Optional, Tuple
|
|
|
|
import numpy as np
|
|
|
|
from decoupled_wbc.data.viz.rerun_viz import RerunViz
|
|
|
|
|
|
class JointSafetyMonitor:
|
|
"""Monitor joint velocities for G1 robot arms and hands."""
|
|
|
|
# Velocity limits in rad/s
|
|
ARM_VELOCITY_LIMIT = 6.0 # rad/s for arm joints
|
|
HAND_VELOCITY_LIMIT = 50.0 # rad/s for finger joints
|
|
|
|
def __init__(self, robot_model, enable_viz: bool = False, env_type: str = "real"):
|
|
"""Initialize joint safety monitor.
|
|
|
|
Args:
|
|
robot_model: The robot model containing joint information
|
|
enable_viz: If True, enable rerun visualization (default False)
|
|
env_type: Environment type - "sim" or "real" (default "real")
|
|
"""
|
|
self.robot_model = robot_model
|
|
self.safety_margin = 1.0 # Hardcoded safety margin
|
|
self.enable_viz = enable_viz
|
|
self.env_type = env_type
|
|
|
|
# Startup ramping parameters
|
|
self.control_frequency = 50 # Hz, hardcoded from run_g1_control_loop.py
|
|
self.ramp_duration_steps = int(2.0 * self.control_frequency) # 2 seconds * 50Hz = 100 steps
|
|
self.startup_counter = 0
|
|
self.initial_positions = None
|
|
self.startup_complete = False
|
|
|
|
# Initialize velocity and position limits for monitored joints
|
|
self.velocity_limits = {}
|
|
self.position_limits = {}
|
|
self._initialize_limits()
|
|
|
|
# Track violations for reporting
|
|
self.violations = []
|
|
|
|
# Initialize visualization
|
|
self.right_arm_indices = None
|
|
self.right_arm_joint_names = []
|
|
self.left_arm_indices = None
|
|
self.left_arm_joint_names = []
|
|
self.right_hand_indices = None
|
|
self.right_hand_joint_names = []
|
|
self.left_hand_indices = None
|
|
self.left_hand_joint_names = []
|
|
try:
|
|
arm_indices = self.robot_model.get_joint_group_indices("arms")
|
|
all_joint_names = [self.robot_model.joint_names[i] for i in arm_indices]
|
|
# Filter for right and left arm joints
|
|
self.right_arm_joint_names = [
|
|
name for name in all_joint_names if name.startswith("right_")
|
|
]
|
|
self.right_arm_indices = [
|
|
self.robot_model.joint_to_dof_index[name] for name in self.right_arm_joint_names
|
|
]
|
|
self.left_arm_joint_names = [
|
|
name for name in all_joint_names if name.startswith("left_")
|
|
]
|
|
self.left_arm_indices = [
|
|
self.robot_model.joint_to_dof_index[name] for name in self.left_arm_joint_names
|
|
]
|
|
# Hand joints
|
|
hand_indices = self.robot_model.get_joint_group_indices("hands")
|
|
all_hand_names = [self.robot_model.joint_names[i] for i in hand_indices]
|
|
self.right_hand_joint_names = [
|
|
name for name in all_hand_names if name.startswith("right_")
|
|
]
|
|
self.right_hand_indices = [
|
|
self.robot_model.joint_to_dof_index[name] for name in self.right_hand_joint_names
|
|
]
|
|
self.left_hand_joint_names = [
|
|
name for name in all_hand_names if name.startswith("left_")
|
|
]
|
|
self.left_hand_indices = [
|
|
self.robot_model.joint_to_dof_index[name] for name in self.left_hand_joint_names
|
|
]
|
|
except ValueError as e:
|
|
print(f"[JointSafetyMonitor] Warning: Could not initialize arm/hand visualization: {e}")
|
|
except Exception:
|
|
pass
|
|
|
|
# Use single tensor_key for each plot
|
|
self.right_arm_pos_key = "right_arm_qpos"
|
|
self.left_arm_pos_key = "left_arm_qpos"
|
|
self.right_arm_vel_key = "right_arm_dq"
|
|
self.left_arm_vel_key = "left_arm_dq"
|
|
self.right_hand_pos_key = "right_hand_qpos"
|
|
self.left_hand_pos_key = "left_hand_qpos"
|
|
self.right_hand_vel_key = "right_hand_dq"
|
|
self.left_hand_vel_key = "left_hand_dq"
|
|
|
|
# Define a consistent color palette for up to 8 joints (tab10 + extra)
|
|
self.joint_colors = [
|
|
[31, 119, 180], # blue
|
|
[255, 127, 14], # orange
|
|
[44, 160, 44], # green
|
|
[214, 39, 40], # red
|
|
[148, 103, 189], # purple
|
|
[140, 86, 75], # brown
|
|
[227, 119, 194], # pink
|
|
[127, 127, 127], # gray (for 8th joint if needed)
|
|
]
|
|
|
|
# Initialize Rerun visualization only if enabled
|
|
self.viz = None
|
|
if self.enable_viz:
|
|
try:
|
|
self.viz = RerunViz(
|
|
image_keys=[],
|
|
tensor_keys=[
|
|
self.right_arm_pos_key,
|
|
self.left_arm_pos_key,
|
|
self.right_arm_vel_key,
|
|
self.left_arm_vel_key,
|
|
self.right_hand_pos_key,
|
|
self.left_hand_pos_key,
|
|
self.right_hand_vel_key,
|
|
self.left_hand_vel_key,
|
|
],
|
|
window_size=10.0,
|
|
app_name="joint_safety_monitor",
|
|
)
|
|
except Exception:
|
|
self.viz = None
|
|
|
|
def _initialize_limits(self):
|
|
"""Initialize velocity and position limits for arm and hand joints using robot model joint groups."""
|
|
if self.robot_model.supplemental_info is None:
|
|
raise ValueError("Robot model must have supplemental_info to use joint groups")
|
|
|
|
# Get arm joint indices from robot model joint groups
|
|
try:
|
|
arm_indices = self.robot_model.get_joint_group_indices("arms")
|
|
arm_joint_names = [self.robot_model.joint_names[i] for i in arm_indices]
|
|
|
|
for joint_name in arm_joint_names:
|
|
# Set velocity limits
|
|
vel_limit = self.ARM_VELOCITY_LIMIT * self.safety_margin
|
|
self.velocity_limits[joint_name] = {"min": -vel_limit, "max": vel_limit}
|
|
|
|
# Set position limits from robot model
|
|
if joint_name in self.robot_model.joint_to_dof_index:
|
|
joint_idx = self.robot_model.joint_to_dof_index[joint_name]
|
|
# Adjust index for floating base if present
|
|
limit_idx = joint_idx - (7 if self.robot_model.is_floating_base_model else 0)
|
|
|
|
if 0 <= limit_idx < len(self.robot_model.lower_joint_limits):
|
|
pos_min = self.robot_model.lower_joint_limits[limit_idx]
|
|
pos_max = self.robot_model.upper_joint_limits[limit_idx]
|
|
|
|
# Apply safety margin to position limits
|
|
pos_range = pos_max - pos_min
|
|
margin = pos_range * (1.0 - self.safety_margin) / 2.0
|
|
|
|
self.position_limits[joint_name] = {
|
|
"min": pos_min + margin,
|
|
"max": pos_max - margin,
|
|
}
|
|
except ValueError as e:
|
|
print(f"[JointSafetyMonitor] Warning: Could not find 'arms' joint group: {e}")
|
|
|
|
# Get hand joint indices from robot model joint groups
|
|
try:
|
|
hand_indices = self.robot_model.get_joint_group_indices("hands")
|
|
hand_joint_names = [self.robot_model.joint_names[i] for i in hand_indices]
|
|
|
|
for joint_name in hand_joint_names:
|
|
# Set velocity limits only for hands (no position limits for now)
|
|
vel_limit = self.HAND_VELOCITY_LIMIT * self.safety_margin
|
|
self.velocity_limits[joint_name] = {"min": -vel_limit, "max": vel_limit}
|
|
except ValueError as e:
|
|
print(f"[JointSafetyMonitor] Warning: Could not find 'hands' joint group: {e}")
|
|
|
|
def check_safety(self, obs: Dict, action: Dict) -> Tuple[bool, List[Dict]]:
|
|
"""Check if current velocities and positions are within safe bounds.
|
|
|
|
Args:
|
|
obs: Observation dictionary containing joint positions and velocities
|
|
action: Action dictionary containing target positions
|
|
|
|
Returns:
|
|
(is_safe, violations): Tuple of safety status and list of violations
|
|
Note: is_safe=False only for velocity violations (triggers shutdown)
|
|
Position violations are warnings only (don't affect is_safe)
|
|
"""
|
|
self.violations = []
|
|
is_safe = True
|
|
joint_names = self.robot_model.joint_names
|
|
|
|
# Check current joint velocities (critical - triggers shutdown)
|
|
if "dq" in obs:
|
|
joint_velocities = obs["dq"]
|
|
|
|
for i, joint_name in enumerate(joint_names):
|
|
# Only check monitored joints
|
|
if joint_name not in self.velocity_limits:
|
|
continue
|
|
|
|
if i < len(joint_velocities):
|
|
velocity = joint_velocities[i]
|
|
limits = self.velocity_limits[joint_name]
|
|
|
|
if velocity < limits["min"] or velocity > limits["max"]:
|
|
violation = {
|
|
"joint": joint_name,
|
|
"type": "velocity",
|
|
"value": velocity,
|
|
"limit_min": limits["min"],
|
|
"limit_max": limits["max"],
|
|
"exceeded_by": self._calculate_exceeded_percentage(
|
|
velocity, limits["min"], limits["max"]
|
|
),
|
|
"critical": True, # Velocity violations are critical
|
|
}
|
|
self.violations.append(violation)
|
|
is_safe = False
|
|
|
|
# Check current joint positions (warning only - no shutdown)
|
|
if "q" in obs:
|
|
joint_positions = obs["q"]
|
|
|
|
for i, joint_name in enumerate(joint_names):
|
|
# Only check joints with position limits (arms)
|
|
if joint_name not in self.position_limits:
|
|
continue
|
|
|
|
if i < len(joint_positions):
|
|
position = joint_positions[i]
|
|
limits = self.position_limits[joint_name]
|
|
|
|
if position < limits["min"] or position > limits["max"]:
|
|
violation = {
|
|
"joint": joint_name,
|
|
"type": "position",
|
|
"value": position,
|
|
"limit_min": limits["min"],
|
|
"limit_max": limits["max"],
|
|
"exceeded_by": self._calculate_exceeded_percentage(
|
|
position, limits["min"], limits["max"]
|
|
),
|
|
"critical": False, # Position violations are warnings only
|
|
}
|
|
self.violations.append(violation)
|
|
# Don't set is_safe = False for position violations
|
|
|
|
return is_safe, self.violations
|
|
|
|
def _calculate_exceeded_percentage(
|
|
self, value: float, limit_min: float, limit_max: float
|
|
) -> float:
|
|
"""Calculate by how much percentage a value exceeds the limits."""
|
|
if value < limit_min:
|
|
return abs((value - limit_min) / limit_min) * 100
|
|
elif value > limit_max:
|
|
return abs((value - limit_max) / limit_max) * 100
|
|
return 0.0
|
|
|
|
def get_safe_action(self, obs: Dict, original_action: Dict) -> Dict:
|
|
"""Generate a safe action with startup ramping for smooth initialization.
|
|
|
|
Args:
|
|
obs: Observation dictionary containing current joint positions
|
|
original_action: The original action that may cause violations
|
|
|
|
Returns:
|
|
Safe action with startup ramping applied if within ramp duration
|
|
"""
|
|
safe_action = original_action.copy()
|
|
|
|
# Handle startup ramping for arm joints
|
|
if not self.startup_complete:
|
|
if self.initial_positions is None and "q" in obs:
|
|
# Store initial positions from first observation
|
|
self.initial_positions = obs["q"].copy()
|
|
|
|
if (
|
|
self.startup_counter < self.ramp_duration_steps
|
|
and self.initial_positions is not None
|
|
and "q" in safe_action
|
|
):
|
|
# Ramp factor: 0.0 at start → 1.0 at end
|
|
ramp_factor = self.startup_counter / self.ramp_duration_steps
|
|
|
|
# Apply ramping only to monitored arm joints
|
|
for joint_name in self.velocity_limits: # Only monitored arm joints
|
|
if joint_name in self.robot_model.joint_to_dof_index:
|
|
joint_idx = self.robot_model.joint_to_dof_index[joint_name]
|
|
if joint_idx < len(safe_action["q"]) and joint_idx < len(
|
|
self.initial_positions
|
|
):
|
|
initial_pos = self.initial_positions[joint_idx]
|
|
target_pos = original_action["q"][joint_idx]
|
|
# Linear interpolation: initial + ramp_factor * (target - initial)
|
|
safe_action["q"][joint_idx] = initial_pos + ramp_factor * (
|
|
target_pos - initial_pos
|
|
)
|
|
|
|
# Increment counter for next iteration
|
|
self.startup_counter += 1
|
|
else:
|
|
# Ramping complete - use original actions
|
|
self.startup_complete = True
|
|
|
|
return safe_action
|
|
|
|
def get_violation_report(self, violations: Optional[List[Dict]] = None) -> str:
|
|
"""Generate a formatted error report for violations.
|
|
|
|
Args:
|
|
violations: List of violations to report (uses self.violations if None)
|
|
|
|
Returns:
|
|
Formatted error message string
|
|
"""
|
|
if violations is None:
|
|
violations = self.violations
|
|
|
|
if not violations:
|
|
return "No violations detected."
|
|
|
|
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")[:-3]
|
|
|
|
# Check if these are critical violations or warnings
|
|
critical_violations = [v for v in violations if v.get("critical", True)]
|
|
warning_violations = [v for v in violations if not v.get("critical", True)]
|
|
|
|
if critical_violations and warning_violations:
|
|
report = f"Joint safety bounds exceeded!\nTimestamp: {timestamp}\nViolations:\n"
|
|
elif critical_violations:
|
|
report = f"Joint safety bounds exceeded!\nTimestamp: {timestamp}\nViolations:\n"
|
|
else:
|
|
report = f"Joint position warnings!\nTimestamp: {timestamp}\nWarnings:\n"
|
|
|
|
for violation in violations:
|
|
joint = violation["joint"]
|
|
vtype = violation["type"]
|
|
value = violation["value"]
|
|
exceeded = violation["exceeded_by"]
|
|
limit_min = violation["limit_min"]
|
|
limit_max = violation["limit_max"]
|
|
|
|
if vtype == "velocity":
|
|
report += f" - {joint}: {vtype}={value:.3f} rad/s "
|
|
report += f"(limit: ±{limit_max:.3f} rad/s) - "
|
|
report += f"EXCEEDED by {exceeded:.1f}%\n"
|
|
elif vtype == "position":
|
|
report += f" - {joint}: {vtype}={value:.3f} rad "
|
|
report += f"(limits: [{limit_min:.3f}, {limit_max:.3f}] rad) - "
|
|
report += f"EXCEEDED by {exceeded:.1f}%\n"
|
|
|
|
# Add appropriate action message
|
|
if critical_violations:
|
|
report += "Action: Safe mode engaged (kp=0, tau=0). System shutdown initiated.\n"
|
|
report += "Please restart Docker container to resume operation."
|
|
else:
|
|
report += "Action: Position warning only. Robot continues operation."
|
|
|
|
return report
|
|
|
|
def handle_violations(self, obs: Dict, action: Dict) -> Dict:
|
|
"""Check safety and handle violations appropriately.
|
|
|
|
Args:
|
|
obs: Observation dictionary
|
|
action: Action dictionary
|
|
|
|
Returns:
|
|
Dict with keys:
|
|
- 'safe_to_continue': bool - whether robot should continue operation
|
|
- 'action': Dict - potentially modified safe action
|
|
- 'shutdown_required': bool - whether system shutdown is needed
|
|
"""
|
|
is_safe, violations = self.check_safety(obs, action)
|
|
|
|
# Apply startup ramping (always, regardless of violations)
|
|
safe_action = self.get_safe_action(obs, action)
|
|
|
|
# Visualize arm and hand joint positions and velocities if enabled
|
|
if self.enable_viz:
|
|
if (
|
|
self.right_arm_indices is not None
|
|
and self.left_arm_indices is not None
|
|
and self.right_hand_indices is not None
|
|
and self.left_hand_indices is not None
|
|
and "q" in obs
|
|
and "dq" in obs
|
|
and self.viz is not None
|
|
):
|
|
try:
|
|
right_arm_positions = obs["q"][self.right_arm_indices]
|
|
left_arm_positions = obs["q"][self.left_arm_indices]
|
|
right_arm_velocities = obs["dq"][self.right_arm_indices]
|
|
left_arm_velocities = obs["dq"][self.left_arm_indices]
|
|
right_hand_positions = obs["q"][self.right_hand_indices]
|
|
left_hand_positions = obs["q"][self.left_hand_indices]
|
|
right_hand_velocities = obs["dq"][self.right_hand_indices]
|
|
left_hand_velocities = obs["dq"][self.left_hand_indices]
|
|
tensor_dict = {
|
|
self.right_arm_pos_key: right_arm_positions,
|
|
self.left_arm_pos_key: left_arm_positions,
|
|
self.right_arm_vel_key: right_arm_velocities,
|
|
self.left_arm_vel_key: left_arm_velocities,
|
|
self.right_hand_pos_key: right_hand_positions,
|
|
self.left_hand_pos_key: left_hand_positions,
|
|
self.right_hand_vel_key: right_hand_velocities,
|
|
self.left_hand_vel_key: left_hand_velocities,
|
|
}
|
|
self.viz.plot_tensors(tensor_dict, time.time())
|
|
except Exception:
|
|
pass
|
|
|
|
if not violations:
|
|
return {"safe_to_continue": True, "action": safe_action, "shutdown_required": False}
|
|
|
|
# Separate critical (velocity) and warning (position) violations
|
|
critical_violations = [v for v in violations if v.get("critical", True)]
|
|
# warning_violations = [v for v in violations if not v.get('critical', True)]
|
|
|
|
# Print warnings for position violations
|
|
# if warning_violations:
|
|
# warning_msg = self.get_violation_report(warning_violations)
|
|
# print(f"[SAFETY WARNING] {warning_msg}")
|
|
|
|
# Handle critical violations (velocity) - trigger shutdown
|
|
if not is_safe and critical_violations:
|
|
error_msg = self.get_violation_report(critical_violations)
|
|
if self.env_type == "real":
|
|
print(f"[SAFETY VIOLATION] {error_msg}")
|
|
self.trigger_system_shutdown()
|
|
|
|
return {"safe_to_continue": False, "action": safe_action, "shutdown_required": True}
|
|
|
|
# Only position violations - continue with safe action
|
|
return {"safe_to_continue": True, "action": safe_action, "shutdown_required": False}
|
|
|
|
def trigger_system_shutdown(self):
|
|
"""Trigger system shutdown after safety violation."""
|
|
print("\n[SAFETY] Initiating system shutdown due to safety violation...")
|
|
sys.exit(1)
|
|
|
|
|
|
def main():
|
|
"""Test the joint safety monitor with joint groups."""
|
|
print("Testing joint safety monitor with joint groups...")
|
|
|
|
try:
|
|
from decoupled_wbc.control.robot_model.instantiation.g1 import instantiate_g1_robot_model
|
|
|
|
# Instantiate robot model
|
|
robot_model = instantiate_g1_robot_model()
|
|
print(f"Robot model created with {len(robot_model.joint_names)} joints")
|
|
|
|
# Create safety monitor
|
|
safety_monitor = JointSafetyMonitor(robot_model)
|
|
print("Safety monitor created successfully!")
|
|
print(f"Monitoring {len(safety_monitor.velocity_limits)} joints")
|
|
|
|
# Print monitored joints
|
|
print("\nVelocity limits:")
|
|
for joint_name, limits in safety_monitor.velocity_limits.items():
|
|
print(f" - {joint_name}: ±{limits['max']:.2f} rad/s")
|
|
|
|
print(f"\nPosition limits (arms only): {len(safety_monitor.position_limits)} joints")
|
|
for joint_name, limits in safety_monitor.position_limits.items():
|
|
print(f" - {joint_name}: [{limits['min']:.3f}, {limits['max']:.3f}] rad")
|
|
|
|
# Test safety checking with safe values
|
|
print("\n--- Testing Safety Checking ---")
|
|
|
|
# Create mock observation with safe values
|
|
safe_obs = {
|
|
"q": np.zeros(robot_model.num_dofs), # All joints at zero position
|
|
"dq": np.zeros(robot_model.num_dofs), # All joints at zero velocity
|
|
}
|
|
safe_action = {"q": np.zeros(robot_model.num_dofs)}
|
|
|
|
# Test handle_violations method
|
|
result = safety_monitor.handle_violations(safe_obs, safe_action)
|
|
print(
|
|
f"Safe values test: safe_to_continue={result['safe_to_continue']}, "
|
|
f"shutdown_required={result['shutdown_required']}"
|
|
)
|
|
|
|
# Test with unsafe velocity
|
|
unsafe_obs = safe_obs.copy()
|
|
unsafe_obs["dq"] = np.zeros(robot_model.num_dofs)
|
|
# Set left shoulder pitch velocity to exceed limit
|
|
left_shoulder_idx = robot_model.dof_index("left_shoulder_pitch_joint")
|
|
unsafe_obs["dq"][left_shoulder_idx] = 6.0 # Exceeds 5.0 rad/s limit
|
|
|
|
print("\nUnsafe velocity test:")
|
|
result = safety_monitor.handle_violations(unsafe_obs, safe_action)
|
|
print(
|
|
f" safe_to_continue={result['safe_to_continue']}, shutdown_required={result['shutdown_required']}"
|
|
)
|
|
|
|
# Test with unsafe position only
|
|
unsafe_pos_obs = safe_obs.copy()
|
|
unsafe_pos_obs["q"] = np.zeros(robot_model.num_dofs)
|
|
# Set left shoulder pitch position to exceed limit
|
|
unsafe_pos_obs["q"][left_shoulder_idx] = -4.0 # Exceeds lower limit of -3.089
|
|
|
|
print("\nUnsafe position test:")
|
|
result = safety_monitor.handle_violations(unsafe_pos_obs, safe_action)
|
|
print(
|
|
f" safe_to_continue={result['safe_to_continue']}, shutdown_required={result['shutdown_required']}"
|
|
)
|
|
|
|
print("\nAll tests completed successfully!")
|
|
|
|
except Exception as e:
|
|
print(f"Test failed with error: {e}")
|
|
import traceback
|
|
|
|
traceback.print_exc()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|