Skip to content

text_encoding

Prompt encoding stages for diffusion pipelines.

This module contains implementations of prompt encoding stages for diffusion pipelines.

Classes

fastvideo.pipelines.stages.text_encoding.TextEncodingStage

TextEncodingStage(text_encoders, tokenizers)

Bases: PipelineStage

Stage for encoding text prompts into embeddings for diffusion models.

This stage handles the encoding of text prompts into the embedding space expected by the diffusion model.

Initialize the prompt encoding stage.

Parameters:

Name Type Description Default
enable_logging

Whether to enable logging for this stage.

required
is_secondary

Whether this is a secondary text encoder.

required
Source code in fastvideo/pipelines/stages/text_encoding.py
def __init__(self, text_encoders, tokenizers) -> None:
    """
    Initialize the prompt encoding stage.

    Args:
        enable_logging: Whether to enable logging for this stage.
        is_secondary: Whether this is a secondary text encoder.
    """
    super().__init__()
    self.tokenizers = tokenizers
    self.text_encoders = text_encoders

Functions

fastvideo.pipelines.stages.text_encoding.TextEncodingStage.encode_text
encode_text(text: str | list[str], fastvideo_args: FastVideoArgs, encoder_index: int | list[int] | None = None, return_attention_mask: bool = False, return_type: str = 'list', device: device | str | None = None, dtype: dtype | None = None, max_length: int | None = None, truncation: bool | None = None, padding: bool | str | None = None)

Encode plain text using selected text encoder(s) and return embeddings.

Parameters:

Name Type Description Default
text str | list[str]

A single string or a list of strings to encode.

required
fastvideo_args FastVideoArgs

The inference arguments providing pipeline config, including tokenizer and encoder settings, preprocess and postprocess functions.

required
encoder_index int | list[int] | None

Encoder selector by index. Accepts an int or list of ints.

None
return_attention_mask bool

If True, also return attention masks for each selected encoder.

False
return_type str

"list" (default) returns a list aligned with selection; "dict" returns a dict keyed by encoder index as a string; "stack" stacks along a new first dimension (requires matching shapes).

'list'
device device | str | None

Optional device override for inputs; defaults to local torch device.

None
dtype dtype | None

Optional dtype to cast returned embeddings to.

None
max_length int | None

Optional per-call tokenizer override.

None
truncation bool | None

Optional per-call tokenizer override.

None
padding bool | str | None

Optional per-call tokenizer override.

None

Returns:

Type Description

Depending on return_type and return_attention_mask:

  • list: List[Tensor] or (List[Tensor], List[Tensor])
  • dict: Dict[str, Tensor] or (Dict[str, Tensor], Dict[str, Tensor])
  • stack: Tensor of shape [num_encoders, ...] or a tuple with stacked attention masks
Source code in fastvideo/pipelines/stages/text_encoding.py
@torch.no_grad()
def encode_text(
    self,
    text: str | list[str],
    fastvideo_args: FastVideoArgs,
    encoder_index: int | list[int] | None = None,
    return_attention_mask: bool = False,
    return_type: str = "list",  # one of: "list", "dict", "stack"
    device: torch.device | str | None = None,
    dtype: torch.dtype | None = None,
    max_length: int | None = None,
    truncation: bool | None = None,
    padding: bool | str | None = None,
):
    """
    Encode plain text using selected text encoder(s) and return embeddings.

    Args:
        text: A single string or a list of strings to encode.
        fastvideo_args: The inference arguments providing pipeline config,
            including tokenizer and encoder settings, preprocess and postprocess
            functions.
        encoder_index: Encoder selector by index. Accepts an int or list of ints.
        return_attention_mask: If True, also return attention masks for each
            selected encoder.
        return_type: "list" (default) returns a list aligned with selection;
            "dict" returns a dict keyed by encoder index as a string; "stack" stacks along a
            new first dimension (requires matching shapes).
        device: Optional device override for inputs; defaults to local torch device.
        dtype: Optional dtype to cast returned embeddings to.
        max_length: Optional per-call tokenizer override.
        truncation: Optional per-call tokenizer override.
        padding: Optional per-call tokenizer override.

    Returns:
        Depending on return_type and return_attention_mask:
        - list: List[Tensor] or (List[Tensor], List[Tensor])
        - dict: Dict[str, Tensor] or (Dict[str, Tensor], Dict[str, Tensor])
        - stack: Tensor of shape [num_encoders, ...] or a tuple with stacked
          attention masks
    """

    assert len(self.tokenizers) == len(self.text_encoders)
    assert len(self.text_encoders) == len(
        fastvideo_args.pipeline_config.text_encoder_configs)

    # Resolve selection into indices
    encoder_cfgs = fastvideo_args.pipeline_config.text_encoder_configs
    if encoder_index is None:
        indices: list[int] = [0]
    elif isinstance(encoder_index, int):
        indices = [encoder_index]
    else:
        indices = list(encoder_index)
    # validate range
    num_encoders = len(self.text_encoders)
    for idx in indices:
        if idx < 0 or idx >= num_encoders:
            raise IndexError(
                f"encoder index {idx} out of range [0, {num_encoders-1}]")

    # Validate indices are within range
    num_encoders = len(self.text_encoders)

    # Normalize input to list[str]
    assert isinstance(text, str | list)
    if isinstance(text, str):
        texts: list[str] = [text]
    else:
        texts = text

    embeds_list: list[torch.Tensor] = []
    attn_masks_list: list[torch.Tensor] = []

    preprocess_funcs = fastvideo_args.pipeline_config.preprocess_text_funcs
    postprocess_funcs = fastvideo_args.pipeline_config.postprocess_text_funcs
    encoder_cfgs = fastvideo_args.pipeline_config.text_encoder_configs

    if return_type not in ("list", "dict", "stack"):
        raise ValueError(
            f"Invalid return_type '{return_type}'. Expected one of: 'list', 'dict', 'stack'"
        )

    target_device = device if device is not None else get_local_torch_device(
    )

    for i in indices:
        tokenizer = self.tokenizers[i]
        text_encoder = self.text_encoders[i]
        encoder_config = encoder_cfgs[i]
        preprocess_func = preprocess_funcs[i]
        postprocess_func = postprocess_funcs[i]

        tok_kwargs = dict(encoder_config.tokenizer_kwargs)
        if max_length is not None:
            tok_kwargs["max_length"] = max_length
        elif hasattr(fastvideo_args.pipeline_config,
                     "text_encoder_max_lengths"):
            tok_kwargs[
                "max_length"] = fastvideo_args.pipeline_config.text_encoder_max_lengths[
                    i]

        if truncation is not None:
            tok_kwargs["truncation"] = truncation
        if padding is not None:
            tok_kwargs["padding"] = padding

        processed_texts: list[str] = []
        for prompt_str in texts:
            processed_text = preprocess_func(prompt_str)
            if processed_text is not None:
                processed_texts.append(processed_text)
            else:
                # Assuming batch_size = 1
                prompt_embeds = torch.zeros((1, tok_kwargs["max_length"],
                                             encoder_config.hidden_size),
                                            device=target_device)
                attention_mask = torch.zeros((1, tok_kwargs["max_length"]),
                                             device=target_device,
                                             dtype=torch.int64)
                embeds_list.append(prompt_embeds)
                attn_masks_list.append(attention_mask)
                return self.return_embeds(embeds_list, attn_masks_list,
                                          return_type,
                                          return_attention_mask, indices)

        if encoder_config.is_chat_model:
            text_inputs = tokenizer.apply_chat_template(
                processed_texts, **tok_kwargs).to(target_device)
        else:
            text_inputs = tokenizer(processed_texts,
                                    **tok_kwargs).to(target_device)

        input_ids = text_inputs["input_ids"]
        attention_mask = text_inputs["attention_mask"]

        with set_forward_context(current_timestep=0, attn_metadata=None):
            outputs = text_encoder(
                input_ids=input_ids,
                attention_mask=attention_mask,
                output_hidden_states=True,
            )

        try:
            prompt_embeds = postprocess_func(outputs)
        except Exception:
            prompt_embeds, attention_mask = postprocess_func(
                outputs, attention_mask)

        if dtype is not None:
            prompt_embeds = prompt_embeds.to(dtype=dtype)
        embeds_list.append(prompt_embeds)
        if return_attention_mask:
            attn_masks_list.append(attention_mask)

    return self.return_embeds(embeds_list, attn_masks_list, return_type,
                              return_attention_mask, indices)
fastvideo.pipelines.stages.text_encoding.TextEncodingStage.forward
forward(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> ForwardBatch

Encode the prompt into text encoder hidden states.

Parameters:

Name Type Description Default
batch ForwardBatch

The current batch information.

required
fastvideo_args FastVideoArgs

The inference arguments.

required

Returns:

Type Description
ForwardBatch

The batch with encoded prompt embeddings.

Source code in fastvideo/pipelines/stages/text_encoding.py
@torch.no_grad()
def forward(
    self,
    batch: ForwardBatch,
    fastvideo_args: FastVideoArgs,
) -> ForwardBatch:
    """
    Encode the prompt into text encoder hidden states.

    Args:
        batch: The current batch information.
        fastvideo_args: The inference arguments.

    Returns:
        The batch with encoded prompt embeddings.
    """
    assert len(self.tokenizers) == len(self.text_encoders)
    assert len(self.text_encoders) == len(
        fastvideo_args.pipeline_config.text_encoder_configs)

    # Encode positive prompt with all available encoders
    assert batch.prompt is not None
    prompt_text: str | list[str] = batch.prompt
    all_indices: list[int] = list(range(len(self.text_encoders)))
    prompt_embeds_list, prompt_masks_list = self.encode_text(
        prompt_text,
        fastvideo_args,
        encoder_index=all_indices,
        return_attention_mask=True,
    )

    for pe in prompt_embeds_list:
        batch.prompt_embeds.append(pe)
    if batch.prompt_attention_mask is not None:
        for am in prompt_masks_list:
            batch.prompt_attention_mask.append(am)

    # Encode negative prompt if CFG is enabled
    if batch.do_classifier_free_guidance:
        assert isinstance(batch.negative_prompt, str)
        neg_embeds_list, neg_masks_list = self.encode_text(
            batch.negative_prompt,
            fastvideo_args,
            encoder_index=all_indices,
            return_attention_mask=True,
        )

        assert batch.negative_prompt_embeds is not None
        for ne in neg_embeds_list:
            batch.negative_prompt_embeds.append(ne)
        if batch.negative_attention_mask is not None:
            for nm in neg_masks_list:
                batch.negative_attention_mask.append(nm)

    return batch
fastvideo.pipelines.stages.text_encoding.TextEncodingStage.verify_input
verify_input(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult

Verify text encoding stage inputs.

Source code in fastvideo/pipelines/stages/text_encoding.py
def verify_input(self, batch: ForwardBatch,
                 fastvideo_args: FastVideoArgs) -> VerificationResult:
    """Verify text encoding stage inputs."""
    result = VerificationResult()
    result.add_check("prompt", batch.prompt, V.string_or_list_strings)
    # result.add_check(
    #     "negative_prompt", batch.negative_prompt, lambda x: not batch.
    #     do_classifier_free_guidance or V.string_not_empty(x))
    result.add_check("do_classifier_free_guidance",
                     batch.do_classifier_free_guidance, V.bool_value)
    result.add_check("prompt_embeds", batch.prompt_embeds, V.is_list)
    result.add_check("negative_prompt_embeds", batch.negative_prompt_embeds,
                     V.none_or_list)
    return result
fastvideo.pipelines.stages.text_encoding.TextEncodingStage.verify_output
verify_output(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult

Verify text encoding stage outputs.

Source code in fastvideo/pipelines/stages/text_encoding.py
def verify_output(self, batch: ForwardBatch,
                  fastvideo_args: FastVideoArgs) -> VerificationResult:
    """Verify text encoding stage outputs."""
    result = VerificationResult()
    result.add_check("prompt_embeds", batch.prompt_embeds,
                     V.list_of_tensors_min_dims(2))
    result.add_check(
        "negative_prompt_embeds", batch.negative_prompt_embeds,
        lambda x: not batch.do_classifier_free_guidance or V.
        list_of_tensors_with_min_dims(x, 2))
    return result

Functions