fastvideo.v1.attention.layer
#
Module Contents#
Classes#
Distributed attention layer. |
|
Attention layer. |
API#
- class fastvideo.v1.attention.layer.DistributedAttention(num_heads: int, head_size: int, num_kv_heads: Optional[int] = None, softmax_scale: Optional[float] = None, causal: bool = False, supported_attention_backends: Optional[Tuple[fastvideo.v1.platforms._Backend, ...]] = None, prefix: str = '', **extra_impl_args)[source]#
Bases:
torch.nn.Module
Distributed attention layer.
Initialization
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- forward(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, replicated_q: Optional[torch.Tensor] = None, replicated_k: Optional[torch.Tensor] = None, replicated_v: Optional[torch.Tensor] = None) tuple[torch.Tensor, Optional[torch.Tensor]] [source]#
Forward pass for distributed attention.
- Parameters:
q (torch.Tensor) – Query tensor [batch_size, seq_len, num_heads, head_dim]
k (torch.Tensor) – Key tensor [batch_size, seq_len, num_heads, head_dim]
v (torch.Tensor) – Value tensor [batch_size, seq_len, num_heads, head_dim]
replicated_q (Optional[torch.Tensor]) – Replicated query tensor, typically for text tokens
replicated_k (Optional[torch.Tensor]) – Replicated key tensor
replicated_v (Optional[torch.Tensor]) – Replicated value tensor
- Returns:
A tuple containing: - o (torch.Tensor): Output tensor after attention for the main sequence - replicated_o (Optional[torch.Tensor]): Output tensor for replicated tokens, if provided
- Return type:
Tuple[torch.Tensor, Optional[torch.Tensor]]
- class fastvideo.v1.attention.layer.LocalAttention(num_heads: int, head_size: int, num_kv_heads: Optional[int] = None, softmax_scale: Optional[float] = None, causal: bool = False, supported_attention_backends: Optional[Tuple[fastvideo.v1.platforms._Backend, ...]] = None, **extra_impl_args)[source]#
Bases:
torch.nn.Module
Attention layer.
Initialization
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- forward(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) torch.Tensor [source]#
Apply local attention between query, key and value tensors.
- Parameters:
q (torch.Tensor) – Query tensor of shape [batch_size, seq_len, num_heads, head_dim]
k (torch.Tensor) – Key tensor of shape [batch_size, seq_len, num_heads, head_dim]
v (torch.Tensor) – Value tensor of shape [batch_size, seq_len, num_heads, head_dim]
- Returns:
Output tensor after local attention
- Return type: