You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
276 lines
10 KiB
276 lines
10 KiB
import os
|
|
import time
|
|
import xml.etree.ElementTree as ET
|
|
|
|
import mujoco
|
|
import numpy as np
|
|
import robosuite
|
|
import robosuite.utils.transform_utils as T
|
|
from robosuite.models.objects import MujocoXMLObject
|
|
from robosuite.utils.mjcf_utils import array_to_string, string_to_array
|
|
from robosuite.environments.robot_env import RobotEnv
|
|
|
|
|
|
class MJCFObject(MujocoXMLObject):
|
|
"""
|
|
Blender object with support for changing the scaling
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
name,
|
|
mjcf_path,
|
|
scale=1.0,
|
|
solimp=(0.998, 0.998, 0.001),
|
|
solref=(0.001, 1),
|
|
density=100,
|
|
friction=(0.95, 0.3, 0.1),
|
|
margin=None,
|
|
rgba=None,
|
|
priority=None,
|
|
static=False,
|
|
):
|
|
# get scale in x, y, z
|
|
if isinstance(scale, float):
|
|
scale = [scale, scale, scale]
|
|
elif isinstance(scale, tuple) or isinstance(scale, list):
|
|
assert len(scale) == 3
|
|
scale = tuple(scale)
|
|
else:
|
|
raise Exception("got invalid scale: {}".format(scale))
|
|
scale = np.array(scale)
|
|
|
|
self.solimp = solimp
|
|
self.solref = solref
|
|
self.density = density
|
|
self.friction = friction
|
|
self.margin = margin
|
|
|
|
self.priority = priority
|
|
|
|
self.rgba = rgba
|
|
|
|
# read default xml
|
|
xml_path = mjcf_path
|
|
folder = os.path.dirname(xml_path)
|
|
tree = ET.parse(xml_path)
|
|
root = tree.getroot()
|
|
|
|
# write modified xml (and make sure to postprocess any paths just in case)
|
|
xml_str = ET.tostring(root, encoding="utf8").decode("utf8")
|
|
xml_str = self.postprocess_model_xml(xml_str)
|
|
time_str = str(time.time()).replace(".", "_")
|
|
new_xml_path = os.path.join(folder, "{}_{}.xml".format(time_str, os.getpid()))
|
|
f = open(new_xml_path, "w")
|
|
f.write(xml_str)
|
|
f.close()
|
|
|
|
# initialize object with new xml we wrote
|
|
super().__init__(
|
|
fname=new_xml_path,
|
|
name=name,
|
|
joints=None if static else [dict(type="free", damping="0.0005")],
|
|
obj_type="all",
|
|
duplicate_collision_geoms=False,
|
|
scale=scale,
|
|
)
|
|
|
|
self.spawns = []
|
|
self._disabled_spawns = set()
|
|
for s in self.worldbody.findall("./body/body/geom"):
|
|
geom_name = s.get("name")
|
|
if geom_name and geom_name.startswith("{}spawn_".format(self.naming_prefix)):
|
|
self.spawns.append(s)
|
|
|
|
# clean up xml - we don't need it anymore
|
|
if os.path.exists(new_xml_path):
|
|
os.remove(new_xml_path)
|
|
|
|
def postprocess_model_xml(self, xml_str):
|
|
"""
|
|
New version of postprocess model xml that only replaces robosuite file paths if necessary (otherwise
|
|
there is an error with the "max" operation)
|
|
"""
|
|
|
|
path = os.path.split(robosuite.__file__)[0]
|
|
path_split = path.split("/")
|
|
|
|
# replace mesh and texture file paths
|
|
tree = ET.fromstring(xml_str)
|
|
root = tree
|
|
asset = root.find("asset")
|
|
meshes = asset.findall("mesh")
|
|
textures = asset.findall("texture")
|
|
all_elements = meshes + textures
|
|
|
|
for elem in all_elements:
|
|
old_path = elem.get("file")
|
|
if old_path is None:
|
|
continue
|
|
|
|
old_path_split = old_path.split("/")
|
|
# maybe replace all paths to robosuite assets
|
|
check_lst = [loc for loc, val in enumerate(old_path_split) if val == "robosuite"]
|
|
if len(check_lst) > 0:
|
|
ind = max(check_lst) # last occurrence index
|
|
new_path_split = path_split + old_path_split[ind + 1 :]
|
|
new_path = "/".join(new_path_split)
|
|
elem.set("file", new_path)
|
|
|
|
return ET.tostring(root, encoding="utf8").decode("utf8")
|
|
|
|
def _get_geoms(self, root, _parent=None):
|
|
"""
|
|
Helper function to recursively search through element tree starting at @root and returns
|
|
a list of (parent, child) tuples where the child is a geom element
|
|
|
|
Args:
|
|
root (ET.Element): Root of xml element tree to start recursively searching through
|
|
|
|
_parent (ET.Element): Parent of the root element tree. Should not be used externally; only set
|
|
during the recursive call
|
|
|
|
Returns:
|
|
list: array of (parent, child) tuples where the child element is a geom type
|
|
"""
|
|
geom_pairs = super(MJCFObject, self)._get_geoms(root=root, _parent=_parent)
|
|
|
|
# modify geoms according to the attributes
|
|
for i, (parent, element) in enumerate(geom_pairs):
|
|
element.set("solref", array_to_string(self.solref))
|
|
element.set("solimp", array_to_string(self.solimp))
|
|
element.set("density", str(self.density))
|
|
element.set("friction", array_to_string(self.friction))
|
|
if self.margin is not None:
|
|
element.set("margin", str(self.margin))
|
|
|
|
if (self.rgba is not None) and (element.get("group") == "1"):
|
|
element.set("rgba", array_to_string(self.rgba))
|
|
|
|
if self.priority is not None:
|
|
# set high priorit
|
|
element.set("priority", str(self.priority))
|
|
|
|
return geom_pairs
|
|
|
|
def get_joint(self, joint_name: str):
|
|
_, _, joints = self._get_elements_by_name(
|
|
geom_names=[], body_names=[], joint_names=[joint_name]
|
|
)
|
|
return joints[joint_name]
|
|
|
|
@property
|
|
def horizontal_radius(self):
|
|
horizontal_radius_site = self.worldbody.find(
|
|
"./body/site[@name='{}horizontal_radius_site']".format(self.naming_prefix)
|
|
)
|
|
site_values = string_to_array(horizontal_radius_site.get("pos"))
|
|
return np.linalg.norm(site_values[0:2])
|
|
|
|
def get_bbox_points(self, trans=None, rot=None) -> list[np.ndarray]:
|
|
"""
|
|
Get the full 8 bounding box points of the object
|
|
rot: a rotation matrix
|
|
"""
|
|
bottom_offset = self.bottom_offset
|
|
top_offset = self.top_offset
|
|
horizontal_radius_site = self.worldbody.find(
|
|
"./body/site[@name='{}horizontal_radius_site']".format(self.naming_prefix)
|
|
)
|
|
horiz_radius = string_to_array(horizontal_radius_site.get("pos"))[:2]
|
|
return self._get_bbox_points(
|
|
bottom_offset=bottom_offset,
|
|
top_offset=top_offset,
|
|
radius=horiz_radius,
|
|
trans=trans,
|
|
rot=rot,
|
|
)
|
|
|
|
@staticmethod
|
|
def _get_bbox_points(
|
|
bottom_offset, top_offset, radius, trans=None, rot=None
|
|
) -> list[np.ndarray]:
|
|
"""
|
|
Helper function to get the full 8 bounding box points of the object.
|
|
"""
|
|
center = np.mean([bottom_offset, top_offset], axis=0)
|
|
half_size = [radius[0], radius[1], top_offset[2] - center[2]]
|
|
|
|
bbox_offsets = [
|
|
center + half_size * np.array([-1, -1, -1]), # p0
|
|
center + half_size * np.array([1, -1, -1]), # px
|
|
center + half_size * np.array([-1, 1, -1]), # py
|
|
center + half_size * np.array([-1, -1, 1]), # pz
|
|
center + half_size * np.array([1, 1, 1]),
|
|
center + half_size * np.array([-1, 1, 1]),
|
|
center + half_size * np.array([1, -1, 1]),
|
|
center + half_size * np.array([1, 1, -1]),
|
|
]
|
|
|
|
if trans is None:
|
|
trans = np.array([0, 0, 0])
|
|
if rot is not None:
|
|
rot = T.quat2mat(rot)
|
|
else:
|
|
rot = np.eye(3)
|
|
|
|
points = [(np.matmul(rot, p) + trans) for p in bbox_offsets]
|
|
return points
|
|
|
|
@staticmethod
|
|
def get_spawn_bottom_offset(site: ET.Element) -> np.array:
|
|
"""
|
|
Get bottom offset of the spawn zone.
|
|
"""
|
|
site_pos = string_to_array(site.get("pos"))
|
|
site_size = string_to_array(site.get("size")) if site.get("type") == "box" else np.zeros(3)
|
|
return site_pos - np.array([0, 0, site_size[-1]])
|
|
|
|
def get_random_spawn(self, rng, exclude_disabled: bool = False) -> tuple[int, ET.Element]:
|
|
"""
|
|
Get random spawn site.
|
|
"""
|
|
options = [o for o in range(0, len(self.spawns))]
|
|
if exclude_disabled:
|
|
options = [o for o in options if o not in self._disabled_spawns]
|
|
spawn_id = rng.choice(options)
|
|
return spawn_id, self.spawns[spawn_id]
|
|
|
|
def set_spawn_active(self, spawn_id: int, active: bool):
|
|
"""
|
|
Update the activity state of a spawn site. Disabled sites are excluded from random sampling.
|
|
"""
|
|
if active and spawn_id in self._disabled_spawns:
|
|
self._disabled_spawns.remove(spawn_id)
|
|
elif not active:
|
|
self._disabled_spawns.add(spawn_id)
|
|
|
|
def closest_spawn_id(self, env: RobotEnv, obj: "MJCFObject", max_distance: float = 1.0) -> int:
|
|
if len(self.spawns) == 0:
|
|
return -1
|
|
if not env.check_contact(self, obj):
|
|
return -1
|
|
obj_pos = env.sim.data.body_xpos[env.sim.model.body_name2id(obj.root_body)]
|
|
distances = []
|
|
for spawn_id in range(len(self.spawns)):
|
|
spawn_pos = env.sim.data.get_geom_xpos(self.spawns[spawn_id].get("name"))
|
|
distance = np.linalg.norm(spawn_pos - obj_pos)
|
|
distances.append((spawn_id, distance))
|
|
distances = sorted(distances, key=lambda item: item[1])
|
|
obj_geom_ids = [env.sim.model.geom_name2id(g) for g in obj.contact_geoms]
|
|
for spawn_id, distance in distances:
|
|
spawn_geom_id = env.sim.model.geom_name2id(self.spawns[spawn_id].get("name"))
|
|
for obj_geom_id in obj_geom_ids:
|
|
real_distance = mujoco.mj_geomDistance(
|
|
m=env.sim.model._model,
|
|
d=env.sim.data._data,
|
|
geom1=spawn_geom_id,
|
|
geom2=obj_geom_id,
|
|
distmax=max_distance,
|
|
fromto=None,
|
|
)
|
|
if real_distance <= 0:
|
|
return spawn_id
|
|
if real_distance >= max_distance:
|
|
return -1
|
|
return -1
|