Skip to content

vmoba

Classes

fastvideo.attention.backends.vmoba.VMOBAAttentionImpl

VMOBAAttentionImpl(num_heads, head_size, softmax_scale, causal=False, num_kv_heads=None, prefix='', **extra_impl_args)

Bases: AttentionImpl

Source code in fastvideo/attention/backends/vmoba.py
def __init__(self,
             num_heads,
             head_size,
             softmax_scale,
             causal=False,
             num_kv_heads=None,
             prefix="",
             **extra_impl_args) -> None:
    self.prefix = prefix
    self.layer_idx = self._get_layer_idx(prefix)
    from flash_attn.bert_padding import pad_input
    self.pad_input = pad_input

Functions

fastvideo.attention.backends.vmoba.VMOBAAttentionImpl.forward
forward(query: Tensor, key: Tensor, value: Tensor, attn_metadata: AttentionMetadata) -> Tensor

query: [B, L, H, D] key: [B, L, H, D] value: [B, L, H, D] attn_metadata: AttentionMetadata

Source code in fastvideo/attention/backends/vmoba.py
def forward(
    self,
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    attn_metadata: AttentionMetadata,
) -> torch.Tensor:
    """
    query: [B, L, H, D]
    key:   [B, L, H, D]
    value: [B, L, H, D]
    attn_metadata: AttentionMetadata
    """
    batch_size, sequence_length, num_heads, head_dim = query.shape

    # select chunk type according to layer idx:
    loop_layer_num = attn_metadata.temporal_layer + attn_metadata.spatial_layer + attn_metadata.st_layer
    moba_layer = self.layer_idx - attn_metadata.first_full_layer
    if moba_layer % loop_layer_num < attn_metadata.temporal_layer:
        moba_chunk_size = attn_metadata.temporal_chunk_size
        moba_topk = attn_metadata.temporal_topk
    elif moba_layer % loop_layer_num < attn_metadata.temporal_layer + attn_metadata.spatial_layer:
        moba_chunk_size = attn_metadata.spatial_chunk_size
        moba_topk = attn_metadata.spatial_topk
    elif moba_layer % loop_layer_num < attn_metadata.temporal_layer + attn_metadata.spatial_layer + attn_metadata.st_layer:
        moba_chunk_size = attn_metadata.st_chunk_size
        moba_topk = attn_metadata.st_topk

    query, chunk_size = process_moba_input(query,
                                           attn_metadata.patch_resolution,
                                           moba_chunk_size)
    key, chunk_size = process_moba_input(key,
                                         attn_metadata.patch_resolution,
                                         moba_chunk_size)
    value, chunk_size = process_moba_input(value,
                                           attn_metadata.patch_resolution,
                                           moba_chunk_size)
    max_seqlen = query.shape[1]
    indices_q = torch.arange(0,
                             query.shape[0] * query.shape[1],
                             device=query.device)
    cu_seqlens = torch.arange(0,
                              query.shape[0] * query.shape[1] + 1,
                              query.shape[1],
                              dtype=torch.int32,
                              device=query.device)
    query = rearrange(query, "b s ... -> (b s) ...")
    key = rearrange(key, "b s ... -> (b s) ...")
    value = rearrange(value, "b s ... -> (b s) ...")

    # current_timestep=attn_metadata.current_timestep
    hidden_states = moba_attn_varlen(
        query,
        key,
        value,
        cu_seqlens=cu_seqlens,
        max_seqlen=max_seqlen,
        moba_chunk_size=chunk_size,
        moba_topk=moba_topk,
        select_mode=attn_metadata.moba_select_mode,
        simsum_threshold=attn_metadata.moba_threshold,
        threshold_type=attn_metadata.moba_threshold_type,
    )
    hidden_states = self.pad_input(hidden_states, indices_q, batch_size,
                                   sequence_length)
    hidden_states = process_moba_output(hidden_states,
                                        attn_metadata.patch_resolution,
                                        moba_chunk_size)

    return hidden_states

Functions