You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 

286 lines
11 KiB

import numpy as np
import torch
import os
import h5py
from torch.utils.data import TensorDataset, DataLoader
import time
import IPython
e = IPython.embed
from pathlib import Path
class EpisodicDataset(torch.utils.data.Dataset):
def __init__(self, episode_ids, dataset_dir, camera_names, norm_stats, episode_len, history_stack=0):
super(EpisodicDataset).__init__()
self.episode_ids = episode_ids
self.dataset_dir = dataset_dir
self.camera_names = camera_names
self.norm_stats = norm_stats
self.is_sim = None
self.max_pad_len = 200
action_str = 'qpos_action'
self.history_stack = history_stack
self.dataset_paths = []
self.roots = []
self.is_sims = []
self.original_action_shapes = []
self.states = []
self.image_dict = dict()
for cam_name in self.camera_names:
self.image_dict[cam_name] = []
self.actions = []
for i, episode_id in enumerate(self.episode_ids):
self.dataset_paths.append(os.path.join(self.dataset_dir, f'processed_episode_{episode_id}.hdf5'))
root = h5py.File(self.dataset_paths[i], 'r')
self.roots.append(root)
self.is_sims.append(root.attrs['sim'])
self.original_action_shapes.append(root[action_str].shape)
self.states.append(np.array(root['observation.state']))
for cam_name in self.camera_names:
self.image_dict[cam_name].append(root[f'observation.image.{cam_name}'])
self.actions.append(np.array(root[action_str]))
self.is_sim = self.is_sims[0]
self.episode_len = episode_len
self.cumulative_len = np.cumsum(self.episode_len)
# self.__getitem__(0) # initialize self.is_sim
# def __len__(self):
# return len(self.episode_ids)
def _locate_transition(self, index):
assert index < self.cumulative_len[-1]
episode_index = np.argmax(self.cumulative_len > index) # argmax returns first True index
start_ts = index - (self.cumulative_len[episode_index] - self.episode_len[episode_index])
return episode_index, start_ts
def __getitem__(self, ts_index):
sample_full_episode = False # hardcode
index, start_ts = self._locate_transition(ts_index)
original_action_shape = self.original_action_shapes[index]
episode_len = original_action_shape[0]
if sample_full_episode:
start_ts = 0
else:
start_ts = np.random.choice(episode_len)
# get observation at start_ts only
qpos = self.states[index][start_ts]
# qvel = root['/observations/qvel'][start_ts]
if self.history_stack > 0:
last_indices = np.maximum(0, np.arange(start_ts-self.history_stack, start_ts)).astype(int)
last_action = self.actions[index][last_indices, :]
image_dict = dict()
for cam_name in self.camera_names:
image_dict[cam_name] = self.image_dict[cam_name][index][start_ts]
# get all actions after and including start_ts
all_time_action = self.actions[index][:]
all_time_action_padded = np.zeros((self.max_pad_len+original_action_shape[0], original_action_shape[1]), dtype=np.float32)
all_time_action_padded[:episode_len] = all_time_action
all_time_action_padded[episode_len:] = all_time_action[-1]
padded_action = all_time_action_padded[start_ts:start_ts+self.max_pad_len]
real_len = episode_len - start_ts
is_pad = np.zeros(self.max_pad_len)
is_pad[real_len:] = 1
# new axis for different cameras
all_cam_images = []
for cam_name in self.camera_names:
all_cam_images.append(image_dict[cam_name])
all_cam_images = np.stack(all_cam_images, axis=0)
# construct observations
image_data = torch.from_numpy(all_cam_images)
qpos_data = torch.from_numpy(qpos).float()
action_data = torch.from_numpy(padded_action).float()
is_pad = torch.from_numpy(is_pad).bool()
if self.history_stack > 0:
last_action_data = torch.from_numpy(last_action).float()
# normalize image and change dtype to float
image_data = image_data / 255.0
action_data = (action_data - self.norm_stats["action_mean"]) / self.norm_stats["action_std"]
qpos_data = (qpos_data - self.norm_stats["qpos_mean"]) / self.norm_stats["qpos_std"]
if self.history_stack > 0:
last_action_data = (last_action_data - self.norm_stats['action_mean']) / self.norm_stats['action_std']
qpos_data = torch.cat((qpos_data, last_action_data.flatten()))
# print(f"qpos_data: {qpos_data.shape}, action_data: {action_data.shape}, image_data: {image_data.shape}, is_pad: {is_pad.shape}")
return image_data, qpos_data, action_data, is_pad
def get_norm_stats(dataset_dir, num_episodes):
action_str = 'qpos_action'
all_qpos_data = []
all_action_data = []
all_episode_len = []
for episode_idx in range(num_episodes):
dataset_path = os.path.join(dataset_dir, f'processed_episode_{episode_idx}.hdf5')
with h5py.File(dataset_path, 'r') as root:
qpos = root['observation.state'][()]
action = root[action_str][()]
all_qpos_data.append(torch.from_numpy(qpos))
all_action_data.append(torch.from_numpy(action))
all_episode_len.append(len(qpos))
all_qpos_data = torch.cat(all_qpos_data)
all_action_data = torch.cat(all_action_data)
all_action_data = all_action_data
# normalize action data
action_mean = all_action_data.mean(dim=0, keepdim=True) # (episode, timstep, action_dim)
action_std = all_action_data.std(dim=0, keepdim=True)
action_std = torch.clip(action_std, 1e-2, np.inf) # clipping
# normalize qpos data
qpos_mean = all_qpos_data.mean(dim=0, keepdim=True)
qpos_std = all_qpos_data.std(dim=0, keepdim=True)
qpos_std = torch.clip(qpos_std, 1e-2, np.inf) # clipping
stats = {"action_mean": action_mean.numpy().squeeze(), "action_std": action_std.numpy().squeeze(),
"qpos_mean": qpos_mean.numpy().squeeze(), "qpos_std": qpos_std.numpy().squeeze(),
"example_qpos": qpos}
return stats, all_episode_len
def find_all_processed_episodes(path):
episodes = [f for f in os.listdir(path)]
return episodes
def BatchSampler(batch_size, episode_len_l, sample_weights=None):
sample_probs = np.array(sample_weights) / np.sum(sample_weights) if sample_weights is not None else None
sum_dataset_len_l = np.cumsum([0] + [np.sum(episode_len) for episode_len in episode_len_l])
while True:
batch = []
for _ in range(batch_size):
episode_idx = np.random.choice(len(episode_len_l), p=sample_probs)
step_idx = np.random.randint(sum_dataset_len_l[episode_idx], sum_dataset_len_l[episode_idx + 1])
batch.append(step_idx)
yield batch
def load_data(dataset_dir, camera_names, batch_size_train, batch_size_val):
print(f'\nData from: {dataset_dir}\n')
all_eps = find_all_processed_episodes(dataset_dir)
num_episodes = len(all_eps)
# obtain train test split
train_ratio = 0.99
shuffled_indices = np.random.permutation(num_episodes)
train_indices = shuffled_indices[:int(train_ratio * num_episodes)]
val_indices = shuffled_indices[int(train_ratio * num_episodes):]
print(f"Train episodes: {len(train_indices)}, Val episodes: {len(val_indices)}")
# obtain normalization stats for qpos and action
norm_stats, all_episode_len = get_norm_stats(dataset_dir, num_episodes)
train_episode_len_l = [all_episode_len[i] for i in train_indices]
val_episode_len_l = [all_episode_len[i] for i in val_indices]
batch_sampler_train = BatchSampler(batch_size_train, train_episode_len_l)
batch_sampler_val = BatchSampler(batch_size_val, val_episode_len_l, None)
# construct dataset and dataloader
train_dataset = EpisodicDataset(train_indices, dataset_dir, camera_names, norm_stats, train_episode_len_l)
val_dataset = EpisodicDataset(val_indices, dataset_dir, camera_names, norm_stats, val_episode_len_l)
train_dataloader = DataLoader(train_dataset, batch_sampler=batch_sampler_train, pin_memory=True, num_workers=24, prefetch_factor=2)
val_dataloader = DataLoader(val_dataset, batch_sampler=batch_sampler_val, pin_memory=True, num_workers=16, prefetch_factor=2)
return train_dataloader, val_dataloader, norm_stats, train_dataset.is_sim
def sample_box_pose():
x_range = [0.0, 0.2]
y_range = [0.4, 0.6]
z_range = [0.05, 0.05]
ranges = np.vstack([x_range, y_range, z_range])
cube_position = np.random.uniform(ranges[:, 0], ranges[:, 1])
cube_quat = np.array([1, 0, 0, 0])
return np.concatenate([cube_position, cube_quat])
def sample_insertion_pose():
# Peg
x_range = [0.1, 0.2]
y_range = [0.4, 0.6]
z_range = [0.05, 0.05]
ranges = np.vstack([x_range, y_range, z_range])
peg_position = np.random.uniform(ranges[:, 0], ranges[:, 1])
peg_quat = np.array([1, 0, 0, 0])
peg_pose = np.concatenate([peg_position, peg_quat])
# Socket
x_range = [-0.2, -0.1]
y_range = [0.4, 0.6]
z_range = [0.05, 0.05]
ranges = np.vstack([x_range, y_range, z_range])
socket_position = np.random.uniform(ranges[:, 0], ranges[:, 1])
socket_quat = np.array([1, 0, 0, 0])
socket_pose = np.concatenate([socket_position, socket_quat])
return peg_pose, socket_pose
### helper functions
def compute_dict_mean(epoch_dicts):
result = {k: None for k in epoch_dicts[0]}
num_items = len(epoch_dicts)
for k in result:
value_sum = 0
for epoch_dict in epoch_dicts:
value_sum += epoch_dict[k]
result[k] = value_sum / num_items
return result
def detach_dict(d):
new_d = dict()
for k, v in d.items():
new_d[k] = v.detach()
return new_d
def set_seed(seed):
torch.manual_seed(seed)
np.random.seed(seed)
def parse_id(base_dir, prefix):
base_path = Path(base_dir)
# Ensure the base path exists and is a directory
if not base_path.exists() or not base_path.is_dir():
raise ValueError(f"The provided base directory does not exist or is not a directory: \n{base_path}")
# Loop through all subdirectories of the base path
for subfolder in base_path.iterdir():
if subfolder.is_dir() and subfolder.name.startswith(prefix):
return str(subfolder), subfolder.name
# If no matching subfolder is found
return None, None
def find_all_ckpt(base_dir, prefix="policy_epoch_"):
base_path = Path(base_dir)
# Ensure the base path exists and is a directory
if not base_path.exists() or not base_path.is_dir():
raise ValueError("The provided base directory does not exist or is not a directory.")
ckpt_files = []
for file in base_path.iterdir():
if file.is_file() and file.name.startswith(prefix):
ckpt_files.append(file.name)
# find latest ckpt
ckpt_files = sorted(ckpt_files, key=lambda x: int(x.split(prefix)[-1].split('_')[0]), reverse=True)
epoch = int(ckpt_files[0].split(prefix)[-1].split('_')[0])
return ckpt_files[0], epoch