fastvideo.attention.backends.video_sparse_attn
#
Module Contents#
Classes#
Functions#
Compute the number of valid (non‑padded) tokens inside every
(ts_t × ts_h × ts_w) tile after padding ‑‑ flattened in the order
(t‑tile, h‑tile, w‑tile) that |
|
Data#
API#
- class fastvideo.attention.backends.video_sparse_attn.VideoSparseAttentionBackend[source]#
Bases:
fastvideo.attention.backends.abstract.AttentionBackend
- static get_builder_cls() type[fastvideo.attention.backends.video_sparse_attn.VideoSparseAttentionMetadataBuilder] [source]#
- static get_impl_cls() type[fastvideo.attention.backends.video_sparse_attn.VideoSparseAttentionImpl] [source]#
- static get_metadata_cls() type[fastvideo.attention.backends.video_sparse_attn.VideoSparseAttentionMetadata] [source]#
- class fastvideo.attention.backends.video_sparse_attn.VideoSparseAttentionImpl(num_heads: int, head_size: int, causal: bool, softmax_scale: float, num_kv_heads: int | None = None, prefix: str = '', **extra_impl_args)[source]#
Bases:
fastvideo.attention.backends.abstract.AttentionImpl
- forward(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, gate_compress: torch.Tensor, attn_metadata: fastvideo.attention.backends.video_sparse_attn.VideoSparseAttentionMetadata) torch.Tensor [source]#
- postprocess_output(output: torch.Tensor, attn_metadata: fastvideo.attention.backends.video_sparse_attn.VideoSparseAttentionMetadata) torch.Tensor [source]#
- preprocess_qkv(qkv: torch.Tensor, attn_metadata: fastvideo.attention.backends.video_sparse_attn.VideoSparseAttentionMetadata) torch.Tensor [source]#
- tile(x: torch.Tensor, num_tiles: list[int], tile_partition_indices: torch.LongTensor, non_pad_index: torch.LongTensor) torch.Tensor [source]#
- untile(x: torch.Tensor, reverse_tile_partition_indices: torch.LongTensor, non_pad_index: torch.LongTensor) torch.Tensor [source]#
- class fastvideo.attention.backends.video_sparse_attn.VideoSparseAttentionMetadata[source]#
Bases:
fastvideo.attention.backends.abstract.AttentionMetadata
- class fastvideo.attention.backends.video_sparse_attn.VideoSparseAttentionMetadataBuilder[source]#
Bases:
fastvideo.attention.backends.abstract.AttentionMetadataBuilder
- fastvideo.attention.backends.video_sparse_attn.construct_variable_block_sizes(dit_seq_shape: tuple[int, int, int], num_tiles: tuple[int, int, int], device: torch.device) torch.LongTensor [source]#
Compute the number of valid (non‑padded) tokens inside every (ts_t × ts_h × ts_w) tile after padding ‑‑ flattened in the order (t‑tile, h‑tile, w‑tile) that
rearrange
uses.- Returns:
torch.LongTensor # shape
- Return type:
[∏ full_window_size]
- fastvideo.attention.backends.video_sparse_attn.get_non_pad_index(variable_block_sizes: torch.LongTensor, max_block_size: int)[source]#
- fastvideo.attention.backends.video_sparse_attn.get_reverse_tile_partition_indices(dit_seq_shape: tuple[int, int, int], tile_size: tuple[int, int, int], device: torch.device) torch.LongTensor [source]#