Skip to content

ltx2_text_encoding

LTX2-specific text encoding stage with sequence parallelism broadcast support.

When running with sequence parallelism (SP), the Gemma text encoder is only executed on rank 0, and the embeddings are broadcast to all other ranks. This avoids I/O contention from all ranks loading the Gemma model simultaneously.

Classes

fastvideo.pipelines.stages.ltx2_text_encoding.LTX2TextEncodingStage

LTX2TextEncodingStage(text_encoders, tokenizers)

Bases: TextEncodingStage

LTX2 text encoding stage with sequence parallelism support.

When SP is enabled (sp_world_size > 1), only rank 0 runs the text encoder and broadcasts embeddings to other ranks. This avoids I/O contention from all ranks loading the Gemma model simultaneously, which can cause text encoding to take 100+ seconds instead of ~5 seconds.

Source code in fastvideo/pipelines/stages/text_encoding.py
def __init__(self, text_encoders, tokenizers) -> None:
    """
    Initialize the prompt encoding stage.

    Args:
        enable_logging: Whether to enable logging for this stage.
        is_secondary: Whether this is a secondary text encoder.
    """
    super().__init__()
    self.tokenizers = tokenizers
    self.text_encoders = text_encoders
    self._last_audio_embeds: list[torch.Tensor] | None = None

Functions