fastvideo.attention.backends.video_sparse_attn#

Module Contents#

Classes#

Functions#

construct_variable_block_sizes

Compute the number of valid (non‑padded) tokens inside every (ts_t × ts_h × ts_w) tile after padding ‑‑ flattened in the order (t‑tile, h‑tile, w‑tile) that rearrange uses.

get_non_pad_index

get_reverse_tile_partition_indices

get_tile_partition_indices

Data#

API#

fastvideo.attention.backends.video_sparse_attn.VSA_TILE_SIZE[source]#

(4, 4, 4)

class fastvideo.attention.backends.video_sparse_attn.VideoSparseAttentionBackend[source]#

Bases: fastvideo.attention.backends.abstract.AttentionBackend

accept_output_buffer: bool[source]#

True

static get_builder_cls() type[fastvideo.attention.backends.video_sparse_attn.VideoSparseAttentionMetadataBuilder][source]#
static get_impl_cls() type[fastvideo.attention.backends.video_sparse_attn.VideoSparseAttentionImpl][source]#
static get_metadata_cls() type[fastvideo.attention.backends.video_sparse_attn.VideoSparseAttentionMetadata][source]#
static get_name() str[source]#
static get_supported_head_sizes() list[int][source]#
class fastvideo.attention.backends.video_sparse_attn.VideoSparseAttentionImpl(num_heads: int, head_size: int, causal: bool, softmax_scale: float, num_kv_heads: int | None = None, prefix: str = '', **extra_impl_args)[source]#

Bases: fastvideo.attention.backends.abstract.AttentionImpl

forward(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, gate_compress: torch.Tensor, attn_metadata: fastvideo.attention.backends.video_sparse_attn.VideoSparseAttentionMetadata) torch.Tensor[source]#
postprocess_output(output: torch.Tensor, attn_metadata: fastvideo.attention.backends.video_sparse_attn.VideoSparseAttentionMetadata) torch.Tensor[source]#
preprocess_qkv(qkv: torch.Tensor, attn_metadata: fastvideo.attention.backends.video_sparse_attn.VideoSparseAttentionMetadata) torch.Tensor[source]#
tile(x: torch.Tensor, num_tiles: list[int], tile_partition_indices: torch.LongTensor, non_pad_index: torch.LongTensor) torch.Tensor[source]#
untile(x: torch.Tensor, reverse_tile_partition_indices: torch.LongTensor, non_pad_index: torch.LongTensor) torch.Tensor[source]#
class fastvideo.attention.backends.video_sparse_attn.VideoSparseAttentionMetadata[source]#

Bases: fastvideo.attention.backends.abstract.AttentionMetadata

VSA_sparsity: float[source]#

None

current_timestep: int[source]#

None

dit_seq_shape: list[int][source]#

None

non_pad_index: torch.LongTensor[source]#

None

num_tiles: list[int][source]#

None

reverse_tile_partition_indices: torch.LongTensor[source]#

None

tile_partition_indices: torch.LongTensor[source]#

None

total_seq_length: int[source]#

None

variable_block_sizes: torch.LongTensor[source]#

None

class fastvideo.attention.backends.video_sparse_attn.VideoSparseAttentionMetadataBuilder[source]#

Bases: fastvideo.attention.backends.abstract.AttentionMetadataBuilder

build(current_timestep: int, raw_latent_shape: tuple[int, int, int], patch_size: tuple[int, int, int], VSA_sparsity: float, device: torch.device, **kwargs: dict[str, Any]) fastvideo.attention.backends.video_sparse_attn.VideoSparseAttentionMetadata[source]#
prepare()[source]#
fastvideo.attention.backends.video_sparse_attn.construct_variable_block_sizes(dit_seq_shape: tuple[int, int, int], num_tiles: tuple[int, int, int], device: torch.device) torch.LongTensor[source]#

Compute the number of valid (non‑padded) tokens inside every (ts_t × ts_h × ts_w) tile after padding ‑‑ flattened in the order (t‑tile, h‑tile, w‑tile) that rearrange uses.

Returns:

torch.LongTensor # shape

Return type:

[∏ full_window_size]

fastvideo.attention.backends.video_sparse_attn.get_non_pad_index(variable_block_sizes: torch.LongTensor, max_block_size: int)[source]#
fastvideo.attention.backends.video_sparse_attn.get_reverse_tile_partition_indices(dit_seq_shape: tuple[int, int, int], tile_size: tuple[int, int, int], device: torch.device) torch.LongTensor[source]#
fastvideo.attention.backends.video_sparse_attn.get_tile_partition_indices(dit_seq_shape: tuple[int, int, int], tile_size: tuple[int, int, int], device: torch.device) torch.LongTensor[source]#
fastvideo.attention.backends.video_sparse_attn.logger[source]#

‘init_logger(…)’