|
|
@ -183,7 +183,7 @@ class SyncEnv(gym.Env): |
|
|
# Add state keys for model input |
|
|
# Add state keys for model input |
|
|
obs = prepare_observation_for_eval(self.robot_model, obs) |
|
|
obs = prepare_observation_for_eval(self.robot_model, obs) |
|
|
|
|
|
|
|
|
obs["language.language_instruction"] = raw_obs["language.language_instruction"] |
|
|
|
|
|
|
|
|
obs["annotation.human.task_description"] = raw_obs["language.language_instruction"] |
|
|
|
|
|
|
|
|
if hasattr(self.base_env, "get_privileged_obs_keys"): |
|
|
if hasattr(self.base_env, "get_privileged_obs_keys"): |
|
|
for key in self.base_env.get_privileged_obs_keys(): |
|
|
for key in self.base_env.get_privileged_obs_keys(): |
|
|
@ -205,7 +205,7 @@ class SyncEnv(gym.Env): |
|
|
def get_observation(self): |
|
|
def get_observation(self): |
|
|
return self.base_env._get_observations() # assumes base env is robosuite |
|
|
return self.base_env._get_observations() # assumes base env is robosuite |
|
|
|
|
|
|
|
|
def get_step_info(self) -> Dict[str, any]: |
|
|
|
|
|
|
|
|
def get_step_info(self) -> Tuple[Dict[str, any], float, bool, bool, Dict[str, any]]: |
|
|
return ( |
|
|
return ( |
|
|
self.observe(), |
|
|
self.observe(), |
|
|
self.cache["reward"], |
|
|
self.cache["reward"], |
|
|
@ -319,7 +319,7 @@ class SyncEnv(gym.Env): |
|
|
|
|
|
|
|
|
obs_space = prepare_gym_space_for_eval(self.robot_model, obs_space) |
|
|
obs_space = prepare_gym_space_for_eval(self.robot_model, obs_space) |
|
|
|
|
|
|
|
|
obs_space["language.language_instruction"] = gym.spaces.Text( |
|
|
|
|
|
|
|
|
obs_space["annotation.human.task_description"] = gym.spaces.Text( |
|
|
max_length=256, charset=ALLOWED_LANGUAGE_CHARSET |
|
|
max_length=256, charset=ALLOWED_LANGUAGE_CHARSET |
|
|
) |
|
|
) |
|
|
|
|
|
|
|
|
|