|
|
|
@ -39,9 +39,11 @@ class EpisodeWriter(): |
|
|
|
self.data_info() |
|
|
|
self.text_desc() |
|
|
|
|
|
|
|
self.is_available = True # Indicates whether the class is available for new operations |
|
|
|
# Initialize the queue and worker thread |
|
|
|
self.item_data_queue = Queue(maxsize=100) |
|
|
|
self.stop_worker = False |
|
|
|
self.need_save = False # Flag to indicate when save_episode is triggered |
|
|
|
self.worker_thread = Thread(target=self.process_queue) |
|
|
|
self.worker_thread.start() |
|
|
|
|
|
|
|
@ -78,13 +80,17 @@ class EpisodeWriter(): |
|
|
|
|
|
|
|
def create_episode(self): |
|
|
|
""" |
|
|
|
Create a new episode, each episode needs to specify the episode_id. |
|
|
|
text: Text descriptions of operation goals, steps, etc. The text description of each episode is the same. |
|
|
|
goal: operation goal |
|
|
|
desc: description |
|
|
|
steps: operation steps |
|
|
|
Create a new episode. |
|
|
|
Returns: |
|
|
|
bool: True if the episode is successfully created, False otherwise. |
|
|
|
Note: |
|
|
|
Once successfully created, this function will only be available again after save_episode complete its save task. |
|
|
|
""" |
|
|
|
if not self.is_available: |
|
|
|
print("==> The class is currently unavailable for new operations. Please wait until ongoing tasks are completed.") |
|
|
|
return False # Return False if the class is unavailable |
|
|
|
|
|
|
|
# Reset episode-related data and create necessary directories |
|
|
|
self.item_id = -1 |
|
|
|
self.episode_data = [] |
|
|
|
self.episode_id = self.episode_id + 1 |
|
|
|
@ -101,6 +107,10 @@ class EpisodeWriter(): |
|
|
|
if self.rerun_log: |
|
|
|
self.online_logger = RerunLogger(prefix="online/", IdxRangeBoundary = 60, memory_limit="300MB") |
|
|
|
|
|
|
|
self.is_available = False # After the episode is created, the class is marked as unavailable until the episode is successfully saved |
|
|
|
print(f"==> New episode created: {self.episode_dir}") |
|
|
|
return True # Return True if the episode is successfully created |
|
|
|
|
|
|
|
def add_item(self, colors, depths=None, states=None, actions=None, tactiles=None, audios=None): |
|
|
|
# Increment the item ID |
|
|
|
self.item_id += 1 |
|
|
|
@ -119,6 +129,7 @@ class EpisodeWriter(): |
|
|
|
|
|
|
|
def process_queue(self): |
|
|
|
while not self.stop_worker or not self.item_data_queue.empty(): |
|
|
|
# Process items in the queue |
|
|
|
try: |
|
|
|
item_data = self.item_data_queue.get(timeout=1) |
|
|
|
try: |
|
|
|
@ -127,7 +138,11 @@ class EpisodeWriter(): |
|
|
|
print(f"Error processing item_data (idx={item_data['idx']}): {e}") |
|
|
|
self.item_data_queue.task_done() |
|
|
|
except Empty: |
|
|
|
continue |
|
|
|
pass |
|
|
|
|
|
|
|
# Check if save_episode was triggered |
|
|
|
if self.need_save and self.item_data_queue.empty(): |
|
|
|
self._save_episode() |
|
|
|
|
|
|
|
def _process_item_data(self, item_data): |
|
|
|
idx = item_data['idx'] |
|
|
|
@ -169,20 +184,32 @@ class EpisodeWriter(): |
|
|
|
|
|
|
|
def save_episode(self): |
|
|
|
""" |
|
|
|
with open("./hmm.json",'r',encoding='utf-8') as json_file: |
|
|
|
model=json.load(json_file) |
|
|
|
Trigger the save operation. This sets the save flag, and the process_queue thread will handle it. |
|
|
|
""" |
|
|
|
self.need_save = True # Set the save flag |
|
|
|
print(f"==> Episode saved start...") |
|
|
|
|
|
|
|
def _save_episode(self): |
|
|
|
""" |
|
|
|
Save the episode data to a JSON file. |
|
|
|
""" |
|
|
|
# Wait for the queue to be processed |
|
|
|
self.item_data_queue.join() |
|
|
|
# save |
|
|
|
self.data['info'] = self.info |
|
|
|
self.data['text'] = self.text |
|
|
|
self.data['data'] = self.episode_data |
|
|
|
with open(self.json_path, 'w', encoding='utf-8') as jsonf: |
|
|
|
jsonf.write(json.dumps(self.data, indent=4, ensure_ascii=False)) |
|
|
|
self.need_save = False # Reset the save flag |
|
|
|
self.is_available = True # Mark the class as available after saving |
|
|
|
print(f"==> Episode saved successfully to {self.json_path}.") |
|
|
|
|
|
|
|
def close(self): |
|
|
|
""" |
|
|
|
Stop the worker thread and ensure all tasks are completed. |
|
|
|
""" |
|
|
|
self.item_data_queue.join() |
|
|
|
# Signal the worker thread to stop and join the thread |
|
|
|
if not self.is_available: # If self.is_available is False, it means there is still data not saved. |
|
|
|
self.save_episode() |
|
|
|
while not self.is_available: |
|
|
|
time.sleep(0.01) |
|
|
|
self.stop_worker = True |
|
|
|
self.worker_thread.join() |