Skip to content

conditioning

Conditioning stage for diffusion pipelines.

Classes

fastvideo.pipelines.stages.conditioning.ConditioningStage

Bases: PipelineStage

Stage for applying conditioning to the diffusion process.

This stage handles the application of conditioning, such as classifier-free guidance, to the diffusion process.

Functions

fastvideo.pipelines.stages.conditioning.ConditioningStage.forward
forward(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> ForwardBatch

Apply conditioning to the diffusion process.

Parameters:

Name Type Description Default
batch ForwardBatch

The current batch information.

required
fastvideo_args FastVideoArgs

The inference arguments.

required

Returns:

Type Description
ForwardBatch

The batch with applied conditioning.

Source code in fastvideo/pipelines/stages/conditioning.py
@torch.no_grad()
def forward(
    self,
    batch: ForwardBatch,
    fastvideo_args: FastVideoArgs,
) -> ForwardBatch:
    """
    Apply conditioning to the diffusion process.

    Args:
        batch: The current batch information.
        fastvideo_args: The inference arguments.

    Returns:
        The batch with applied conditioning.
    """
    # TODO!!
    if not batch.do_classifier_free_guidance:
        return batch
    else:
        return batch

    logger.info("batch.negative_prompt_embeds: %s",
                batch.negative_prompt_embeds)
    logger.info("do_classifier_free_guidance: %s",
                batch.do_classifier_free_guidance)
    logger.info("cfg_scale: %s", batch.guidance_scale)

    # Ensure negative prompt embeddings are available
    assert batch.negative_prompt_embeds is not None, (
        "Negative prompt embeddings are required for classifier-free guidance"
    )

    # Concatenate primary embeddings and masks
    batch.prompt_embeds = torch.cat(
        [batch.negative_prompt_embeds, batch.prompt_embeds])
    if batch.attention_mask is not None:
        batch.attention_mask = torch.cat(
            [batch.negative_attention_mask, batch.attention_mask])

    # Concatenate secondary embeddings and masks if present
    if batch.prompt_embeds_2 is not None:
        batch.prompt_embeds_2 = torch.cat(
            [batch.negative_prompt_embeds_2, batch.prompt_embeds_2])
    if batch.attention_mask_2 is not None:
        batch.attention_mask_2 = torch.cat(
            [batch.negative_attention_mask_2, batch.attention_mask_2])

    return batch
fastvideo.pipelines.stages.conditioning.ConditioningStage.verify_input
verify_input(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult

Verify conditioning stage inputs.

Source code in fastvideo/pipelines/stages/conditioning.py
def verify_input(self, batch: ForwardBatch,
                 fastvideo_args: FastVideoArgs) -> VerificationResult:
    """Verify conditioning stage inputs."""
    result = VerificationResult()
    result.add_check("do_classifier_free_guidance",
                     batch.do_classifier_free_guidance, V.bool_value)
    result.add_check("guidance_scale", batch.guidance_scale,
                     V.positive_float)
    result.add_check("prompt_embeds", batch.prompt_embeds, V.list_not_empty)
    result.add_check(
        "negative_prompt_embeds", batch.negative_prompt_embeds, lambda x:
        not batch.do_classifier_free_guidance or V.list_not_empty(x))
    return result
fastvideo.pipelines.stages.conditioning.ConditioningStage.verify_output
verify_output(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult

Verify conditioning stage outputs.

Source code in fastvideo/pipelines/stages/conditioning.py
def verify_output(self, batch: ForwardBatch,
                  fastvideo_args: FastVideoArgs) -> VerificationResult:
    """Verify conditioning stage outputs."""
    result = VerificationResult()
    result.add_check("prompt_embeds", batch.prompt_embeds, V.list_not_empty)
    return result

Functions