fastvideo.v1.attention.layer#

Module Contents#

Classes#

DistributedAttention

Distributed attention layer.

LocalAttention

Attention layer.

API#

class fastvideo.v1.attention.layer.DistributedAttention(num_heads: int, head_size: int, num_kv_heads: Optional[int] = None, softmax_scale: Optional[float] = None, causal: bool = False, supported_attention_backends: Optional[Tuple[fastvideo.v1.platforms._Backend, ...]] = None, prefix: str = '', **extra_impl_args)[source]#

Bases: torch.nn.Module

Distributed attention layer.

Initialization

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, replicated_q: Optional[torch.Tensor] = None, replicated_k: Optional[torch.Tensor] = None, replicated_v: Optional[torch.Tensor] = None) tuple[torch.Tensor, Optional[torch.Tensor]][source]#

Forward pass for distributed attention.

Parameters:
  • q (torch.Tensor) – Query tensor [batch_size, seq_len, num_heads, head_dim]

  • k (torch.Tensor) – Key tensor [batch_size, seq_len, num_heads, head_dim]

  • v (torch.Tensor) – Value tensor [batch_size, seq_len, num_heads, head_dim]

  • replicated_q (Optional[torch.Tensor]) – Replicated query tensor, typically for text tokens

  • replicated_k (Optional[torch.Tensor]) – Replicated key tensor

  • replicated_v (Optional[torch.Tensor]) – Replicated value tensor

Returns:

A tuple containing: - o (torch.Tensor): Output tensor after attention for the main sequence - replicated_o (Optional[torch.Tensor]): Output tensor for replicated tokens, if provided

Return type:

Tuple[torch.Tensor, Optional[torch.Tensor]]

class fastvideo.v1.attention.layer.LocalAttention(num_heads: int, head_size: int, num_kv_heads: Optional[int] = None, softmax_scale: Optional[float] = None, causal: bool = False, supported_attention_backends: Optional[Tuple[fastvideo.v1.platforms._Backend, ...]] = None, **extra_impl_args)[source]#

Bases: torch.nn.Module

Attention layer.

Initialization

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) torch.Tensor[source]#

Apply local attention between query, key and value tensors.

Parameters:
  • q (torch.Tensor) – Query tensor of shape [batch_size, seq_len, num_heads, head_dim]

  • k (torch.Tensor) – Key tensor of shape [batch_size, seq_len, num_heads, head_dim]

  • v (torch.Tensor) – Value tensor of shape [batch_size, seq_len, num_heads, head_dim]

Returns:

Output tensor after local attention

Return type:

torch.Tensor