You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
303 lines
13 KiB
303 lines
13 KiB
import sys
|
|
import time
|
|
import collections
|
|
import yaml
|
|
import torch
|
|
import numpy as np
|
|
import sys
|
|
import time
|
|
import collections
|
|
import yaml
|
|
import torch
|
|
import numpy as np
|
|
import mujoco
|
|
import mujoco.viewer
|
|
import onnxruntime as ort
|
|
import threading
|
|
from pynput import keyboard as pkb
|
|
import os
|
|
|
|
class GearWbcController:
|
|
def __init__(self, config_path):
|
|
self.CONFIG_PATH = config_path
|
|
self.cmd_lock = threading.Lock()
|
|
self.config = self.load_config(os.path.join(self.CONFIG_PATH, "g1_gear_wbc.yaml"))
|
|
|
|
self.control_dict = {
|
|
'loco_cmd': self.config['cmd_init'],
|
|
"height_cmd": self.config['height_cmd'],
|
|
"rpy_cmd": self.config.get('rpy_cmd', [0.0, 0.0, 0.0]),
|
|
"freq_cmd": self.config.get('freq_cmd', 1.5)
|
|
}
|
|
|
|
self.model = mujoco.MjModel.from_xml_path(self.config['xml_path'])
|
|
self.data = mujoco.MjData(self.model)
|
|
self.model.opt.timestep = self.config['simulation_dt']
|
|
self.n_joints = self.data.qpos.shape[0] - 7
|
|
self.torso_index = mujoco.mj_name2id(self.model, mujoco.mjtObj.mjOBJ_BODY, "torso_link")
|
|
self.base_index = mujoco.mj_name2id(self.model, mujoco.mjtObj.mjOBJ_BODY, "pelvis")
|
|
self.action = np.zeros(self.config['num_actions'], dtype=np.float32)
|
|
self.target_dof_pos = self.config['default_angles'].copy()
|
|
self.policy = self.load_onnx_policy(self.config['policy_path'])
|
|
self.gait_indices = torch.zeros((1), dtype=torch.float32)
|
|
self.counter = 0
|
|
self.just_started = 0.0
|
|
self.walking_mask = False
|
|
self.frozen_FL = False
|
|
self.frozen_FR = False
|
|
self.single_obs, self.single_obs_dim = self.compute_observation(
|
|
self.data, self.config, self.action, self.control_dict, self.n_joints
|
|
)
|
|
self.obs_history = collections.deque(
|
|
[np.zeros(self.single_obs_dim, dtype=np.float32)] * self.config['obs_history_len'],
|
|
maxlen=self.config['obs_history_len']
|
|
)
|
|
self.obs = np.zeros(self.config['num_obs'], dtype=np.float32)
|
|
self.keyboard_listener(self.control_dict, self.config)
|
|
|
|
|
|
def keyboard_listener(self, control_dict, config):
|
|
"""Listen to key press events and update cmd and height_cmd"""
|
|
def on_press(key):
|
|
try:
|
|
k = key.char
|
|
except AttributeError:
|
|
return # Special keys ignored
|
|
|
|
with self.cmd_lock:
|
|
if k == 'w':
|
|
control_dict['loco_cmd'][0] += 0.2
|
|
elif k == 's':
|
|
control_dict['loco_cmd'][0] -= 0.2
|
|
elif k == 'a':
|
|
control_dict['loco_cmd'][1] += 0.5
|
|
elif k == 'd':
|
|
control_dict['loco_cmd'][1] -= 0.5
|
|
elif k == 'q':
|
|
control_dict['loco_cmd'][2] += 0.5
|
|
elif k == 'e':
|
|
control_dict['loco_cmd'][2] -= 0.5
|
|
elif k == 'z':
|
|
control_dict['loco_cmd'][:] = config['cmd_init']
|
|
control_dict["height_cmd"] = config['height_cmd']
|
|
control_dict['rpy_cmd'][:] = config['rpy_cmd']
|
|
control_dict['freq_cmd'] = config['freq_cmd']
|
|
elif k == '1':
|
|
control_dict["height_cmd"] += 0.05
|
|
elif k == '2':
|
|
control_dict["height_cmd"] -= 0.05
|
|
elif k == '3':
|
|
control_dict['rpy_cmd'][0] += 0.2
|
|
elif k == '4':
|
|
control_dict['rpy_cmd'][0] -= 0.2
|
|
elif k == '5':
|
|
control_dict['rpy_cmd'][1] += 0.2
|
|
elif k == '6':
|
|
control_dict['rpy_cmd'][1] -= 0.2
|
|
elif k == '7':
|
|
control_dict['rpy_cmd'][2] += 0.2
|
|
elif k == '8':
|
|
control_dict['rpy_cmd'][2] -= 0.2
|
|
elif k == 'm':
|
|
control_dict['freq_cmd'] += 0.1
|
|
elif k == 'n':
|
|
control_dict['freq_cmd'] -= 0.1
|
|
|
|
print(f"Current Commands: loco_cmd = {control_dict['loco_cmd']}, height_cmd = {control_dict['height_cmd']}, rpy_cmd = {control_dict['rpy_cmd']}, freq_cmd = {control_dict['freq_cmd']}")
|
|
|
|
listener = pkb.Listener(on_press=on_press)
|
|
listener.daemon = True
|
|
listener.start()
|
|
|
|
def load_config(self, config_path):
|
|
with open(config_path, 'r') as f:
|
|
config = yaml.safe_load(f)
|
|
|
|
for path_key in ['policy_path', 'xml_path']:
|
|
config[path_key] = os.path.join(CONFIG_PATH, config[path_key])
|
|
|
|
array_keys = ['kps', 'kds', 'default_angles', 'cmd_scale', 'cmd_init']
|
|
for key in array_keys:
|
|
config[key] = np.array(config[key], dtype=np.float32)
|
|
|
|
return config
|
|
|
|
def pd_control(self, target_q, q, kp, target_dq, dq, kd):
|
|
return (target_q - q) * kp + (target_dq - dq) * kd
|
|
|
|
def quat_rotate_inverse(self, q, v):
|
|
w, x, y, z = q
|
|
q_conj = np.array([w, -x, -y, -z])
|
|
return np.array([
|
|
v[0] * (q_conj[0]**2 + q_conj[1]**2 - q_conj[2]**2 - q_conj[3]**2) +
|
|
v[1] * 2 * (q_conj[1]*q_conj[2] - q_conj[0]*q_conj[3]) +
|
|
v[2] * 2 * (q_conj[1]*q_conj[3] + q_conj[0]*q_conj[2]),
|
|
v[0] * 2 * (q_conj[1]*q_conj[2] + q_conj[0]*q_conj[3]) +
|
|
v[1] * (q_conj[0]**2 - q_conj[1]**2 + q_conj[2]**2 - q_conj[3]**2) +
|
|
v[2] * 2 * (q_conj[2]*q_conj[3] - q_conj[0]*q_conj[1]),
|
|
v[0] * 2 * (q_conj[1]*q_conj[3] - q_conj[0]*q_conj[2]) +
|
|
v[1] * 2 * (q_conj[2]*q_conj[3] + q_conj[0]*q_conj[1]) +
|
|
v[2] * (q_conj[0]**2 - q_conj[1]**2 - q_conj[2]**2 + q_conj[3]**2)
|
|
])
|
|
|
|
def get_gravity_orientation(self, quat):
|
|
gravity_vec = np.array([0.0, 0.0, -1.0])
|
|
return self.quat_rotate_inverse(quat, gravity_vec)
|
|
|
|
def compute_observation(self, d, config, action, control_dict, n_joints):
|
|
command = np.zeros(8, dtype=np.float32)
|
|
command[:3] = control_dict['loco_cmd'][:3] * config['cmd_scale']
|
|
command[3] = control_dict['height_cmd']
|
|
command[4] = control_dict['freq_cmd']
|
|
command[5:8] = control_dict['rpy_cmd']
|
|
|
|
# gait indice
|
|
is_static = np.linalg.norm(command[:3]) < 0.1
|
|
just_entered_walk = (not is_static) and (not self.walking_mask)
|
|
self.walking_mask = not is_static
|
|
|
|
if just_entered_walk:
|
|
self.just_started = 0.0
|
|
self.gait_indices = torch.tensor([-0.25])
|
|
if not is_static:
|
|
self.just_started += 0.02
|
|
else:
|
|
self.just_started = 0.0
|
|
|
|
if not is_static:
|
|
self.frozen_FL = False
|
|
self.frozen_FR = False
|
|
|
|
self.gait_indices = torch.remainder(self.gait_indices + 0.02 * command[4], 1.0)
|
|
|
|
# Parameters
|
|
duration = 0.5
|
|
phase = 0.5
|
|
|
|
# Gait indices
|
|
gait_FR = self.gait_indices.clone()
|
|
gait_FL = torch.remainder(gait_FR + phase, 1.0)
|
|
|
|
if self.just_started < (0.5 / command[4]):
|
|
gait_FR = torch.tensor([0.25])
|
|
gait_pair = [gait_FL.clone(), gait_FR.clone()]
|
|
|
|
for i, fi in enumerate(gait_pair):
|
|
if fi.item() < duration:
|
|
gait_pair[i] = fi * (0.5 / duration)
|
|
else:
|
|
gait_pair[i] = 0.5 + (fi - duration) * (0.5 / (1 - duration))
|
|
|
|
# Clock signal
|
|
clock = [torch.sin(2 * np.pi * fi) for fi in gait_pair]
|
|
|
|
|
|
for i, (clk, frozen_mask_attr) in enumerate(
|
|
zip(clock, ['frozen_FL', 'frozen_FR'])
|
|
):
|
|
frozen_mask = getattr(self, frozen_mask_attr)
|
|
# Freeze condition: static and at sin peak
|
|
if is_static and (not frozen_mask) and clk.item() > 0.98:
|
|
setattr(self, frozen_mask_attr, True)
|
|
clk = torch.tensor([1.0])
|
|
if getattr(self, frozen_mask_attr):
|
|
clk = torch.tensor([1.0])
|
|
clock[i] = clk
|
|
|
|
self.clock_inputs = torch.stack(clock).unsqueeze(0)
|
|
qj = d.qpos[7:7+n_joints].copy()
|
|
dqj = d.qvel[6:6+n_joints].copy()
|
|
quat = d.qpos[3:7].copy()
|
|
omega = d.qvel[3:6].copy()
|
|
# omega = self.data.xmat[self.base_index].reshape(3, 3).T @ self.data.cvel[self.base_index][3:6]
|
|
padded_defaults = np.zeros(n_joints, dtype=np.float32)
|
|
L = min(len(config['default_angles']), n_joints)
|
|
padded_defaults[:L] = config['default_angles'][:L]
|
|
|
|
qj_scaled = (qj - padded_defaults) * config['dof_pos_scale']
|
|
dqj_scaled = dqj * config['dof_vel_scale']
|
|
gravity_orientation = self.get_gravity_orientation(quat)
|
|
omega_scaled = omega * config['ang_vel_scale']
|
|
|
|
torso_quat = self.data.xquat[self.torso_index]
|
|
torso_omega = self.data.xmat[self.torso_index].reshape(3, 3).T @ self.data.cvel[self.torso_index][3:6]
|
|
torso_omega_scaled = torso_omega * config['ang_vel_scale']
|
|
torso_gravity_orientation = self.get_gravity_orientation(torso_quat)
|
|
|
|
single_obs_dim = 95
|
|
single_obs = np.zeros(single_obs_dim, dtype=np.float32)
|
|
single_obs[0:8] = command[:8]
|
|
single_obs[8:11] = omega_scaled
|
|
single_obs[11:14] = gravity_orientation
|
|
single_obs[14:17] = 0.#torso_omega_scaled
|
|
single_obs[17:20] = 0.#torso_gravity_orientation
|
|
single_obs[20:20+n_joints] = qj_scaled
|
|
single_obs[20+n_joints:20+2*n_joints] = dqj_scaled
|
|
single_obs[20+2*n_joints:20+2*n_joints+15] = action
|
|
single_obs[20+2*n_joints+15:20+2*n_joints+15+2] = self.clock_inputs.cpu().numpy().reshape(2)
|
|
|
|
return single_obs, single_obs_dim
|
|
|
|
def load_onnx_policy(self, path):
|
|
model = ort.InferenceSession(path)
|
|
def run_inference(input_tensor):
|
|
ort_inputs = {model.get_inputs()[0].name: input_tensor.cpu().numpy()}
|
|
ort_outs = model.run(None, ort_inputs)
|
|
return torch.tensor(ort_outs[0], device="cuda:0")
|
|
return run_inference
|
|
|
|
def run(self):
|
|
|
|
self.counter = 0
|
|
|
|
with mujoco.viewer.launch_passive(self.model, self.data) as viewer:
|
|
start = time.time()
|
|
while viewer.is_running() and time.time() - start < self.config['simulation_duration']:
|
|
step_start = time.time()
|
|
|
|
leg_tau = self.pd_control(
|
|
self.target_dof_pos,
|
|
self.data.qpos[7:7+self.config['num_actions']],
|
|
self.config['kps'],
|
|
np.zeros_like(self.config['kps']),
|
|
self.data.qvel[6:6+self.config['num_actions']],
|
|
self.config['kds']
|
|
)
|
|
self.data.ctrl[:self.config['num_actions']] = leg_tau
|
|
|
|
if self.n_joints > self.config['num_actions']:
|
|
arm_tau = self.pd_control(
|
|
np.zeros(self.n_joints - self.config['num_actions'], dtype=np.float32),
|
|
self.data.qpos[7+self.config['num_actions']:7+self.n_joints],
|
|
np.full(self.n_joints - self.config['num_actions'], 100.0),
|
|
np.zeros(self.n_joints - self.config['num_actions']),
|
|
self.data.qvel[6+self.config['num_actions']:6+self.n_joints],
|
|
np.full(self.n_joints - self.config['num_actions'], 0.5)
|
|
)
|
|
self.data.ctrl[self.config['num_actions']:] = arm_tau
|
|
|
|
mujoco.mj_step(self.model, self.data)
|
|
|
|
self.counter += 1
|
|
if self.counter % self.config['control_decimation'] == 0:
|
|
with self.cmd_lock:
|
|
current_cmd = self.control_dict
|
|
|
|
single_obs, _ = self.compute_observation(self.data, self.config, self.action, current_cmd, self.n_joints)
|
|
self.obs_history.append(single_obs)
|
|
|
|
for i, hist_obs in enumerate(self.obs_history):
|
|
self.obs[i * self.single_obs_dim:(i + 1) * self.single_obs_dim] = hist_obs
|
|
|
|
obs_tensor = torch.from_numpy(self.obs).unsqueeze(0)
|
|
self.action = self.policy(obs_tensor).cpu().detach().numpy().squeeze()
|
|
self.target_dof_pos = self.action * self.config['action_scale'] + self.config['default_angles']
|
|
|
|
viewer.sync()
|
|
# time.sleep(max(0, self.model.opt.timestep - (time.time() - step_start)))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
CONFIG_PATH = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "resources", "robots", "g1")
|
|
controller = GearWbcController(CONFIG_PATH)
|
|
controller.run()
|