You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

522 lines
18 KiB

import json
from pathlib import Path
import shutil
import tempfile
import time
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
import numpy as np
import pytest
from gr00t_wbc.data.exporter import DataCollectionInfo, Gr00tDataExporter
@pytest.fixture
def test_features():
"""Fixture providing test features dict."""
return {
"observation.images.ego_view": {
"dtype": "video",
"shape": [64, 64, 3], # Small images for faster tests
"names": ["height", "width", "channel"],
},
"observation.state": {
"dtype": "float32",
"shape": (8,),
"names": ["x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8"],
},
"action": {
"dtype": "float32",
"shape": (8,),
"names": ["a1", "a2", "a3", "a4", "a5", "a6", "a7", "a8"],
},
}
@pytest.fixture
def test_modality_config():
return {
"state": {"feature1": {"start": 0, "end": 4}, "feature2": {"start": 4, "end": 9}},
"action": {"feature1": {"start": 0, "end": 4}, "feature2": {"start": 4, "end": 9}},
"video": {"rs_view": {"original_key": "observation.images.ego_view"}},
"annotation": {"human.task_description": {"original_key": "task_index"}},
}
@pytest.fixture
def test_data_collection_info():
return DataCollectionInfo(
teleoperator_username="test_user",
support_operator_username="test_user",
robot_type="test_robot",
lower_body_policy="test_policy",
wbc_model_path="test_path",
)
def get_test_frame(step: int):
"""Generate a test frame with data that varies by step."""
# Create a simple, small image that will encode quickly
img = np.ones((64, 64, 3), dtype=np.uint8) * (step % 255)
# Add a pattern to make each frame unique and verifiable
img[step % 64, :, :] = 255 - (step % 255)
return {
"observation.images.ego_view": img,
"observation.state": np.ones(8, dtype=np.float32) * step,
"action": np.ones(8, dtype=np.float32) * step,
}
@pytest.fixture
def temp_dir():
"""Create a temporary directory for test data that's cleaned up after tests."""
temp_dir = tempfile.mkdtemp()
yield Path(temp_dir) / "dataset"
shutil.rmtree(temp_dir)
class TestInterruptAndResume:
"""Test class for simulating interruption and resumption of recording."""
# Skip this test if ffmpeg is not installed
@pytest.mark.skipif(
shutil.which("ffmpeg") is None, reason="ffmpeg not installed, skipping test"
)
def test_interrupted_mid_episode(
self, temp_dir, test_features, test_modality_config, test_data_collection_info
):
"""
Test that simulates a recording session that gets interrupted and then resumes.
This test uses the actual Gr00tDataExporter implementation with no mocks.
"""
# Constants for the test
NUM_EPISODES = 2
FRAMES_PER_EPISODE = 5
# Pick a random episode and frame to interrupt at
interrupt_episode = 1
interrupt_frame = 3
print(f"Will interrupt at episode {interrupt_episode}, frame {interrupt_frame}")
# Track what we've added to verify later
completed_episodes = []
frames_added_first_session = 0
# Initial recording session
try:
# Start recording with real Gr00tDataExporter
exporter1 = Gr00tDataExporter.create(
save_root=temp_dir,
fps=30,
features=test_features,
modality_config=test_modality_config,
task="test_task",
robot_type="test_robot",
vcodec="libx264", # Use a common codec that should be available
data_collection_info=test_data_collection_info,
)
# Record episodes until interruption
for episode in range(NUM_EPISODES):
for frame in range(FRAMES_PER_EPISODE):
# Simulate interruption
if episode == interrupt_episode and frame == interrupt_frame:
print(f"Simulating interruption at episode {episode}, frame {frame}")
raise KeyboardInterrupt("Simulated interruption")
# Add frame
exporter1.add_frame(get_test_frame(frame))
frames_added_first_session += 1
# Save episode
exporter1.save_episode()
completed_episodes.append(episode)
except KeyboardInterrupt:
print(f"Recording interrupted at episode {interrupt_episode}, frame {interrupt_frame}")
print(f"Completed episodes: {completed_episodes}")
# Don't consolidate since we're interrupted
pass
# Verify what was recorded before interruption
assert len(completed_episodes) == interrupt_episode
assert (
frames_added_first_session == interrupt_episode * FRAMES_PER_EPISODE + interrupt_frame
)
# Let file system operations complete
time.sleep(0.5)
# Resume recording - create a new exporter pointing to the same directory
exporter2 = Gr00tDataExporter.create(
save_root=temp_dir,
fps=30,
features=test_features,
modality_config=test_modality_config,
task="test_task",
robot_type="test_robot",
vcodec="libx264",
)
# The interrupted episode had frames added but wasn't saved
# In a real scenario with the current implementation, we need to restart that episode
# Record all episodes from the beginning
frames_added_second_session = 0
episodes_saved_second_session = 0
for episode in range(NUM_EPISODES):
for frame in range(FRAMES_PER_EPISODE):
exporter2.add_frame(get_test_frame(frame))
frames_added_second_session += 1
# Save episode
exporter2.save_episode()
episodes_saved_second_session += 1
# Verify the result
assert frames_added_second_session == NUM_EPISODES * FRAMES_PER_EPISODE
assert episodes_saved_second_session == NUM_EPISODES
# Verify actual files were created
for episode_idx in range(NUM_EPISODES):
video_path = exporter2.root / exporter2.meta.get_video_file_path(
episode_idx, "observation.images.ego_view"
)
assert video_path.exists(), f"Video file not found: {video_path}"
@pytest.mark.skipif(
shutil.which("ffmpeg") is None, reason="ffmpeg not installed, skipping test"
)
def test_interrupted_after_episode_completion(
self, temp_dir, test_features, test_modality_config, test_data_collection_info
):
"""
Test specifically for the case when interruption happens after an episode is completed.
Uses the real Gr00tDataExporter implementation.
"""
# First session - record 1 complete episode and then interrupt
exporter1 = Gr00tDataExporter.create(
save_root=temp_dir,
fps=30,
features=test_features,
modality_config=test_modality_config,
task="test_task",
data_collection_info=test_data_collection_info,
vcodec="libx264",
)
# Record 1 complete episode
for frame in range(5):
exporter1.add_frame(get_test_frame(frame))
exporter1.save_episode()
# Let file system operations complete
time.sleep(0.5)
# Verify the first episode was saved
video_path = exporter1.root / exporter1.meta.get_video_file_path(
0, "observation.images.ego_view"
)
assert video_path.exists(), f"First episode video file not found: {video_path}"
# Second session - resume and record another episode
exporter2 = Gr00tDataExporter.create(
save_root=temp_dir,
fps=30,
features=test_features,
modality_config=test_modality_config,
task="test_task",
vcodec="libx264",
)
# Record the second episode
for frame in range(5):
exporter2.add_frame(get_test_frame(frame))
exporter2.save_episode()
# Verify the second episode was saved
video_path = exporter2.root / exporter2.meta.get_video_file_path(
1, "observation.images.ego_view"
)
assert video_path.exists(), f"Second episode video file not found: {video_path}"
@pytest.mark.skipif(
shutil.which("ffmpeg") is None, reason="ffmpeg not installed, skipping test"
)
def test_interrupted_no_episode_completion(
self, temp_dir, test_features, test_modality_config, test_data_collection_info
):
"""
Test specifically for the case when interruption happens in the middle of recording an episode.
Uses the real Gr00tDataExporter implementation.
"""
# First session - add some frames and interrupt before saving
exporter1 = Gr00tDataExporter.create(
save_root=temp_dir,
fps=30,
features=test_features,
modality_config=test_modality_config,
task="test_task",
data_collection_info=test_data_collection_info,
vcodec="libx264",
)
# Add 3 frames but don't save
for frame in range(3):
exporter1.add_frame(get_test_frame(frame))
# Don't save episode or consolidate to simulate interruption
# The episode buffer is only in memory and will be lost on interruption
# Let file system operations complete
time.sleep(0.5)
# Verify no episode was saved
video_path = exporter1.root / exporter1.meta.get_video_file_path(
0, "observation.images.ego_view"
)
assert not video_path.exists(), f"Episode should not have been saved: {video_path}"
# Second session - will raise an error because no meta file exist, so we can't resume
with pytest.raises(ValueError):
_ = Gr00tDataExporter.create(
save_root=temp_dir,
fps=30,
features=test_features,
modality_config=test_modality_config,
task="test_task",
vcodec="libx264",
)
@pytest.mark.skipif(shutil.which("ffmpeg") is None, reason="ffmpeg not installed, skipping test")
def test_full_workflow(temp_dir, test_features, test_modality_config, test_data_collection_info):
"""
Test that simulates the complete workflow from the record_session.py example.
"""
NUM_EPISODES = 2
FRAMES_PER_EPISODE = 3
# Create the exporter
exporter = Gr00tDataExporter.create(
save_root=temp_dir,
fps=20,
features=test_features,
modality_config=test_modality_config,
task="test_task",
data_collection_info=test_data_collection_info,
robot_type="dummy",
)
# Create a small dataset
for episode_index in range(NUM_EPISODES):
for frame_index in range(FRAMES_PER_EPISODE):
exporter.add_frame(get_test_frame(frame_index))
exporter.save_episode()
# check modality config
modality_config_path = exporter.root / "meta" / "modality.json"
assert modality_config_path.exists(), f"{modality_config_path} does not exists."
with open(modality_config_path, "rb") as f:
actual_modality_config = json.load(f)
assert (
actual_modality_config == test_modality_config
), f"Modality configs don't match.\nActual: {actual_modality_config}\nExpected: {test_modality_config}"
# Verify results
for episode_idx in range(NUM_EPISODES):
video_path = exporter.root / exporter.meta.get_video_file_path(
episode_idx, "observation.images.ego_view"
)
assert video_path.exists(), f"Video file not found: {video_path}"
# Check that the expected number of episodes exists
episode_count = 0
for path in exporter.root.glob("**/*.mp4"):
episode_count += 1
assert episode_count == NUM_EPISODES, f"Expected {NUM_EPISODES} episodes, found {episode_count}"
# Check the values of the dataset
dataset = LeRobotDataset(
repo_id="dataset",
root=temp_dir,
)
for episode_idx in range(NUM_EPISODES):
for frame_idx in range(FRAMES_PER_EPISODE):
expected_frame = get_test_frame(frame_idx)
actual_frame = dataset[episode_idx * FRAMES_PER_EPISODE + frame_idx]
print(actual_frame["observation.images.ego_view"])
actual_image_frame = actual_frame["observation.images.ego_view"].permute(1, 2, 0) * 255
assert np.allclose(
actual_image_frame.numpy(), expected_frame["observation.images.ego_view"], atol=10
) # Allow some tolerance for video compression
assert np.allclose(
actual_frame["observation.state"], expected_frame["observation.state"]
)
assert np.allclose(actual_frame["action"], expected_frame["action"])
# validate data_collection_info
assert dataset.meta.info["data_collection_info"] == test_data_collection_info.to_dict()
@pytest.mark.skipif(shutil.which("ffmpeg") is None, reason="ffmpeg not installed, skipping test")
def test_overwrite_existing_dataset_false(
temp_dir, test_features, test_modality_config, test_data_collection_info
):
"""
Test that appends to the existing dataset when overwrite_existing is set to false.
"""
# first dataset
FIRST_NUM_EPISODES = 2
FIRST_FRAMES_PER_EPISODE = 3
exporter = Gr00tDataExporter.create(
save_root=temp_dir,
fps=20,
features=test_features,
modality_config=test_modality_config,
task="test_task",
data_collection_info=test_data_collection_info,
robot_type="dummy",
)
# !! `overwrite_existing` should always be set to false by default
# So we're deliberately not setting the overwrite_existing argument here.
# This test ensures that
# i. the default behavior is overwrite_existing=False
# ii. the dataset appends to the existing dataset (instead of overwriting)
# Create a first dataset
for episode_index in range(FIRST_NUM_EPISODES):
for frame_index in range(FIRST_FRAMES_PER_EPISODE):
exporter.add_frame(get_test_frame(frame_index))
exporter.save_episode()
# second dataset
del exporter
SECOND_NUM_EPISODES = 3
SECOND_FRAMES_PER_EPISODE = 2
exporter = Gr00tDataExporter.create(
save_root=temp_dir,
fps=20,
features=test_features,
modality_config=test_modality_config,
task="test_task",
robot_type="dummy",
)
for episode_index in range(SECOND_NUM_EPISODES):
for frame_index in range(SECOND_FRAMES_PER_EPISODE):
exporter.add_frame(get_test_frame(frame_index))
exporter.save_episode()
# verify that there are
EXPECTED_NUM_EPISODES = FIRST_NUM_EPISODES + SECOND_NUM_EPISODES
assert len(list(exporter.root.glob("**/*.mp4"))) == EXPECTED_NUM_EPISODES
assert len(list(exporter.root.glob("**/*.parquet"))) == EXPECTED_NUM_EPISODES
def test_overwrite_existing_dataset_true(
temp_dir, test_features, test_modality_config, test_data_collection_info
):
"""
Test that overwrites to an existing dataset when overwrite_existing=True.
"""
# first dataset
FIRST_NUM_EPISODES = 2
FIRST_FRAMES_PER_EPISODE = 3
exporter = Gr00tDataExporter.create(
save_root=temp_dir,
fps=20,
features=test_features,
modality_config=test_modality_config,
task="test_task",
data_collection_info=test_data_collection_info,
robot_type="dummy",
)
# Create a first dataset
for episode_index in range(FIRST_NUM_EPISODES):
for frame_index in range(FIRST_FRAMES_PER_EPISODE):
exporter.add_frame(get_test_frame(frame_index))
exporter.save_episode()
# verify that the dataset is written to the disk
assert len(list(exporter.root.glob("**/*.mp4"))) == FIRST_NUM_EPISODES
assert len(list(exporter.root.glob("**/*.parquet"))) == FIRST_NUM_EPISODES
# second dataset
SECOND_NUM_EPISODES = 3
SECOND_FRAMES_PER_EPISODE = 2
# re-initialize the exporter
del exporter
exporter = Gr00tDataExporter.create(
save_root=temp_dir,
fps=20,
features=test_features,
modality_config=test_modality_config,
task="test_task",
data_collection_info=test_data_collection_info,
robot_type="dummy",
overwrite_existing=True,
)
for episode_index in range(SECOND_NUM_EPISODES):
for frame_index in range(SECOND_FRAMES_PER_EPISODE):
exporter.add_frame(get_test_frame(frame_index))
exporter.save_episode()
# verify that the dataset is overwritten
assert len(list(exporter.root.glob("**/*.mp4"))) == SECOND_NUM_EPISODES
assert len(list(exporter.root.glob("**/*.parquet"))) == SECOND_NUM_EPISODES
def test_save_episode_as_discarded_and_skip(
temp_dir, test_features, test_modality_config, test_data_collection_info
):
"""
Test that verifies the functionality of saving an episode as discarded and skipping an episode.
"""
FIRST_NUM_EPISODES = 10
FIRST_FRAMES_PER_EPISODE = 3
exporter = Gr00tDataExporter.create(
save_root=temp_dir,
fps=20,
features=test_features,
modality_config=test_modality_config,
task="test_task",
data_collection_info=test_data_collection_info,
robot_type="dummy",
)
# Create a first dataset
saved_episodes = 0
discarded_episode_indices = []
for episode_index in range(FIRST_NUM_EPISODES):
for frame_index in range(FIRST_FRAMES_PER_EPISODE):
exporter.add_frame(get_test_frame(frame_index))
if episode_index % 3 == 0:
exporter.save_episode_as_discarded()
discarded_episode_indices.append(saved_episodes)
saved_episodes += 1
elif episode_index % 3 == 1:
exporter.skip_and_start_new_episode()
else:
exporter.save_episode()
saved_episodes += 1
# verify that the dataset is written to the disk
assert len(list(exporter.root.glob("**/*.mp4"))) == saved_episodes
assert len(list(exporter.root.glob("**/*.parquet"))) == saved_episodes
dataset = LeRobotDataset(
repo_id="dataset",
root=temp_dir,
)
assert dataset.meta.info["discarded_episode_indices"] == discarded_episode_indices