You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

557 lines
18 KiB

"""SMPL/SMPLX body model utilities: downloading, creation, and pose decomposition."""
import os
import pickle
import subprocess
from pathlib import Path
import numpy as np
import smplx
import torch
import torch.nn.functional as F
from smplx import SMPL, SMPLX, SMPLXLayer
from groot.rl.trl.utils.smplx.body_model import BodyModelSMPLH, BodyModelSMPLX
from groot.rl.trl.utils.smplx.body_model.smplx_lite import (
SmplxLiteCoco17,
SmplxLiteSmplN24,
SmplxLiteV437Coco17,
)
# fmt: off
SMPLH_PARENTS = torch.tensor([-1, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 9, 9, 12, 13, 14,
16, 17, 18, 19, 20, 22, 23, 20, 25, 26, 20, 28, 29, 20, 31, 32, 20, 34,
35, 21, 37, 38, 21, 40, 41, 21, 43, 44, 21, 46, 47, 21, 49, 50])
# fmt: on
def _ensure_body_models_downloaded(model_path: str) -> Path:
"""
Ensure body_models directory exists, downloading from gear data if needed.
Args:
model_path: Path to body_models directory (e.g., "data/body_models")
Returns:
Path object to the body_models directory
"""
model_path_obj = Path(model_path)
# Check if path exists and has required subdirectories (smplx is the main one we need)
if model_path_obj.exists():
smplx_subdir = model_path_obj / "smplx"
if smplx_subdir.exists() and any(smplx_subdir.iterdir()):
return model_path_obj
# Try relative to project root (gr00t/), not groot/
# File is at: groot/rl/trl/utils/smplx/smplx_utils.py
# Need to go up 6 levels to get to project root
project_root = Path(__file__).parent.parent.parent.parent.parent.parent
full_path = project_root / model_path
if full_path.exists():
smplx_subdir = full_path / "smplx"
if smplx_subdir.exists() and any(smplx_subdir.iterdir()):
return full_path
# Need to download - create directory
full_path.mkdir(parents=True, exist_ok=True)
print(f"Body models not found at {model_path}. Downloading from gear data storage...")
try:
# Use exact command format that works manually: gear data download wbc_data/body_models/** data/body_models/
result = subprocess.run(
[
"gear",
"data",
"download",
"wbc_data/body_models/**",
str(full_path) + "/",
],
check=True,
capture_output=True,
text=True,
)
# Print output for debugging
if result.stdout:
print(result.stdout)
if result.stderr:
print(result.stderr)
# Verify download succeeded by checking if directory has contents
if not any(full_path.iterdir()):
raise RuntimeError(
f"Download completed but directory is empty. "
f"Check if 'wbc_data/body_models' exists in gear data storage.\n"
f"Output: {result.stdout}\n"
f"Error: {result.stderr}"
)
print(f"Successfully downloaded body models to {full_path}")
except subprocess.CalledProcessError as e:
error_msg = e.stderr if e.stderr else (e.stdout if e.stdout else "Unknown error")
raise RuntimeError(
f"Failed to download body models from wbc_data/body_models.\n"
f"Error: {error_msg}\n"
f"Please ensure 'gear' CLI is installed and configured, or manually download:\n"
f" gear data download wbc_data/body_models/** {full_path}/"
)
except FileNotFoundError:
raise RuntimeError(
f"'gear' command not found. Please install gear CLI or manually download body models:\n"
f" gear data download wbc_data/body_models/** {full_path}/"
)
return full_path
def make_smplx(type="neu_fullpose", **kwargs):
if type == "neu_fullpose":
model = smplx.create(
model_path="inputs/models/smplx/SMPLX_NEUTRAL.npz",
use_pca=False,
flat_hand_mean=True,
**kwargs,
)
elif type == "supermotion":
# SuperMotion is trained on BEDLAM dataset, the smplx config is the same except only 10 betas are used
bm_kwargs = {
"model_type": "smplx",
"gender": "neutral",
"num_pca_comps": 12,
"flat_hand_mean": False,
}
bm_kwargs.update(kwargs)
body_models_path = _ensure_body_models_downloaded("data/body_models")
model = BodyModelSMPLX(model_path=str(body_models_path), **bm_kwargs)
elif type == "supermotion_EVAL3DPW":
# SuperMotion is trained on BEDLAM dataset, the smplx config is the same except only 10 betas are used
bm_kwargs = {
"model_type": "smplx",
"gender": "neutral",
"num_pca_comps": 12,
"flat_hand_mean": True,
}
bm_kwargs.update(kwargs)
body_models_path = _ensure_body_models_downloaded("data/body_models")
model = BodyModelSMPLX(model_path=str(body_models_path), **bm_kwargs)
elif type == "supermotion_coco17":
# Fast but only predicts 17 joints
body_models_path = _ensure_body_models_downloaded("data/body_models")
model = SmplxLiteCoco17(model_path=str(body_models_path / "smplx"))
elif type == "supermotion_v437coco17":
# Predicts 437 verts and 17 joints
body_models_path = _ensure_body_models_downloaded("data/body_models")
model = SmplxLiteV437Coco17(model_path=str(body_models_path / "smplx"))
elif type == "supermotion_smpl24":
body_models_path = _ensure_body_models_downloaded("data/body_models")
model = SmplxLiteSmplN24(model_path=str(body_models_path / "smplx"))
elif type == "supermotion_smpl24_male":
body_models_path = _ensure_body_models_downloaded("data/body_models")
model = SmplxLiteSmplN24(gender="male", model_path=str(body_models_path / "smplx"))
elif type == "supermotion_smpl24_female":
body_models_path = _ensure_body_models_downloaded("data/body_models")
model = SmplxLiteSmplN24(gender="female", model_path=str(body_models_path / "smplx"))
elif type == "rich-smplx":
# https://github.com/paulchhuang/rich_toolkit/blob/main/smplx2images.py
bm_kwargs = {
"model_type": "smplx",
"gender": kwargs.get("gender", "male"),
"num_pca_comps": 12,
"flat_hand_mean": False,
# create_expression=True, create_jaw_pose=Ture
}
# A /smplx folder should exist under the model_path
body_models_path = _ensure_body_models_downloaded("data/body_models")
model = BodyModelSMPLX(model_path=str(body_models_path), **bm_kwargs)
elif type == "rich-smplh":
bm_kwargs = {
"model_type": "smplh",
"gender": kwargs.get("gender", "male"),
"use_pca": False,
"flat_hand_mean": True,
}
body_models_path = _ensure_body_models_downloaded("data/body_models")
model = BodyModelSMPLH(model_path=str(body_models_path), **bm_kwargs)
elif type in ["smplx-circle", "smplx-groundlink"]:
# don't use hand
body_models_path = _ensure_body_models_downloaded("data/body_models")
bm_kwargs = {
"model_path": str(body_models_path),
"model_type": "smplx",
"gender": kwargs.get("gender"),
"num_betas": 16,
"num_expression": 0,
}
model = BodyModelSMPLX(**bm_kwargs)
elif type == "smplx-motionx":
layer_args = {
"create_global_orient": False,
"create_body_pose": False,
"create_left_hand_pose": False,
"create_right_hand_pose": False,
"create_jaw_pose": False,
"create_leye_pose": False,
"create_reye_pose": False,
"create_betas": False,
"create_expression": False,
"create_transl": False,
}
body_models_path = _ensure_body_models_downloaded("data/body_models")
bm_kwargs = {
"model_type": "smplx",
"model_path": str(body_models_path),
"gender": "neutral",
"use_pca": False,
"use_face_contour": True,
**layer_args,
}
model = smplx.create(**bm_kwargs)
elif type == "smplx-samp":
# don't use hand
body_models_path = _ensure_body_models_downloaded("data/body_models")
bm_kwargs = {
"model_path": str(body_models_path),
"model_type": "smplx",
"gender": kwargs.get("gender"),
"num_betas": 10,
"num_expression": 0,
}
model = BodyModelSMPLX(**bm_kwargs)
elif type == "smplx-bedlam":
# don't use hand
bm_kwargs = {
"model_path": "data/body_models",
"model_type": "smplx",
"gender": kwargs.get("gender"),
"num_betas": 11,
"num_expression": 0,
}
model = BodyModelSMPLX(**bm_kwargs)
elif type in ["smplx-layer", "smplx-fit3d"]:
# Use layer
if type == "smplx-fit3d":
assert (
kwargs.get("gender") == "neutral"
), "smplx-fit3d use neutral model: https://github.com/sminchisescu-research/imar_vision_datasets_tools/blob/e8c8f83ffac23cc36adf8ec8d0fd1c55679484ef/util/smplx_util.py#L15C34-L15C34"
body_models_path = _ensure_body_models_downloaded("data/body_models")
bm_kwargs = {
"model_path": str(body_models_path / "smplx"),
"gender": kwargs.get("gender"),
"num_betas": 10,
"num_expression": 10,
}
model = SMPLXLayer(**bm_kwargs)
elif type == "smpl":
body_models_path = _ensure_body_models_downloaded("data/body_models")
bm_kwargs = {
"model_path": str(body_models_path),
"model_type": "smpl",
"gender": "neutral",
"num_betas": 10,
"create_body_pose": False,
"create_betas": False,
"create_global_orient": False,
"create_transl": False,
}
bm_kwargs.update(kwargs)
# model = SMPL(**bm_kwargs)
model = BodyModelSMPLH(**bm_kwargs)
elif type == "smplh":
bm_kwargs = {
"model_type": "smplh",
"gender": kwargs.get("gender", "male"),
"use_pca": False,
"flat_hand_mean": False,
}
body_models_path = _ensure_body_models_downloaded("data/body_models")
model = BodyModelSMPLH(model_path=str(body_models_path), **bm_kwargs)
else:
raise NotImplementedError
if hasattr(model, "create_parents_list"):
model.create_parents_list()
return model
def load_parents(npz_path="models/smplx/SMPLX_NEUTRAL.npz"):
smplx_struct = np.load("models/smplx/SMPLX_NEUTRAL.npz", allow_pickle=True)
parents = smplx_struct["kintree_table"][0].astype(np.long)
parents[0] = -1
return parents
def load_smpl_faces(npz_path="models/smplh/SMPLH_FEMALE.pkl"):
with open(npz_path, "rb") as f:
smpl_model = pickle.load(f, encoding="latin1")
faces = np.array(smpl_model["f"].astype(np.int64))
return faces
def decompose_fullpose(fullpose, model_type="smplx"):
assert model_type == "smplx"
fullpose_dict = {
"global_orient": fullpose[..., :3],
"body_pose": fullpose[..., 3:66],
"jaw_pose": fullpose[..., 66:69],
"leye_pose": fullpose[..., 69:72],
"reye_pose": fullpose[..., 72:75],
"left_hand_pose": fullpose[..., 75:120],
"right_hand_pose": fullpose[..., 120:165],
}
return fullpose_dict
def compose_fullpose(fullpose_dict, model_type="smplx"):
assert model_type == "smplx"
fullpose = torch.cat(
[
fullpose_dict[k]
for k in [
"global_orient",
"body_pose",
"jaw_pose",
"leye_pose",
"reye_pose",
"left_hand_pose",
"right_hand_pose",
]
],
dim=-1,
)
return fullpose
def compute_R_from_kinetree(rot_mats, parents):
"""operation of lbs/batch_rigid_transform, focus on 3x3 R only
Parameters
----------
rot_mats: torch.tensor BxNx3x3
Tensor of rotation matrices
parents : torch.tensor BxN
The kinematic tree of each object
Returns
-------
R : torch.tensor BxNx3x3
Tensor of rotation matrices
"""
rot_mat_chain = [rot_mats[:, 0]]
for i in range(1, parents.shape[0]):
curr_res = torch.matmul(rot_mat_chain[parents[i]], rot_mats[:, i])
rot_mat_chain.append(curr_res)
R = torch.stack(rot_mat_chain, dim=1)
return R
def compute_relR_from_kinetree(R, parents):
"""Inverse operation of lbs/batch_rigid_transform, focus on 3x3 R only
Parameters
----------
R : torch.tensor BxNx4x4 or BxNx3x3
Tensor of rotation matrices
parents : torch.tensor BxN
The kinematic tree of each object
Returns
-------
rot_mats: torch.tensor BxNx3x3
Tensor of rotation matrices
"""
R = R[:, :, :3, :3]
Rp = R[:, parents] # Rp[:, 0] is invalid
rot_mats = Rp.transpose(2, 3) @ R
rot_mats[:, 0] = R[:, 0]
return rot_mats
def quat_mul(x, y):
"""
Performs quaternion multiplication on arrays of quaternions
:param x: tensor of quaternions of shape (..., Nb of joints, 4)
:param y: tensor of quaternions of shape (..., Nb of joints, 4)
:return: The resulting quaternions
"""
x0, x1, x2, x3 = x[..., 0:1], x[..., 1:2], x[..., 2:3], x[..., 3:4]
y0, y1, y2, y3 = y[..., 0:1], y[..., 1:2], y[..., 2:3], y[..., 3:4]
# res = np.concatenate(
# [
# y0 * x0 - y1 * x1 - y2 * x2 - y3 * x3,
# y0 * x1 + y1 * x0 - y2 * x3 + y3 * x2,
# y0 * x2 + y1 * x3 + y2 * x0 - y3 * x1,
# y0 * x3 - y1 * x2 + y2 * x1 + y3 * x0,
# ],
# axis=-1,
# )
res = torch.cat(
[
y0 * x0 - y1 * x1 - y2 * x2 - y3 * x3,
y0 * x1 + y1 * x0 - y2 * x3 + y3 * x2,
y0 * x2 + y1 * x3 + y2 * x0 - y3 * x1,
y0 * x3 - y1 * x2 + y2 * x1 + y3 * x0,
],
axis=-1,
)
return res
def quat_inv(q):
"""
Inverts a tensor of quaternions
:param q: quaternion tensor
:return: tensor of inverted quaternions
"""
# res = np.asarray([1, -1, -1, -1], dtype=np.float32) * q
res = torch.tensor([1, -1, -1, -1], device=q.device).float() * q
return res
def quat_mul_vec(q, x):
"""
Performs multiplication of an array of 3D vectors by an array of quaternions (rotation).
:param q: tensor of quaternions of shape (..., Nb of joints, 4)
:param x: tensor of vectors of shape (..., Nb of joints, 3)
:return: the resulting array of rotated vectors
"""
# t = 2.0 * np.cross(q[..., 1:], x)
t = 2.0 * torch.cross(q[..., 1:], x)
# res = x + q[..., 0][..., np.newaxis] * t + np.cross(q[..., 1:], t)
res = x + q[..., 0][..., None] * t + torch.cross(q[..., 1:], t)
return res
def inverse_kinematics_motion(
global_pos,
global_rot,
parents=SMPLH_PARENTS,
):
"""
Args:
global_pos : (B, T, J-1, 3)
global_rot (q) : (B, T, J-1, 4)
parents : SMPLH_PARENTS
Returns:
local_pos : (B, T, J-1, 3)
local_rot (q) : (B, T, J-1, 4)
"""
J = 22
local_pos = quat_mul_vec(
quat_inv(global_rot[..., parents[1:J], :]),
global_pos - global_pos[..., parents[1:J], :],
)
local_rot = (quat_mul(quat_inv(global_rot[..., parents[1:J], :]), global_rot),)
return local_pos, local_rot
def transform_mat(R, t):
"""Creates a batch of transformation matrices
Args:
- R: Bx3x3 array of a batch of rotation matrices
- t: Bx3x1 array of a batch of translation vectors
Returns:
- T: Bx4x4 Transformation matrix
"""
# No padding left or right, only add an extra row
return torch.cat([F.pad(R, [0, 0, 0, 1]), F.pad(t, [0, 0, 0, 1], value=1)], dim=2)
def normalize_joints(joints):
"""
Args:
joints: (B, *, J, 3)
"""
LR_hips_xy = joints[..., 2, [0, 1]] - joints[..., 1, [0, 1]]
LR_shoulders_xy = joints[..., 17, [0, 1]] - joints[..., 16, [0, 1]]
LR_xy = (LR_hips_xy + LR_shoulders_xy) / 2 # (B, *, J, 2)
x_dir = F.pad(F.normalize(LR_xy, 2, -1), (0, 1), "constant", 0) # (B, *, 3)
z_dir = torch.zeros_like(x_dir) # (B, *, 3)
z_dir[..., 2] = 1
y_dir = torch.cross(z_dir, x_dir, dim=-1)
joints_normalized = (joints - joints[..., [0], :]) @ torch.stack([x_dir, y_dir, z_dir], dim=-1)
return joints_normalized
@torch.no_grad()
def compute_Rt_af2az(joints, inverse=False):
"""Assume z coord is upward
Args:
joints: (B, J, 3), in the start-frame
Returns:
R_af2az: (B, 3, 3)
t_af2az: (B, 3)
"""
t_af2az = joints[:, 0, :].detach().clone()
t_af2az[:, 2] = 0 # do not modify z
LR_xy = joints[:, 2, [0, 1]] - joints[:, 1, [0, 1]] # (B, 2)
I_mask = LR_xy.pow(2).sum(-1) < 1e-4 # do not rotate, when can't decided the face direction
x_dir = F.pad(F.normalize(LR_xy, 2, -1), (0, 1), "constant", 0) # (B, 3)
z_dir = torch.zeros_like(x_dir)
z_dir[..., 2] = 1
y_dir = torch.cross(z_dir, x_dir, dim=-1)
R_af2az = torch.stack([x_dir, y_dir, z_dir], dim=-1) # (B, 3, 3)
R_af2az[I_mask] = torch.eye(3).to(R_af2az)
if inverse:
R_az2af = R_af2az.transpose(1, 2)
t_az2af = -(R_az2af @ t_af2az.unsqueeze(2)).squeeze(2)
return R_az2af, t_az2af
else:
return R_af2az, t_af2az
def finite_difference_forward(x, dim_t=1, dup_last=True):
if dim_t == 1:
v = x[:, 1:] - x[:, :-1]
if dup_last:
v = torch.cat([v, v[:, [-1]]], dim=1)
else:
raise NotImplementedError
return v
def compute_joints_zero(betas, gender):
"""
Args:
betas: (16)
gender: 'male' or 'female'
Returns:
joints_zero: (22, 3)
"""
body_model = {
"male": make_smplx(type="humor", gender="male"),
"female": make_smplx(type="humor", gender="female"),
}
smpl_params = {
"root_orient": torch.zeros((1, 3)),
"pose_body": torch.zeros((1, 63)),
"betas": betas[None],
"trans": torch.zeros(1, 3),
}
joints_zero = body_model[gender](**smpl_params).Jtr[0, :22]
return joints_zero