def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
"""
query: [B, L, H, D]
key: [B, L, H, D]
value: [B, L, H, D]
attn_metadata: AttentionMetadata
"""
batch_size, sequence_length, num_heads, head_dim = query.shape
# select chunk type according to layer idx:
loop_layer_num = attn_metadata.temporal_layer + attn_metadata.spatial_layer + attn_metadata.st_layer
moba_layer = self.layer_idx - attn_metadata.first_full_layer
if moba_layer % loop_layer_num < attn_metadata.temporal_layer:
moba_chunk_size = attn_metadata.temporal_chunk_size
moba_topk = attn_metadata.temporal_topk
elif moba_layer % loop_layer_num < attn_metadata.temporal_layer + attn_metadata.spatial_layer:
moba_chunk_size = attn_metadata.spatial_chunk_size
moba_topk = attn_metadata.spatial_topk
elif moba_layer % loop_layer_num < attn_metadata.temporal_layer + attn_metadata.spatial_layer + attn_metadata.st_layer:
moba_chunk_size = attn_metadata.st_chunk_size
moba_topk = attn_metadata.st_topk
query, chunk_size = process_moba_input(query,
attn_metadata.patch_resolution,
moba_chunk_size)
key, chunk_size = process_moba_input(key,
attn_metadata.patch_resolution,
moba_chunk_size)
value, chunk_size = process_moba_input(value,
attn_metadata.patch_resolution,
moba_chunk_size)
max_seqlen = query.shape[1]
indices_q = torch.arange(0,
query.shape[0] * query.shape[1],
device=query.device)
cu_seqlens = torch.arange(0,
query.shape[0] * query.shape[1] + 1,
query.shape[1],
dtype=torch.int32,
device=query.device)
query = rearrange(query, "b s ... -> (b s) ...")
key = rearrange(key, "b s ... -> (b s) ...")
value = rearrange(value, "b s ... -> (b s) ...")
# current_timestep=attn_metadata.current_timestep
hidden_states = moba_attn_varlen(
query,
key,
value,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
moba_chunk_size=chunk_size,
moba_topk=moba_topk,
select_mode=attn_metadata.moba_select_mode,
simsum_threshold=attn_metadata.moba_threshold,
threshold_type=attn_metadata.moba_threshold_type,
)
hidden_states = self.pad_input(hidden_states, indices_q, batch_size,
sequence_length)
hidden_states = process_moba_output(hidden_states,
attn_metadata.patch_resolution,
moba_chunk_size)
return hidden_states