Skip to content

timestep_preparation

Timestep preparation stages for diffusion pipelines.

This module contains implementations of timestep preparation stages for diffusion pipelines.

Classes

fastvideo.pipelines.stages.timestep_preparation.TimestepPreparationStage

TimestepPreparationStage(scheduler)

Bases: PipelineStage

Stage for preparing timesteps for the diffusion process.

This stage handles the preparation of the timestep sequence that will be used during the diffusion process.

Source code in fastvideo/pipelines/stages/timestep_preparation.py
def __init__(self, scheduler) -> None:
    self.scheduler = scheduler

Functions

fastvideo.pipelines.stages.timestep_preparation.TimestepPreparationStage.forward
forward(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> ForwardBatch

Prepare timesteps for the diffusion process.

Parameters:

Name Type Description Default
batch ForwardBatch

The current batch information.

required
fastvideo_args FastVideoArgs

The inference arguments.

required

Returns:

Type Description
ForwardBatch

The batch with prepared timesteps.

Source code in fastvideo/pipelines/stages/timestep_preparation.py
def forward(
    self,
    batch: ForwardBatch,
    fastvideo_args: FastVideoArgs,
) -> ForwardBatch:
    """
    Prepare timesteps for the diffusion process.

    Args:
        batch: The current batch information.
        fastvideo_args: The inference arguments.

    Returns:
        The batch with prepared timesteps.
    """
    scheduler = self.scheduler
    device = get_local_torch_device()
    num_inference_steps = batch.num_inference_steps
    timesteps = batch.timesteps
    sigmas = batch.sigmas
    n_tokens = batch.n_tokens

    # Prepare extra kwargs for set_timesteps
    extra_set_timesteps_kwargs = {}
    if n_tokens is not None and "n_tokens" in inspect.signature(
            scheduler.set_timesteps).parameters:
        extra_set_timesteps_kwargs["n_tokens"] = n_tokens

    # Handle custom timesteps or sigmas
    if timesteps is not None and sigmas is not None:
        raise ValueError(
            "Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values"
        )

    if timesteps is not None:
        accepts_timesteps = "timesteps" in inspect.signature(
            scheduler.set_timesteps).parameters
        if not accepts_timesteps:
            raise ValueError(
                f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
                f" timestep schedules. Please check whether you are using the correct scheduler."
            )
        scheduler.set_timesteps(timesteps=timesteps,
                                device=device,
                                **extra_set_timesteps_kwargs)
        timesteps = scheduler.timesteps
    elif sigmas is not None:
        accept_sigmas = "sigmas" in inspect.signature(
            scheduler.set_timesteps).parameters
        if not accept_sigmas:
            raise ValueError(
                f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
                f" sigmas schedules. Please check whether you are using the correct scheduler."
            )
        scheduler.set_timesteps(sigmas=sigmas,
                                device=device,
                                **extra_set_timesteps_kwargs)
        timesteps = scheduler.timesteps
    else:
        scheduler.set_timesteps(num_inference_steps,
                                device=device,
                                **extra_set_timesteps_kwargs)
        timesteps = scheduler.timesteps

    # Update batch with prepared timesteps
    batch.timesteps = timesteps

    return batch
fastvideo.pipelines.stages.timestep_preparation.TimestepPreparationStage.verify_input
verify_input(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult

Verify timestep preparation stage inputs.

Source code in fastvideo/pipelines/stages/timestep_preparation.py
def verify_input(self, batch: ForwardBatch,
                 fastvideo_args: FastVideoArgs) -> VerificationResult:
    """Verify timestep preparation stage inputs."""
    result = VerificationResult()
    result.add_check("num_inference_steps", batch.num_inference_steps,
                     V.positive_int)
    result.add_check("timesteps", batch.timesteps, V.none_or_tensor)
    result.add_check("sigmas", batch.sigmas, V.none_or_list)
    result.add_check("n_tokens", batch.n_tokens, V.none_or_positive_int)
    return result
fastvideo.pipelines.stages.timestep_preparation.TimestepPreparationStage.verify_output
verify_output(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult

Verify timestep preparation stage outputs.

Source code in fastvideo/pipelines/stages/timestep_preparation.py
def verify_output(self, batch: ForwardBatch,
                  fastvideo_args: FastVideoArgs) -> VerificationResult:
    """Verify timestep preparation stage outputs."""
    result = VerificationResult()
    result.add_check("timesteps", batch.timesteps,
                     [V.is_tensor, V.with_dims(1)])
    return result

Functions