fastvideo.attention.backends.vmoba
#
Module Contents#
Classes#
Data#
API#
- class fastvideo.attention.backends.vmoba.VMOBAAttentionBackend[source]#
Bases:
fastvideo.attention.backends.abstract.AttentionBackend
- static get_builder_cls() type[fastvideo.attention.backends.vmoba.VideoMobaAttentionMetadataBuilder] [source]#
- static get_impl_cls() type[fastvideo.attention.backends.vmoba.VMOBAAttentionImpl] [source]#
- static get_metadata_cls() type[fastvideo.attention.backends.vmoba.VideoMobaAttentionMetadata] [source]#
- class fastvideo.attention.backends.vmoba.VMOBAAttentionImpl(num_heads, head_size, softmax_scale, causal=False, num_kv_heads=None, prefix='', **extra_impl_args)[source]#
Bases:
fastvideo.attention.backends.abstract.AttentionImpl
- forward(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attn_metadata: fastvideo.attention.backends.abstract.AttentionMetadata) torch.Tensor [source]#
query: [B, L, H, D] key: [B, L, H, D] value: [B, L, H, D] attn_metadata: AttentionMetadata
- class fastvideo.attention.backends.vmoba.VideoMobaAttentionMetadata[source]#
Bases:
fastvideo.attention.backends.abstract.AttentionMetadata
- class fastvideo.attention.backends.vmoba.VideoMobaAttentionMetadataBuilder[source]#
Bases:
fastvideo.attention.backends.abstract.AttentionMetadataBuilder
- build(current_timestep: int, raw_latent_shape: tuple[int, int, int], patch_size: tuple[int, int, int], temporal_chunk_size: int, temporal_topk: int, spatial_chunk_size: tuple[int, int], spatial_topk: int, st_chunk_size: tuple[int, int, int], st_topk: int, moba_select_mode: str = 'threshold', moba_threshold: float = 0.25, moba_threshold_type: str = 'query_head', device: torch.device = None, first_full_layer: int = 0, first_full_step: int = 12, temporal_layer: int = 1, spatial_layer: int = 1, st_layer: int = 1, **kwargs) fastvideo.attention.backends.vmoba.VideoMobaAttentionMetadata [source]#