You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
47 lines
1.4 KiB
47 lines
1.4 KiB
from pathlib import Path
|
|
import pickle
|
|
|
|
import numpy as np
|
|
import pytest
|
|
|
|
from gr00t_wbc.control.policy.interpolation_policy import (
|
|
InterpolationPolicy,
|
|
)
|
|
|
|
|
|
def get_test_data_path(filename: str) -> str:
|
|
"""Get the absolute path to a test data file."""
|
|
test_dir = Path(__file__).parent
|
|
return str(test_dir / ".." / ".." / ".." / "replay_data" / filename)
|
|
|
|
|
|
@pytest.fixture
|
|
def logged_data():
|
|
"""Load the logged data from file."""
|
|
data_path = get_test_data_path("interpolation_data.pkl")
|
|
with open(data_path, "rb") as f:
|
|
return pickle.load(f)
|
|
|
|
|
|
def test_replay_logged_data(logged_data):
|
|
"""Test that the wrapper produces the same pose commands as logged data."""
|
|
init_args = logged_data["init_args"]
|
|
interp = InterpolationPolicy(
|
|
init_time=init_args["curr_t"],
|
|
init_values={"target_pose": init_args["curr_pose"]},
|
|
max_change_rate=np.inf,
|
|
)
|
|
|
|
# Test all data points including the first one
|
|
for c in logged_data["calls"]:
|
|
# Get the action from wrapper
|
|
if c["type"] == "get_action":
|
|
action = interp.get_action(**c["args"])
|
|
expected_action = c["result"]
|
|
np.testing.assert_allclose(
|
|
action["target_pose"], expected_action["q"], rtol=1e-9, atol=1e-9
|
|
)
|
|
# print(action, expected_action)
|
|
|
|
else:
|
|
interp.set_goal(**c["args"])
|