Skip to content

mask_utils

Functions

fastvideo.models.mask_utils.and_masks

and_masks(*mask_functions: Callable) -> Callable

Returns a mask function that is the intersection of provided mask functions

Source code in fastvideo/models/mask_utils.py
def and_masks(*mask_functions: Callable) -> Callable:
    """Returns a mask function that is the intersection of provided mask functions"""
    if not all(callable(arg) for arg in mask_functions):
        raise RuntimeError(f"All inputs should be callable mask_functions: {mask_functions}")

    def and_mask(batch_idx, head_idx, q_idx, kv_idx):
        result = q_idx.new_ones((), dtype=torch.bool)
        for mask in mask_functions:
            result = result & mask(batch_idx, head_idx, q_idx, kv_idx).to(result.device)
        return result

    return and_mask

fastvideo.models.mask_utils.causal_mask_function

causal_mask_function(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool

This creates a basic lower-diagonal causal mask.

Source code in fastvideo/models/mask_utils.py
def causal_mask_function(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool:
    """
    This creates a basic lower-diagonal causal mask.
    """
    return kv_idx <= q_idx

fastvideo.models.mask_utils.padding_mask_function

padding_mask_function(padding_mask: Tensor) -> Callable

This return the mask_function function corresponding to a 2D padding mask.

Source code in fastvideo/models/mask_utils.py
def padding_mask_function(padding_mask: torch.Tensor) -> Callable:
    """
    This return the mask_function function corresponding to a 2D padding mask.
    """

    def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool:
        # Note that here the mask should ALWAYS be at least of the max `kv_index` size in the dimension 1. This is because
        # we cannot pad it here in the mask_function as we don't know the final size, and we cannot try/except, as it is not
        # vectorizable on accelerator devices
        return padding_mask[batch_idx, kv_idx]

    return inner_mask

fastvideo.models.mask_utils.prepare_padding_mask

prepare_padding_mask(attention_mask: Optional[Tensor], kv_length: int, kv_offset: int) -> Optional[Tensor]

From the 2D attention mask, prepare the correct padding mask to use by potentially padding it.

Source code in fastvideo/models/mask_utils.py
def prepare_padding_mask(
    attention_mask: Optional[torch.Tensor], kv_length: int, kv_offset: int
) -> Optional[torch.Tensor]:
    """
    From the 2D attention mask, prepare the correct padding mask to use by potentially padding it.
    """
    local_padding_mask = attention_mask
    if attention_mask is not None:
        # Pad it if necessary
        if (padding_length := kv_length + kv_offset - attention_mask.shape[-1]) > 0:
            local_padding_mask = torch.nn.functional.pad(attention_mask, (0, padding_length))
    return local_padding_mask

fastvideo.models.mask_utils.sdpa_mask

sdpa_mask(batch_size: int, cache_position: Tensor, kv_length: int, kv_offset: int = 0, mask_function: Callable = causal_mask_function, attention_mask: Optional[Tensor] = None, local_size: Optional[int] = None, allow_is_causal_skip: bool = True, allow_is_bidirectional_skip: bool = False, allow_torch_fix: bool = True, use_vmap: bool = False, **kwargs) -> Optional[Tensor]

Create a 4D boolean mask of shape (batch_size, 1, query_length, kv_length) where a value of True indicates that the element should take part in the attention computation, and False that it should not. This function can only be used with torch>=2.5, as the context manager is otherwise not available.

Parameters:

Name Type Description Default
batch_size `int`

The batch size of the input sequence.

required
cache_position `torch.Tensor`

A tensor of shape (query_length,) indicating the current indices of the input sequence elements.

required
kv_length `int`

The size that the key and value states will have during the attention computation.

required
kv_offset `int`

An optional offset to indicate at which first position the key and values states will refer to.

0
mask_function `Callable`

The mask factory function describing the mask pattern.

causal_mask_function
attention_mask `torch.Tensor`

The 2D attention mask corresponding to padded tokens of shape (batch_size, number_of_seen_tokens+q_length)

None
local_size `int`

The size of the local attention, if we do not use full attention. This is used only if allow_is_causal_skip=True to try to skip mask creation if possible.

None
allow_is_causal_skip `bool`

Whether to allow to return None for the mask under conditions where we can use the is_causal argument in torch.sdpa instead. Default to True.

True
allow_is_bidirectional_skip `bool`

Whether to allow to return None for the mask under conditions where we do not have to add any bias, i.e. full attention without any padding. Default to False.

False
allow_torch_fix `bool`

Whether to update the mask in case a query is not attending to any tokens, to solve a bug in torch's older versions. We need an arg to skip it when using eager. By default True.

True
use_vmap `bool`

Whether to use vmap during the mask construction or not. Allows powerful custom patterns that may not be index-based (for the cost of speed performance). By default False.

False
Creating a simple causal mask:

To create the following causal mask:

0 ■ ⬚ ⬚ ⬚ ⬚
1 ■ ■ ⬚ ⬚ ⬚
2 ■ ■ ■ ⬚ ⬚
3 ■ ■ ■ ■ ⬚
4 ■ ■ ■ ■ ■

You can do

>>> sdpa_mask(batch_size=1, cache_position=torch.arange(5), kv_length=5)
>>> tensor([[[[ True, False, False, False, False],
              [ True,  True, False, False, False],
              [ True,  True,  True, False, False],
              [ True,  True,  True,  True, False],
              [ True,  True,  True,  True,  True]]]])
Creating a sliding window mask:

To create the following sliding window mask (sliding_window=3):

0 ■ ⬚ ⬚ ⬚ ⬚
1 ■ ■ ⬚ ⬚ ⬚
2 ■ ■ ■ ⬚ ⬚
3 ⬚ ■ ■ ■ ⬚
4 ⬚ ⬚ ■ ■ ■

You can do

>>> sdpa_mask(batch_size=1, cache_position=torch.arange(5), kv_length=5, mask_function=sliding_window_causal_mask_function(3))
>>> tensor([[[[ True, False, False, False, False],
              [ True,  True, False, False, False],
              [ True,  True,  True, False, False],
              [False,  True,  True,  True, False],
              [False, False,  True,  True,  True]]]])
Creating a chunked attention mask

To create the following chunked attention mask (chunk_size=3):

0 ■ ⬚ ⬚ ⬚ ⬚
1 ■ ■ ⬚ ⬚ ⬚
2 ■ ■ ■ ⬚ ⬚
3 ⬚ ⬚ ⬚ ■ ⬚
4 ⬚ ⬚ ⬚ ■ ■

You can do

>>> sdpa_mask(batch_size=1, cache_position=torch.arange(5), kv_length=5, mask_function=chunked_causal_mask_function(3, torch.zeros(1, dtype=int)))
>>> tensor([[[[ True, False, False, False, False],
            [ True,  True, False, False, False],
            [ True,  True,  True, False, False],
            [False, False, False,  True, False],
            [False, False, False,  True,  True]]]])
Source code in fastvideo/models/mask_utils.py
def sdpa_mask(
    batch_size: int,
    cache_position: torch.Tensor,
    kv_length: int,
    kv_offset: int = 0,
    mask_function: Callable = causal_mask_function,
    attention_mask: Optional[torch.Tensor] = None,
    local_size: Optional[int] = None,
    allow_is_causal_skip: bool = True,
    allow_is_bidirectional_skip: bool = False,
    allow_torch_fix: bool = True,
    use_vmap: bool = False,
    **kwargs,
) -> Optional[torch.Tensor]:
    """
    Create a 4D boolean mask of shape `(batch_size, 1, query_length, kv_length)` where a value of True indicates that
    the element should take part in the attention computation, and False that it should not.
    This function can only be used with torch>=2.5, as the context manager is otherwise not available.

    Args:
        batch_size (`int`):
            The batch size of the input sequence.
        cache_position (`torch.Tensor`):
            A tensor of shape (query_length,) indicating the current indices of the input sequence elements.
        kv_length (`int`):
            The size that the key and value states will have during the attention computation.
        kv_offset (`int`, optional):
            An optional offset to indicate at which first position the key and values states will refer to.
        mask_function (`Callable`):
            The mask factory function describing the mask pattern.
        attention_mask (`torch.Tensor`, optional):
            The 2D attention mask corresponding to padded tokens of shape (batch_size, number_of_seen_tokens+q_length)
        local_size (`int`, optional):
            The size of the local attention, if we do not use full attention. This is used only if `allow_is_causal_skip=True`
            to try to skip mask creation if possible.
        allow_is_causal_skip (`bool`, optional):
            Whether to allow to return `None` for the mask under conditions where we can use the `is_causal` argument in
            `torch.sdpa` instead. Default to `True`.
        allow_is_bidirectional_skip (`bool`, optional):
            Whether to allow to return `None` for the mask under conditions where we do not have to add any bias,
            i.e. full attention without any padding. Default to `False`.
        allow_torch_fix (`bool`, optional):
            Whether to update the mask in case a query is not attending to any tokens, to solve a bug in torch's older
            versions. We need an arg to skip it when using eager. By default `True`.
        use_vmap (`bool`, optional):
            Whether to use `vmap` during the mask construction or not. Allows powerful custom patterns that may not be
            index-based (for the cost of speed performance). By default `False`.


    ## Creating a simple causal mask:

    To create the following causal mask:

        0 ■ ⬚ ⬚ ⬚ ⬚
        1 ■ ■ ⬚ ⬚ ⬚
        2 ■ ■ ■ ⬚ ⬚
        3 ■ ■ ■ ■ ⬚
        4 ■ ■ ■ ■ ■

    You can do

    ```python
    >>> sdpa_mask(batch_size=1, cache_position=torch.arange(5), kv_length=5)
    >>> tensor([[[[ True, False, False, False, False],
                  [ True,  True, False, False, False],
                  [ True,  True,  True, False, False],
                  [ True,  True,  True,  True, False],
                  [ True,  True,  True,  True,  True]]]])
    ```

    ## Creating a sliding window mask:

    To create the following sliding window mask (`sliding_window=3`):

        0 ■ ⬚ ⬚ ⬚ ⬚
        1 ■ ■ ⬚ ⬚ ⬚
        2 ■ ■ ■ ⬚ ⬚
        3 ⬚ ■ ■ ■ ⬚
        4 ⬚ ⬚ ■ ■ ■

    You can do

    ```python
    >>> sdpa_mask(batch_size=1, cache_position=torch.arange(5), kv_length=5, mask_function=sliding_window_causal_mask_function(3))
    >>> tensor([[[[ True, False, False, False, False],
                  [ True,  True, False, False, False],
                  [ True,  True,  True, False, False],
                  [False,  True,  True,  True, False],
                  [False, False,  True,  True,  True]]]])
    ```

    ## Creating a chunked attention mask

    To create the following chunked attention mask (`chunk_size=3`):

        0 ■ ⬚ ⬚ ⬚ ⬚
        1 ■ ■ ⬚ ⬚ ⬚
        2 ■ ■ ■ ⬚ ⬚
        3 ⬚ ⬚ ⬚ ■ ⬚
        4 ⬚ ⬚ ⬚ ■ ■

    You can do

    ```python
    >>> sdpa_mask(batch_size=1, cache_position=torch.arange(5), kv_length=5, mask_function=chunked_causal_mask_function(3, torch.zeros(1, dtype=int)))
    >>> tensor([[[[ True, False, False, False, False],
                [ True,  True, False, False, False],
                [ True,  True,  True, False, False],
                [False, False, False,  True, False],
                [False, False, False,  True,  True]]]])
    ```

    """
    q_length = cache_position.shape[0]

    # Potentially pad the 2D mask
    padding_mask = prepare_padding_mask(attention_mask, kv_length, kv_offset)

    # Potentially add the padding 2D mask
    if padding_mask is not None:
        mask_function = and_masks(mask_function, padding_mask_function(padding_mask))

    batch_arange = torch.arange(batch_size, device=cache_position.device)
    head_arange = torch.arange(1, device=cache_position.device)
    # Similar to `kv_arange = torch.arange(start=kv_offset, end=kv_offset + kv_length, device=cache_position.device)`
    # but without data-dependent slicing (i.e. torch.compile friendly)
    kv_arange = torch.arange(kv_length, device=cache_position.device) + kv_offset

    # Actual mask creation
    # Apply mask function element-wise through broadcasting
    attention_mask = mask_function(*_non_vmap_expansion_sdpa(batch_arange, head_arange, cache_position, kv_arange))
    # Expand the mask to match batch size and query length if they weren't used in the mask function
    attention_mask = attention_mask.expand(batch_size, -1, q_length, kv_length)

    return attention_mask