You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

180 lines
6.5 KiB

"""Interpolation and frame-rate rescaling for pose sequences.
Provides linear interpolation (via scipy), quaternion slerp, and
functions to up/down-sample joint pose trajectories to a target frame rate.
"""
import torch
import numpy as np
from scipy.interpolate import interp1d
from .kornia_transform import angle_axis_to_quaternion, quaternion_to_angle_axis
def interp_tensor_with_scipy(x, new_len=None, scale=None, dim=-1):
orig_len = x.shape[dim]
if new_len is None:
new_len = int(orig_len * scale)
T = orig_len
f = interp1d(
np.linspace(0, T, orig_len),
x.cpu().numpy(),
axis=dim,
assume_sorted=True,
fill_value="extrapolate",
)
x_interp = torch.from_numpy(f(np.linspace(0, T, new_len))).type_as(x)
return x_interp
def slerp(q0, q1, t):
# type: (torch.Tensor, torch.Tensor, torch.Tensor) -> torch.Tensor
cos_half_theta = torch.sum(q0 * q1, dim=-1)
neg_mask = cos_half_theta < 0
q1 = q1.clone()
# Replace: q1[neg_mask] = -q1[neg_mask]
# With: torch.where for safer broadcasting
neg_mask_expanded = neg_mask.unsqueeze(-1).expand_as(q1)
q1 = torch.where(neg_mask_expanded, -q1, q1)
cos_half_theta = torch.abs(cos_half_theta)
cos_half_theta = torch.unsqueeze(cos_half_theta, dim=-1)
half_theta = torch.acos(cos_half_theta)
sin_half_theta = torch.sqrt(1.0 - cos_half_theta * cos_half_theta)
ratioA = torch.sin((1 - t[:, None]) * half_theta) / sin_half_theta
ratioB = torch.sin(t[:, None] * half_theta) / sin_half_theta
new_q = ratioA * q0 + ratioB * q1
new_q = torch.where(torch.abs(sin_half_theta) < 0.001, 0.5 * q0 + 0.5 * q1, new_q)
new_q = torch.where(torch.abs(cos_half_theta) >= 1, q0, new_q)
return new_q
def _slerp_batch(a: torch.Tensor, b: torch.Tensor, blend: torch.Tensor) -> torch.Tensor:
"""Spherical linear interpolation between two quaternions."""
slerped_quats = torch.zeros_like(a)
slerped_quats = slerp(a, b, blend)
return slerped_quats
def interpolate_quaternions(
pose_quat: torch.Tensor, source_fps: float, target_fps: float
) -> torch.Tensor:
"""
Interpolate quaternions from source_fps to target_fps.
Args:
pose_quat: Input quaternions, shape (1, T, 4)
source_fps: Source frame rate
target_fps: Target frame rate
Returns:
Interpolated quaternions
"""
device = pose_quat.device
in_shape = pose_quat.shape
assert in_shape[0] == 1, "Only support single sequence for now"
T = in_shape[1]
duration = (T - 1) * (1 / source_fps)
times = torch.arange(0, duration + 1e-6, 1 / target_fps, dtype=torch.float32, device=device)
times = times[times <= duration]
# Compute frame indices and blend factors
frame_indices = times * source_fps
index_0 = torch.floor(frame_indices).long()
index_1 = torch.min(index_0 + 1, torch.tensor(T - 1, device=device))
blend = frame_indices - index_0.float()
pose_quat_interp = _slerp_batch(pose_quat[0, index_0], pose_quat[0, index_1], blend)
pose_quat_interp = pose_quat_interp.unsqueeze(0)
return pose_quat_interp
def interpolate_pose(
pose_aa: torch.Tensor,
source_fps: float,
target_fps: float,
device: str = "cpu",
interpolation_type: str = "slerp",
rot_type: str = "aa",
) -> torch.Tensor:
"""
Interpolate pose_aa from source_fps to target_fps using specified interpolation method.
Args:
pose_aa: Input pose in angle-axis format, shape (T, N*3) where T is number of frames and N is number of joints
source_fps: Source frame rate
target_fps: Target frame rate
device: Device to run computations on
interpolation_type: Type of interpolation to use ("linear" or "slerp")
Returns:
Interpolated pose_aa with new frame rate, shape (T_new, N*3)
"""
# pose_aa: (T, N*3)
orig_shape = pose_aa.shape[1:]
if pose_aa.ndim != 2:
pose_aa = pose_aa.reshape(pose_aa.shape[0], -1)
T, D = pose_aa.shape
if interpolation_type == "linear":
# Direct linear interpolation on angle-axis representation
duration = (T - 1) * (1 / source_fps)
times = torch.arange(0, duration + 1e-6, 1 / target_fps, dtype=torch.float32, device=device)
times = times[times <= duration]
# Compute frame indices and blend factors for linear interpolation
frame_indices = times * source_fps
index_0 = torch.floor(frame_indices).long()
index_1 = torch.min(index_0 + 1, torch.tensor(T - 1, device=device))
blend = frame_indices - index_0.float()
# Linear interpolation on the entire 2D tensor
pose_aa_interp = (1 - blend.unsqueeze(1)) * pose_aa[index_0] + blend.unsqueeze(1) * pose_aa[
index_1
]
pose_aa_interp = pose_aa_interp.view(pose_aa_interp.shape[0], *orig_shape)
if pose_aa.dtype == torch.int64:
pose_aa_interp = pose_aa_interp.round()
pose_aa_interp = pose_aa_interp.type_as(pose_aa)
return pose_aa_interp
elif interpolation_type == "slerp":
dim = 3 if rot_type == "aa" else 4
N = D // dim
pose_aa_reshaped = pose_aa.view(T, N, dim)
# Original spherical linear interpolation on quaternions
pose_aa_interp_list = []
for i in range(N):
# Convert angle-axis to quaternion for this joint
if rot_type == "aa":
pose_quat = angle_axis_to_quaternion(pose_aa_reshaped[:, i]) # (T, 4)
else:
pose_quat = pose_aa_reshaped[:, i]
pose_quat_batch = pose_quat.unsqueeze(0) # (1, T, 4)
pose_quat_interp = interpolate_quaternions(pose_quat_batch, source_fps, target_fps)
pose_quat_interp = pose_quat_interp[0] # (T_new, 4)
if rot_type == "aa":
pose_aa_interp = quaternion_to_angle_axis(pose_quat_interp) # (T_new, 3)
else:
pose_aa_interp = pose_quat_interp
pose_aa_interp_list.append(pose_aa_interp)
# Concatenate all joints: (T_new, N, 3) -> (T_new, N*3)
pose_aa_interp = torch.stack(pose_aa_interp_list, dim=1) # (T_new, N, 3)
pose_aa_interp = pose_aa_interp.view(pose_aa_interp.shape[0], -1) # (T_new, N*3)
pose_aa_interp = pose_aa_interp.view(pose_aa_interp.shape[0], *orig_shape).to(pose_aa)
return pose_aa_interp
else:
raise ValueError(
f"Unsupported interpolation_type: {interpolation_type}. Must be 'linear' or 'slerp'."
)