You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

29 lines
1.3 KiB

"""RL-specific helpers: episode attention masks and (legacy) schedule utilities."""
import torch
def compute_episode_attnmask(dones):
"""
Compute an attention mask that prevents the model from attending to observations from different episodes.
Args:
dones (torch.Tensor): A tensor of shape (num_envs, num_steps) indicating when each environment episode ends.
A value of 1.0 indicates the end of an episode.
Returns:
torch.Tensor: An attention mask of shape (num_envs, num_steps, num_steps) where True values indicate
positions that should be masked (i.e., the model should not attend to these positions).
"""
# Create cumulative sum of dones to identify different episodes
episode_starts = torch.roll(dones, 1, dims=1)
episode_starts[:, 0] = True # First step is always start of an episode
episode_ids = torch.cumsum(episode_starts, dim=1) # (num_envs, num_steps)
# Expand episode_ids for broadcasting
episode_ids_i = episode_ids.unsqueeze(2) # (num_envs, num_steps, 1)
episode_ids_j = episode_ids.unsqueeze(1) # (num_envs, 1, num_steps)
# Create mask where True indicates positions from different episodes
attnmask = episode_ids_i != episode_ids_j
return attnmask