You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
311 lines
13 KiB
311 lines
13 KiB
import collections
|
|
import os
|
|
import threading
|
|
import time
|
|
|
|
import mujoco
|
|
import mujoco.viewer
|
|
import numpy as np
|
|
import onnxruntime as ort
|
|
from pynput import keyboard as pkb
|
|
import torch
|
|
import yaml
|
|
|
|
|
|
class GearWbcController:
|
|
def __init__(self, config_path):
|
|
self.CONFIG_PATH = config_path
|
|
self.cmd_lock = threading.Lock()
|
|
self.config = self.load_config(os.path.join(self.CONFIG_PATH, "g1_gear_wbc.yaml"))
|
|
|
|
self.control_dict = {
|
|
"loco_cmd": self.config["cmd_init"],
|
|
"height_cmd": self.config["height_cmd"],
|
|
"rpy_cmd": self.config.get("rpy_cmd", [0.0, 0.0, 0.0]),
|
|
"freq_cmd": self.config.get("freq_cmd", 1.5),
|
|
}
|
|
|
|
self.model = mujoco.MjModel.from_xml_path(self.config["xml_path"])
|
|
self.data = mujoco.MjData(self.model)
|
|
self.model.opt.timestep = self.config["simulation_dt"]
|
|
self.n_joints = self.data.qpos.shape[0] - 7
|
|
self.torso_index = mujoco.mj_name2id(self.model, mujoco.mjtObj.mjOBJ_BODY, "torso_link")
|
|
self.base_index = mujoco.mj_name2id(self.model, mujoco.mjtObj.mjOBJ_BODY, "pelvis")
|
|
self.action = np.zeros(self.config["num_actions"], dtype=np.float32)
|
|
self.target_dof_pos = self.config["default_angles"].copy()
|
|
self.policy = self.load_onnx_policy(self.config["policy_path"])
|
|
self.gait_indices = torch.zeros((1), dtype=torch.float32)
|
|
self.counter = 0
|
|
self.just_started = 0.0
|
|
self.walking_mask = False
|
|
self.frozen_FL = False
|
|
self.frozen_FR = False
|
|
self.single_obs, self.single_obs_dim = self.compute_observation(
|
|
self.data, self.config, self.action, self.control_dict, self.n_joints
|
|
)
|
|
self.obs_history = collections.deque(
|
|
[np.zeros(self.single_obs_dim, dtype=np.float32)] * self.config["obs_history_len"],
|
|
maxlen=self.config["obs_history_len"],
|
|
)
|
|
self.obs = np.zeros(self.config["num_obs"], dtype=np.float32)
|
|
self.keyboard_listener(self.control_dict, self.config)
|
|
|
|
def keyboard_listener(self, control_dict, config):
|
|
"""Listen to key press events and update cmd and height_cmd"""
|
|
|
|
def on_press(key):
|
|
try:
|
|
k = key.char
|
|
except AttributeError:
|
|
return # Special keys ignored
|
|
|
|
with self.cmd_lock:
|
|
if k == "w":
|
|
control_dict["loco_cmd"][0] += 0.2
|
|
elif k == "s":
|
|
control_dict["loco_cmd"][0] -= 0.2
|
|
elif k == "a":
|
|
control_dict["loco_cmd"][1] += 0.5
|
|
elif k == "d":
|
|
control_dict["loco_cmd"][1] -= 0.5
|
|
elif k == "q":
|
|
control_dict["loco_cmd"][2] += 0.5
|
|
elif k == "e":
|
|
control_dict["loco_cmd"][2] -= 0.5
|
|
elif k == "z":
|
|
control_dict["loco_cmd"][:] = config["cmd_init"]
|
|
control_dict["height_cmd"] = config["height_cmd"]
|
|
control_dict["rpy_cmd"][:] = config["rpy_cmd"]
|
|
control_dict["freq_cmd"] = config["freq_cmd"]
|
|
elif k == "1":
|
|
control_dict["height_cmd"] += 0.05
|
|
elif k == "2":
|
|
control_dict["height_cmd"] -= 0.05
|
|
elif k == "3":
|
|
control_dict["rpy_cmd"][0] += 0.2
|
|
elif k == "4":
|
|
control_dict["rpy_cmd"][0] -= 0.2
|
|
elif k == "5":
|
|
control_dict["rpy_cmd"][1] += 0.2
|
|
elif k == "6":
|
|
control_dict["rpy_cmd"][1] -= 0.2
|
|
elif k == "7":
|
|
control_dict["rpy_cmd"][2] += 0.2
|
|
elif k == "8":
|
|
control_dict["rpy_cmd"][2] -= 0.2
|
|
elif k == "m":
|
|
control_dict["freq_cmd"] += 0.1
|
|
elif k == "n":
|
|
control_dict["freq_cmd"] -= 0.1
|
|
|
|
print(
|
|
f"Current Commands: loco_cmd = {control_dict['loco_cmd']}, height_cmd = {control_dict['height_cmd']}, rpy_cmd = {control_dict['rpy_cmd']}, freq_cmd = {control_dict['freq_cmd']}"
|
|
)
|
|
|
|
listener = pkb.Listener(on_press=on_press)
|
|
listener.daemon = True
|
|
listener.start()
|
|
|
|
def load_config(self, config_path):
|
|
with open(config_path, "r") as f:
|
|
config = yaml.safe_load(f)
|
|
|
|
for path_key in ["policy_path", "xml_path"]:
|
|
config[path_key] = os.path.join(CONFIG_PATH, config[path_key])
|
|
|
|
array_keys = ["kps", "kds", "default_angles", "cmd_scale", "cmd_init"]
|
|
for key in array_keys:
|
|
config[key] = np.array(config[key], dtype=np.float32)
|
|
|
|
return config
|
|
|
|
def pd_control(self, target_q, q, kp, target_dq, dq, kd):
|
|
return (target_q - q) * kp + (target_dq - dq) * kd
|
|
|
|
def quat_rotate_inverse(self, q, v):
|
|
w, x, y, z = q
|
|
q_conj = np.array([w, -x, -y, -z])
|
|
return np.array(
|
|
[
|
|
v[0] * (q_conj[0] ** 2 + q_conj[1] ** 2 - q_conj[2] ** 2 - q_conj[3] ** 2)
|
|
+ v[1] * 2 * (q_conj[1] * q_conj[2] - q_conj[0] * q_conj[3])
|
|
+ v[2] * 2 * (q_conj[1] * q_conj[3] + q_conj[0] * q_conj[2]),
|
|
v[0] * 2 * (q_conj[1] * q_conj[2] + q_conj[0] * q_conj[3])
|
|
+ v[1] * (q_conj[0] ** 2 - q_conj[1] ** 2 + q_conj[2] ** 2 - q_conj[3] ** 2)
|
|
+ v[2] * 2 * (q_conj[2] * q_conj[3] - q_conj[0] * q_conj[1]),
|
|
v[0] * 2 * (q_conj[1] * q_conj[3] - q_conj[0] * q_conj[2])
|
|
+ v[1] * 2 * (q_conj[2] * q_conj[3] + q_conj[0] * q_conj[1])
|
|
+ v[2] * (q_conj[0] ** 2 - q_conj[1] ** 2 - q_conj[2] ** 2 + q_conj[3] ** 2),
|
|
]
|
|
)
|
|
|
|
def get_gravity_orientation(self, quat):
|
|
gravity_vec = np.array([0.0, 0.0, -1.0])
|
|
return self.quat_rotate_inverse(quat, gravity_vec)
|
|
|
|
def compute_observation(self, d, config, action, control_dict, n_joints):
|
|
command = np.zeros(8, dtype=np.float32)
|
|
command[:3] = control_dict["loco_cmd"][:3] * config["cmd_scale"]
|
|
command[3] = control_dict["height_cmd"]
|
|
command[4] = control_dict["freq_cmd"]
|
|
command[5:8] = control_dict["rpy_cmd"]
|
|
|
|
# gait indice
|
|
is_static = np.linalg.norm(command[:3]) < 0.1
|
|
just_entered_walk = (not is_static) and (not self.walking_mask)
|
|
self.walking_mask = not is_static
|
|
|
|
if just_entered_walk:
|
|
self.just_started = 0.0
|
|
self.gait_indices = torch.tensor([-0.25])
|
|
if not is_static:
|
|
self.just_started += 0.02
|
|
else:
|
|
self.just_started = 0.0
|
|
|
|
if not is_static:
|
|
self.frozen_FL = False
|
|
self.frozen_FR = False
|
|
|
|
self.gait_indices = torch.remainder(self.gait_indices + 0.02 * command[4], 1.0)
|
|
|
|
# Parameters
|
|
duration = 0.5
|
|
phase = 0.5
|
|
|
|
# Gait indices
|
|
gait_FR = self.gait_indices.clone()
|
|
gait_FL = torch.remainder(gait_FR + phase, 1.0)
|
|
|
|
if self.just_started < (0.5 / command[4]):
|
|
gait_FR = torch.tensor([0.25])
|
|
gait_pair = [gait_FL.clone(), gait_FR.clone()]
|
|
|
|
for i, fi in enumerate(gait_pair):
|
|
if fi.item() < duration:
|
|
gait_pair[i] = fi * (0.5 / duration)
|
|
else:
|
|
gait_pair[i] = 0.5 + (fi - duration) * (0.5 / (1 - duration))
|
|
|
|
# Clock signal
|
|
clock = [torch.sin(2 * np.pi * fi) for fi in gait_pair]
|
|
|
|
for i, (clk, frozen_mask_attr) in enumerate(zip(clock, ["frozen_FL", "frozen_FR"])):
|
|
frozen_mask = getattr(self, frozen_mask_attr)
|
|
# Freeze condition: static and at sin peak
|
|
if is_static and (not frozen_mask) and clk.item() > 0.98:
|
|
setattr(self, frozen_mask_attr, True)
|
|
clk = torch.tensor([1.0])
|
|
if getattr(self, frozen_mask_attr):
|
|
clk = torch.tensor([1.0])
|
|
clock[i] = clk
|
|
|
|
self.clock_inputs = torch.stack(clock).unsqueeze(0)
|
|
qj = d.qpos[7 : 7 + n_joints].copy()
|
|
dqj = d.qvel[6 : 6 + n_joints].copy()
|
|
quat = d.qpos[3:7].copy()
|
|
omega = d.qvel[3:6].copy()
|
|
# omega = self.data.xmat[self.base_index].reshape(3, 3).T @ self.data.cvel[self.base_index][3:6]
|
|
padded_defaults = np.zeros(n_joints, dtype=np.float32)
|
|
L = min(len(config["default_angles"]), n_joints)
|
|
padded_defaults[:L] = config["default_angles"][:L]
|
|
|
|
qj_scaled = (qj - padded_defaults) * config["dof_pos_scale"]
|
|
dqj_scaled = dqj * config["dof_vel_scale"]
|
|
gravity_orientation = self.get_gravity_orientation(quat)
|
|
omega_scaled = omega * config["ang_vel_scale"]
|
|
|
|
torso_quat = self.data.xquat[self.torso_index]
|
|
torso_omega = (
|
|
self.data.xmat[self.torso_index].reshape(3, 3).T @ self.data.cvel[self.torso_index][3:6]
|
|
)
|
|
torso_omega_scaled = torso_omega * config["ang_vel_scale"]
|
|
torso_gravity_orientation = self.get_gravity_orientation(torso_quat)
|
|
|
|
single_obs_dim = 95
|
|
single_obs = np.zeros(single_obs_dim, dtype=np.float32)
|
|
single_obs[0:8] = command[:8]
|
|
single_obs[8:11] = omega_scaled
|
|
single_obs[11:14] = gravity_orientation
|
|
single_obs[14:17] = 0.0 # torso_omega_scaled
|
|
single_obs[17:20] = 0.0 # torso_gravity_orientation
|
|
single_obs[20 : 20 + n_joints] = qj_scaled
|
|
single_obs[20 + n_joints : 20 + 2 * n_joints] = dqj_scaled
|
|
single_obs[20 + 2 * n_joints : 20 + 2 * n_joints + 15] = action
|
|
single_obs[20 + 2 * n_joints + 15 : 20 + 2 * n_joints + 15 + 2] = (
|
|
self.clock_inputs.cpu().numpy().reshape(2)
|
|
)
|
|
|
|
return single_obs, single_obs_dim
|
|
|
|
def load_onnx_policy(self, path):
|
|
model = ort.InferenceSession(path)
|
|
|
|
def run_inference(input_tensor):
|
|
ort_inputs = {model.get_inputs()[0].name: input_tensor.cpu().numpy()}
|
|
ort_outs = model.run(None, ort_inputs)
|
|
return torch.tensor(ort_outs[0], device="cuda:0")
|
|
|
|
return run_inference
|
|
|
|
def run(self):
|
|
|
|
self.counter = 0
|
|
|
|
with mujoco.viewer.launch_passive(self.model, self.data) as viewer:
|
|
start = time.time()
|
|
while viewer.is_running() and time.time() - start < self.config["simulation_duration"]:
|
|
step_start = time.time()
|
|
|
|
leg_tau = self.pd_control(
|
|
self.target_dof_pos,
|
|
self.data.qpos[7 : 7 + self.config["num_actions"]],
|
|
self.config["kps"],
|
|
np.zeros_like(self.config["kps"]),
|
|
self.data.qvel[6 : 6 + self.config["num_actions"]],
|
|
self.config["kds"],
|
|
)
|
|
self.data.ctrl[: self.config["num_actions"]] = leg_tau
|
|
|
|
if self.n_joints > self.config["num_actions"]:
|
|
arm_tau = self.pd_control(
|
|
np.zeros(self.n_joints - self.config["num_actions"], dtype=np.float32),
|
|
self.data.qpos[7 + self.config["num_actions"] : 7 + self.n_joints],
|
|
np.full(self.n_joints - self.config["num_actions"], 100.0),
|
|
np.zeros(self.n_joints - self.config["num_actions"]),
|
|
self.data.qvel[6 + self.config["num_actions"] : 6 + self.n_joints],
|
|
np.full(self.n_joints - self.config["num_actions"], 0.5),
|
|
)
|
|
self.data.ctrl[self.config["num_actions"] :] = arm_tau
|
|
|
|
mujoco.mj_step(self.model, self.data)
|
|
|
|
self.counter += 1
|
|
if self.counter % self.config["control_decimation"] == 0:
|
|
with self.cmd_lock:
|
|
current_cmd = self.control_dict
|
|
|
|
single_obs, _ = self.compute_observation(
|
|
self.data, self.config, self.action, current_cmd, self.n_joints
|
|
)
|
|
self.obs_history.append(single_obs)
|
|
|
|
for i, hist_obs in enumerate(self.obs_history):
|
|
self.obs[i * self.single_obs_dim : (i + 1) * self.single_obs_dim] = hist_obs
|
|
|
|
obs_tensor = torch.from_numpy(self.obs).unsqueeze(0)
|
|
self.action = self.policy(obs_tensor).cpu().detach().numpy().squeeze()
|
|
self.target_dof_pos = (
|
|
self.action * self.config["action_scale"] + self.config["default_angles"]
|
|
)
|
|
|
|
viewer.sync()
|
|
# time.sleep(max(0, self.model.opt.timestep - (time.time() - step_start)))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
CONFIG_PATH = os.path.join(
|
|
os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "resources", "robots", "g1"
|
|
)
|
|
controller = GearWbcController(CONFIG_PATH)
|
|
controller.run()
|