You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

297 lines
11 KiB

import numbers
import time as time_module
from typing import Any, Dict, Optional, Union
import gymnasium as gym
import numpy as np
import scipy.interpolate as si
from decoupled_wbc.control.base.policy import Policy
class InterpolationPolicy(Policy):
def __init__(
self,
init_time: float,
init_values: dict[str, np.ndarray],
max_change_rate: float,
):
"""
Args:
init_time: The time of recording the initial values.
init_values: The initial values of the features.
The keys are the names of the features, and the values
are the initial values of the features (1D array).
max_change_rate: The maximum change rate.
"""
super().__init__()
self.last_action = init_values # Vecs are 1D arrays
self.concat_order = sorted(init_values.keys())
self.concat_dims = []
for key in self.concat_order:
vec = np.array(init_values[key])
if vec.ndim == 2 and vec.shape[0] == 1:
vec = vec[0]
init_values[key] = vec
assert vec.ndim == 1, f"The shape of {key} should be (D,). Got {vec.shape}."
self.concat_dims.append(vec.shape[0])
self.init_values_concat = self._concat_vecs(init_values, 1)
self.max_change_rate = max_change_rate
self.reset(init_time)
def reset(self, init_time: float = time_module.monotonic()):
self.interp = PoseTrajectoryInterpolator(np.array([init_time]), self.init_values_concat)
self.last_waypoint_time = init_time
self.max_change_rate = self.max_change_rate
def _concat_vecs(self, values: dict[str, np.ndarray], length: int) -> np.ndarray:
"""
Concatenate the vectors into a 2D array to be used for interpolation.
Args:
values: The values to concatenate.
length: The length of the concatenated vectors (time dimension).
Returns:
The concatenated vectors (T, D) arrays.
"""
concat_vecs = []
for key in self.concat_order:
if key in values:
vec = np.array(values[key])
if vec.ndim == 1:
# If the vector is 1D, tile it to the length of the time dimension
vec = np.tile(vec, (length, 1))
assert vec.ndim == 2, f"The shape of {key} should be (T, D). Got {vec.shape}."
concat_vecs.append(vec)
else:
# If the vector is not in the values, use the last action
# Since the last action is 1D, we need to tile it to the length of the time dimension
concat_vecs.append(np.tile(self.last_action[key], (length, 1)))
return np.concatenate(concat_vecs, axis=1) # Vecs are 2D (T, D) arrays
def _unconcat_vecs(self, concat_vec: np.ndarray) -> dict[str, np.ndarray]:
curr_idx = 0
action = {}
assert (
concat_vec.ndim == 1
), f"The shape of the concatenated vectors should be (T, D). Got {concat_vec.shape}."
for key, dim in zip(self.concat_order, self.concat_dims):
action[key] = concat_vec[curr_idx : curr_idx + dim]
curr_idx += dim
return action # Vecs are 1D arrays
def __call__(
self, observation: Dict[str, Any], goal: Dict[str, Any], time: float
) -> Dict[str, np.ndarray]:
raise NotImplementedError(
"`InterpolationPolicy` accepts goal and provide action in two separate methods."
)
def set_goal(self, goal: Dict[str, Any]) -> None:
if "target_time" not in goal:
return
assert (
"interpolation_garbage_collection_time" in goal
), "`interpolation_garbage_collection_time` is required."
target_time = goal.pop("target_time")
interpolation_garbage_collection_time = goal.pop("interpolation_garbage_collection_time")
if isinstance(target_time, list):
for key, vec in goal.items():
assert isinstance(vec, list)
assert len(vec) == len(target_time), (
f"The length of {key} and `target_time` should be the same. "
f"Got {len(vec)} and {len(target_time)}."
)
else:
target_time = [target_time]
for key in goal:
goal[key] = [goal[key]]
# Concatenate all vectors in goal
concat_vecs = self._concat_vecs(goal, len(target_time))
assert concat_vecs.shape[0] == len(target_time), (
f"The length of the concatenated goal and `target_time` should be the same. "
f"Got {concat_vecs.shape[0]} and {len(target_time)}."
)
for tt, vec in zip(target_time, concat_vecs):
if tt < interpolation_garbage_collection_time:
continue
self.interp = self.interp.schedule_waypoint(
pose=vec,
time=tt,
max_change_rate=self.max_change_rate,
interpolation_garbage_collection_time=interpolation_garbage_collection_time,
last_waypoint_time=self.last_waypoint_time,
)
self.last_waypoint_time = tt
def get_action(self, time: Optional[float] = None) -> dict[str, Any]:
"""Get the next action based on the (current) monotonic time."""
if time is None:
time = time_module.monotonic()
concat_vec = self.interp(time)
self.last_action.update(self._unconcat_vecs(concat_vec))
return self.last_action
def observation_space(self) -> gym.spaces.Dict:
"""Return the observation space."""
pass
def action_space(self) -> gym.spaces.Dict:
"""Return the action space."""
pass
def close(self) -> None:
"""Clean up resources."""
pass
class PoseTrajectoryInterpolator:
def __init__(self, times: np.ndarray, poses: np.ndarray):
assert len(times) >= 1
assert len(poses) == len(times)
times = np.asarray(times)
poses = np.asarray(poses)
self.num_joint = len(poses[0])
if len(times) == 1:
# special treatment for single step interpolation
self.single_step = True
self._times = times
self._poses = poses
else:
self.single_step = False
assert np.all(times[1:] >= times[:-1])
self.pose_interp = si.interp1d(times, poses, axis=0, assume_sorted=True)
@property
def times(self) -> np.ndarray:
if self.single_step:
return self._times
else:
return self.pose_interp.x
@property
def poses(self) -> np.ndarray:
if self.single_step:
return self._poses
else:
return self.pose_interp.y
def trim(self, start_t: float, end_t: float) -> "PoseTrajectoryInterpolator":
assert start_t <= end_t
times = self.times
should_keep = (start_t < times) & (times < end_t)
keep_times = times[should_keep]
all_times = np.concatenate([[start_t], keep_times, [end_t]])
# remove duplicates, Slerp requires strictly increasing x
all_times = np.unique(all_times)
# interpolate
all_poses = self(all_times)
return PoseTrajectoryInterpolator(times=all_times, poses=all_poses)
def schedule_waypoint(
self,
pose,
time,
max_change_rate=np.inf,
interpolation_garbage_collection_time=None,
last_waypoint_time=None,
) -> "PoseTrajectoryInterpolator":
if not isinstance(max_change_rate, np.ndarray):
max_change_rate = np.array([max_change_rate] * self.num_joint)
assert len(max_change_rate) == self.num_joint
assert np.max(max_change_rate) > 0
if last_waypoint_time is not None:
assert interpolation_garbage_collection_time is not None
# trim current interpolator to between interpolation_garbage_collection_time and last_waypoint_time
start_time = self.times[0]
end_time = self.times[-1]
assert start_time <= end_time
if interpolation_garbage_collection_time is not None:
if time <= interpolation_garbage_collection_time:
# if insert time is earlier than current time
# no effect should be done to the interpolator
return self
# now, interpolation_garbage_collection_time < time
start_time = max(interpolation_garbage_collection_time, start_time)
if last_waypoint_time is not None:
# if last_waypoint_time is earlier than start_time
# use start_time
if time <= last_waypoint_time:
end_time = interpolation_garbage_collection_time
else:
end_time = max(last_waypoint_time, interpolation_garbage_collection_time)
else:
end_time = interpolation_garbage_collection_time
end_time = min(end_time, time)
start_time = min(start_time, end_time)
# end time should be the latest of all times except time
# after this we can assume order (proven by zhenjia, due to the 2 min operations)
# Constraints:
# start_time <= end_time <= time (proven by zhenjia)
# interpolation_garbage_collection_time <= start_time (proven by zhenjia)
# interpolation_garbage_collection_time <= time (proven by zhenjia)
# time can't change
# last_waypoint_time can't change
# interpolation_garbage_collection_time can't change
assert start_time <= end_time
assert end_time <= time
if last_waypoint_time is not None:
if time <= last_waypoint_time:
assert end_time == interpolation_garbage_collection_time
else:
assert end_time == max(last_waypoint_time, interpolation_garbage_collection_time)
if interpolation_garbage_collection_time is not None:
assert interpolation_garbage_collection_time <= start_time
assert interpolation_garbage_collection_time <= time
trimmed_interp = self.trim(start_time, end_time)
# after this, all waypoints in trimmed_interp is within start_time and end_time
# and is earlier than time
# determine speed
duration = time - end_time
end_pose = trimmed_interp(end_time)
pose_min_duration = np.max(np.abs(end_pose - pose) / max_change_rate)
duration = max(duration, pose_min_duration)
assert duration >= 0
last_waypoint_time = end_time + duration
# insert new pose
times = np.append(trimmed_interp.times, [last_waypoint_time], axis=0)
poses = np.append(trimmed_interp.poses, [pose], axis=0)
# create new interpolator
final_interp = PoseTrajectoryInterpolator(times, poses)
return final_interp
def __call__(self, t: Union[numbers.Number, np.ndarray]) -> np.ndarray:
is_single = False
if isinstance(t, numbers.Number):
is_single = True
t = np.array([t])
pose = np.zeros((len(t), self.num_joint))
if self.single_step:
pose[:] = self._poses[0]
else:
start_time = self.times[0]
end_time = self.times[-1]
t = np.clip(t, start_time, end_time)
pose = self.pose_interp(t)
if is_single:
pose = pose[0]
return pose