Skip to content

utils

Utility functions for pipeline stages.

Functions

fastvideo.pipelines.stages.utils.retrieve_timesteps

retrieve_timesteps(scheduler: Any, num_inference_steps: int | None = None, device: str | device | None = None, timesteps: list[int] | None = None, sigmas: list[float] | None = None, **kwargs: Any) -> tuple[Any, int]

Calls the scheduler's set_timesteps method and retrieves timesteps from the scheduler after the call. Handles custom timesteps. Any kwargs will be supplied to scheduler.set_timesteps.

Parameters:

Name Type Description Default
scheduler `SchedulerMixin`

The scheduler to get timesteps from.

required
num_inference_steps `int`

The number of diffusion steps used when generating samples with a pre-trained model. If used, timesteps must be None.

None
device `str` or `torch.device`, *optional*

The device to which the timesteps should be moved to. If None, the timesteps are not moved.

None
timesteps `List[int]`, *optional*

Custom timesteps used to override the timestep spacing strategy of the scheduler. If timesteps is passed, num_inference_steps and sigmas must be None.

None
sigmas `List[float]`, *optional*

Custom sigmas used to override the timestep spacing strategy of the scheduler. If sigmas is passed, num_inference_steps and timesteps must be None.

None

Returns:

Type Description
Any

Tuple[torch.Tensor, int]: A tuple where the first element is the timestep schedule and the

int

second element is the number of inference steps.

Source code in fastvideo/pipelines/stages/utils.py
def retrieve_timesteps(
    scheduler: Any,
    num_inference_steps: int | None = None,
    device: str | torch.device | None = None,
    timesteps: list[int] | None = None,
    sigmas: list[float] | None = None,
    **kwargs: Any,
) -> tuple[Any, int]:
    """
    Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
    custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.

    Args:
        scheduler (`SchedulerMixin`):
            The scheduler to get timesteps from.
        num_inference_steps (`int`):
            The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
            must be `None`.
        device (`str` or `torch.device`, *optional*):
            The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
        timesteps (`List[int]`, *optional*):
            Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
            `num_inference_steps` and `sigmas` must be `None`.
        sigmas (`List[float]`, *optional*):
            Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
            `num_inference_steps` and `timesteps` must be `None`.

    Returns:
        `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule and the
        second element is the number of inference steps.
    """
    if timesteps is not None and sigmas is not None:
        raise ValueError(
            "Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values"
        )
    if timesteps is not None:
        accepts_timesteps = "timesteps" in set(
            inspect.signature(scheduler.set_timesteps).parameters.keys())
        if not accepts_timesteps:
            raise ValueError(
                f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
                f" timestep schedules. Please check whether you are using the correct scheduler."
            )
        scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
        timesteps = scheduler.timesteps
        if timesteps is None:
            raise ValueError("scheduler.timesteps is None after set_timesteps")
        num_inference_steps = len(timesteps)
    elif sigmas is not None:
        accept_sigmas = "sigmas" in set(
            inspect.signature(scheduler.set_timesteps).parameters.keys())
        if not accept_sigmas:
            raise ValueError(
                f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
                f" sigmas schedules. Please check whether you are using the correct scheduler."
            )
        scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
        timesteps = scheduler.timesteps
        if timesteps is None:
            raise ValueError("scheduler.timesteps is None after set_timesteps")
        num_inference_steps = len(timesteps)
    else:
        scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
        timesteps = scheduler.timesteps
        if timesteps is None:
            raise ValueError("scheduler.timesteps is None after set_timesteps")
        num_inference_steps = len(timesteps)
    return timesteps, num_inference_steps