Skip to content

sla

Classes

fastvideo.attention.backends.sla.SLAAttentionBackend

Bases: AttentionBackend

Sparse-Linear Attention backend.

fastvideo.attention.backends.sla.SLAAttentionImpl

SLAAttentionImpl(num_heads: int, head_size: int, causal: bool = False, softmax_scale: float | None = None, num_kv_heads: int | None = None, prefix: str = '', topk_ratio: float = 0.1, feature_map: str = 'softmax', BLKQ: int = 128, BLKK: int = 64, use_bf16: bool = True, **extra_impl_args)

Bases: AttentionImpl, Module

SLA attention implementation with learnable linear projection.

This implementation combines sparse attention with linear attention, using a learnable projection to blend the outputs. The sparse attention uses a block-sparse pattern determined by QK similarity.

Parameters:

Name Type Description Default
num_heads int

Number of attention heads

required
head_size int

Dimension of each head

required
topk_ratio float

Ratio of key blocks to attend to (0-1), default 0.5

0.1
feature_map str

Feature map for linear attention ('softmax', 'elu', 'relu')

'softmax'
BLKQ int

Query block size for sparse attention

128
BLKK int

Key block size for sparse attention

64
use_bf16 bool

Whether to use bfloat16 for computation

True
Source code in fastvideo/attention/backends/sla.py
def __init__(
    self,
    num_heads: int,
    head_size: int,
    causal: bool = False,
    softmax_scale: float | None = None,
    num_kv_heads: int | None = None,
    prefix: str = "",
    # SLA-specific parameters - matched to TurboDiffusion defaults
    topk_ratio: float = 0.1,  # TurboDiffusion uses topk=0.1
    feature_map: str = "softmax",
    BLKQ: int = 128,  # TurboDiffusion uses BLKQ=128
    BLKK: int = 64,  # TurboDiffusion uses BLKK=64
    use_bf16: bool = True,
    **extra_impl_args,
) -> None:
    nn.Module.__init__(self)

    self.num_heads = num_heads
    self.head_size = head_size
    self.softmax_scale = softmax_scale if softmax_scale else head_size**-0.5
    self.causal = causal
    self.prefix = prefix

    # SLA-specific config
    self.topk_ratio = topk_ratio
    self.BLKQ = BLKQ
    self.BLKK = BLKK
    self.dtype = torch.bfloat16 if use_bf16 else torch.float16

    # Learnable linear projection for combining sparse + linear attention
    self.proj_l = nn.Linear(head_size, head_size, dtype=torch.float32)

    # Feature map for linear attention
    # Type annotation for callables
    self.feature_map_q: Callable[[torch.Tensor], torch.Tensor]
    self.feature_map_k: Callable[[torch.Tensor], torch.Tensor]
    if feature_map == "elu":
        self.feature_map_q = lambda x: F.elu(x) + 1
        self.feature_map_k = lambda x: F.elu(x) + 1
    elif feature_map == "relu":
        self.feature_map_q = F.relu
        self.feature_map_k = F.relu
    elif feature_map == "softmax":
        self.feature_map_q = lambda x: F.softmax(x, dim=-1)
        self.feature_map_k = lambda x: F.softmax(x, dim=-1)
    else:
        raise ValueError(f"Unknown feature map: {feature_map}")

    self._init_weights()

Functions

fastvideo.attention.backends.sla.SLAAttentionImpl.forward
forward(query: Tensor, key: Tensor, value: Tensor, attn_metadata: AttentionMetadata) -> Tensor

Forward pass for SLA attention.

Input tensors are in FastVideo format: (B, L, H, D) Internally converted to SLA format: (B, H, L, D)

Parameters:

Name Type Description Default
query Tensor

Query tensor (B, L, H, D)

required
key Tensor

Key tensor (B, L, H, D)

required
value Tensor

Value tensor (B, L, H, D)

required
attn_metadata AttentionMetadata

Attention metadata

required

Returns:

Type Description
Tensor

Output tensor (B, L, H, D)

Source code in fastvideo/attention/backends/sla.py
def forward(
    self,
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    attn_metadata: AttentionMetadata,
) -> torch.Tensor:
    """Forward pass for SLA attention.

    Input tensors are in FastVideo format: (B, L, H, D)
    Internally converted to SLA format: (B, H, L, D)

    Args:
        query: Query tensor (B, L, H, D)
        key: Key tensor (B, L, H, D)
        value: Value tensor (B, L, H, D)
        attn_metadata: Attention metadata

    Returns:
        Output tensor (B, L, H, D)
    """
    original_dtype = query.dtype

    # Convert from FastVideo format (B, L, H, D) to SLA format (B, H, L, D)
    q = query.transpose(1, 2).contiguous()
    k = key.transpose(1, 2).contiguous()
    v = value.transpose(1, 2).contiguous()

    # Get topk ratio from metadata if available
    topk_ratio = self.topk_ratio
    if hasattr(attn_metadata, 'topk_ratio'):
        topk_ratio = attn_metadata.topk_ratio  # type: ignore[union-attr]

    # Compute block-sparse attention pattern
    sparse_map, lut, real_topk = get_block_map(q,
                                               k,
                                               topk_ratio=topk_ratio,
                                               BLKQ=self.BLKQ,
                                               BLKK=self.BLKK)

    # Convert to compute dtype
    q = q.to(self.dtype)
    k = k.to(self.dtype)
    v = v.to(self.dtype)

    # Sparse attention
    o_s = _attention.apply(q, k, v, sparse_map, lut, real_topk, self.BLKQ,
                           self.BLKK)

    # Linear attention with feature maps
    q_linear = self.feature_map_q(q).contiguous().to(self.dtype)
    k_linear = self.feature_map_k(k).contiguous().to(self.dtype)
    o_l = self._calc_linear_attention(q_linear, k_linear, v)

    # Project linear attention output and combine
    with torch.amp.autocast('cuda', dtype=self.dtype):
        o_l = self.proj_l(o_l)

    # Combine sparse and linear outputs
    output = (o_s + o_l).to(original_dtype)

    # Convert back to FastVideo format (B, L, H, D)
    output = output.transpose(1, 2)

    return output

fastvideo.attention.backends.sla.SLAAttentionMetadata dataclass

SLAAttentionMetadata(current_timestep: int, topk_ratio: float = 0.5)

Bases: AttentionMetadata

Metadata for SLA attention.

fastvideo.attention.backends.sla.SLAAttentionMetadataBuilder

SLAAttentionMetadataBuilder()

Bases: AttentionMetadataBuilder

Builder for SLA attention metadata.

Source code in fastvideo/attention/backends/sla.py
def __init__(self) -> None:
    pass

fastvideo.attention.backends.sla.SageSLAAttentionBackend

Bases: AttentionBackend

Quantized Sparse-Linear Attention backend using SageAttention kernels.

fastvideo.attention.backends.sla.SageSLAAttentionImpl

SageSLAAttentionImpl(num_heads: int, head_size: int, causal: bool = False, softmax_scale: float | None = None, num_kv_heads: int | None = None, prefix: str = '', topk_ratio: float = 0.5, feature_map: str = 'softmax', use_bf16: bool = True, **extra_impl_args)

Bases: AttentionImpl, Module

SageSLA attention implementation using quantized SageAttention kernels.

This uses INT8 quantization for Q/K and FP8 for V to achieve better performance while maintaining accuracy. Requires spas_sage_attn package.

Parameters:

Name Type Description Default
num_heads int

Number of attention heads

required
head_size int

Dimension of each head (must be 64 or 128)

required
topk_ratio float

Ratio of key blocks to attend to (0-1), default 0.5

0.5
feature_map str

Feature map for linear attention ('softmax', 'elu', 'relu')

'softmax'
use_bf16 bool

Whether to use bfloat16 for computation

True
Source code in fastvideo/attention/backends/sla.py
def __init__(
    self,
    num_heads: int,
    head_size: int,
    causal: bool = False,
    softmax_scale: float | None = None,
    num_kv_heads: int | None = None,
    prefix: str = "",
    # SageSLA-specific parameters
    topk_ratio: float = 0.5,
    feature_map: str = "softmax",
    use_bf16: bool = True,
    **extra_impl_args,
) -> None:
    nn.Module.__init__(self)

    if not SAGESLA_ENABLED:
        raise ImportError(
            "SageSLA requires spas_sage_attn. "
            "Install with: pip install git+https://github.com/thu-ml/SpargeAttn.git"
        )

    assert head_size in [
        64, 128
    ], f"SageSLA requires head_size in [64, 128], got {head_size}"

    self.num_heads = num_heads
    self.head_size = head_size
    self.softmax_scale = softmax_scale if softmax_scale else head_size**-0.5
    self.causal = causal
    self.prefix = prefix

    # SageSLA-specific config
    self.topk_ratio = topk_ratio
    self.dtype = torch.bfloat16 if use_bf16 else torch.float16

    # Learnable linear projection for combining sparse + linear attention
    self.proj_l = nn.Linear(head_size, head_size, dtype=torch.float32)

    # Feature map for linear attention
    # Type annotation for callables
    self.feature_map_q: Callable[[torch.Tensor], torch.Tensor]
    self.feature_map_k: Callable[[torch.Tensor], torch.Tensor]
    if feature_map == "elu":
        self.feature_map_q = lambda x: F.elu(x) + 1
        self.feature_map_k = lambda x: F.elu(x) + 1
    elif feature_map == "relu":
        self.feature_map_q = F.relu
        self.feature_map_k = F.relu
    elif feature_map == "softmax":
        self.feature_map_q = lambda x: F.softmax(x, dim=-1)
        self.feature_map_k = lambda x: F.softmax(x, dim=-1)
    else:
        raise ValueError(f"Unknown feature map: {feature_map}")

    self._init_weights()

Functions

fastvideo.attention.backends.sla.SageSLAAttentionImpl.forward
forward(query: Tensor, key: Tensor, value: Tensor, attn_metadata: AttentionMetadata) -> Tensor

Forward pass for SageSLA attention with quantized kernels.

Input tensors are in FastVideo format: (B, L, H, D)

Parameters:

Name Type Description Default
query Tensor

Query tensor (B, L, H, D)

required
key Tensor

Key tensor (B, L, H, D)

required
value Tensor

Value tensor (B, L, H, D)

required
attn_metadata AttentionMetadata

Attention metadata

required

Returns:

Type Description
Tensor

Output tensor (B, L, H, D)

Source code in fastvideo/attention/backends/sla.py
def forward(
    self,
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    attn_metadata: AttentionMetadata,
) -> torch.Tensor:
    """Forward pass for SageSLA attention with quantized kernels.

    Input tensors are in FastVideo format: (B, L, H, D)

    Args:
        query: Query tensor (B, L, H, D)
        key: Key tensor (B, L, H, D)
        value: Value tensor (B, L, H, D)
        attn_metadata: Attention metadata

    Returns:
        Output tensor (B, L, H, D)
    """
    original_dtype = query.dtype

    # Convert from FastVideo format (B, L, H, D) to SLA format (B, H, L, D)
    q = query.transpose(1, 2).contiguous()
    k = key.transpose(1, 2).contiguous()
    v = value.transpose(1, 2).contiguous()

    # Get topk ratio from metadata if available
    topk_ratio = self.topk_ratio
    if hasattr(attn_metadata, 'topk_ratio'):
        topk_ratio = attn_metadata.topk_ratio  # type: ignore[union-attr]

    # Determine block sizes based on GPU architecture
    arch = _get_cuda_arch(q.device.index)
    if arch == "sm90":
        BLKQ, BLKK = 64, 128
    else:
        BLKQ, BLKK = 128, 64

    # Compute block-sparse attention pattern
    sparse_map, lut, real_topk = get_block_map(q,
                                               k,
                                               topk_ratio=topk_ratio,
                                               BLKQ=BLKQ,
                                               BLKK=BLKK)

    # Convert to compute dtype
    q = q.to(self.dtype)
    k = k.to(self.dtype)
    v = v.to(self.dtype)

    # ========== SPARGE QUANTIZED ATTENTION ==========
    km = k.mean(dim=-2, keepdim=True)
    headdim = q.size(-1)
    scale = 1.0 / (headdim**0.5)

    # Quantize Q, K to INT8
    q_int8, q_scale, k_int8, k_scale = get_vanilla_qk_quant(
        q, k, km, BLKQ, BLKK)
    lut_triton, valid_block_num = block_map_lut_triton(sparse_map)

    # Quantize V to FP8
    b, h_kv, kv_len, head_dim = v.shape
    padded_len = (kv_len + 127) // 128 * 128
    v_transposed_permutted = torch.empty((b, h_kv, head_dim, padded_len),
                                         dtype=v.dtype,
                                         device=v.device)
    fused.transpose_pad_permute_cuda(v, v_transposed_permutted, 1)
    v_fp8 = torch.empty(v_transposed_permutted.shape,
                        dtype=torch.float8_e4m3fn,
                        device=v.device)
    v_scale = torch.empty((b, h_kv, head_dim),
                          dtype=torch.float32,
                          device=v.device)
    fused.scale_fuse_quant_cuda(v_transposed_permutted, v_fp8, v_scale,
                                kv_len, 2.25, 1)

    # Sparse attention with quantized kernels
    o_s = torch.empty_like(q)
    if arch == "sm90":
        qattn.qk_int8_sv_f8_accum_f32_block_sparse_attn_inst_buf_fuse_v_scale_sm90(
            q_int8, k_int8, v_fp8, o_s, lut_triton, valid_block_num,
            q_scale, k_scale, v_scale, 1, False, 1, scale)
    else:
        pvthreshold = torch.full((q.shape[-3], ),
                                 1e6,
                                 dtype=torch.float32,
                                 device=q.device)
        if SAGE2PP_ENABLED:
            qk_int8_sv_f8_accum_f16_block_sparse_attn_inst_buf_fuse_v_scale_with_pv_threshold(
                q_int8, k_int8, v_fp8, o_s, lut_triton, valid_block_num,
                pvthreshold, q_scale, k_scale, v_scale, 1, False, 1, scale,
                0)
        else:
            qattn.qk_int8_sv_f8_accum_f32_block_sparse_attn_inst_buf_fuse_v_scale_with_pv_threshold(
                q_int8, k_int8, v_fp8, o_s, lut_triton, valid_block_num,
                pvthreshold, q_scale, k_scale, v_scale, 1, False, 1, scale,
                0)
    # ========== END SPARGE ==========

    # Linear attention with feature maps
    q_linear = self.feature_map_q(q).contiguous().to(self.dtype)
    k_linear = self.feature_map_k(k).contiguous().to(self.dtype)
    o_l = self._calc_linear_attention(q_linear, k_linear, v)

    # Project linear attention output and combine
    with torch.amp.autocast('cuda', dtype=self.dtype):
        o_l = self.proj_l(o_l)

    # Combine sparse and linear outputs
    output = (o_s + o_l).to(original_dtype)

    # Convert back to FastVideo format (B, L, H, D)
    output = output.transpose(1, 2)

    return output

Functions

fastvideo.attention.backends.sla.get_block_map

get_block_map(q: Tensor, k: Tensor, topk_ratio: float, BLKQ: int = 64, BLKK: int = 64) -> tuple[Tensor, Tensor, int]

Compute sparse block map for attention based on QK similarity.

Parameters:

Name Type Description Default
q Tensor

Query tensor of shape (B, H, L, D)

required
k Tensor

Key tensor of shape (B, H, L, D)

required
topk_ratio float

Ratio of key blocks to attend to (0-1)

required
BLKQ int

Query block size

64
BLKK int

Key block size

64

Returns:

Name Type Description
sparse_map Tensor

Binary mask of shape (B, H, num_q_blocks, num_k_blocks)

lut Tensor

Top-k indices of shape (B, H, num_q_blocks, topk)

topk int

Number of key blocks selected

Source code in fastvideo/attention/backends/sla.py
def get_block_map(
    q: torch.Tensor,
    k: torch.Tensor,
    topk_ratio: float,
    BLKQ: int = 64,
    BLKK: int = 64,
) -> tuple[torch.Tensor, torch.Tensor, int]:
    """Compute sparse block map for attention based on QK similarity.

    Args:
        q: Query tensor of shape (B, H, L, D)
        k: Key tensor of shape (B, H, L, D)
        topk_ratio: Ratio of key blocks to attend to (0-1)
        BLKQ: Query block size
        BLKK: Key block size

    Returns:
        sparse_map: Binary mask of shape (B, H, num_q_blocks, num_k_blocks)
        lut: Top-k indices of shape (B, H, num_q_blocks, topk)
        topk: Number of key blocks selected
    """
    arg_k = k - torch.mean(
        k, dim=-2, keepdim=True)  # smooth-k technique from SageAttention
    pooled_qblocks = mean_pool(q, BLKQ)
    pooled_kblocks = mean_pool(arg_k, BLKK)
    pooled_score = pooled_qblocks @ pooled_kblocks.transpose(-1, -2)

    K = pooled_score.shape[-1]
    topk = min(K, int(topk_ratio * K))
    lut = torch.topk(pooled_score, topk, dim=-1, sorted=False).indices

    sparse_map = torch.zeros_like(pooled_score, dtype=torch.int8)
    sparse_map.scatter_(-1, lut, 1)
    return sparse_map, lut, topk

fastvideo.attention.backends.sla.mean_pool

mean_pool(x: Tensor, BLK: int) -> Tensor

Mean pool tensor along sequence dimension with block size BLK.

Source code in fastvideo/attention/backends/sla.py
def mean_pool(x: torch.Tensor, BLK: int) -> torch.Tensor:
    """Mean pool tensor along sequence dimension with block size BLK."""
    assert x.is_contiguous()

    B, H, L, D = x.shape
    L_BLOCKS = (L + BLK - 1) // BLK
    x_mean = torch.empty((B, H, L_BLOCKS, D), device=x.device, dtype=x.dtype)

    grid = (L_BLOCKS, B * H)
    compress_kernel[grid](x, x_mean, L, D, BLK)
    return x_mean