Skip to content

latent_preparation

Latent preparation stage for diffusion pipelines.

Classes

fastvideo.pipelines.stages.latent_preparation.CosmosLatentPreparationStage

CosmosLatentPreparationStage(scheduler, transformer, vae=None)

Bases: PipelineStage

Cosmos-specific latent preparation stage that properly handles the tensor shapes and conditioning masks required by the Cosmos transformer.

This stage replicates the logic from diffusers' Cosmos2VideoToWorldPipeline.prepare_latents()

Source code in fastvideo/pipelines/stages/latent_preparation.py
def __init__(self, scheduler, transformer, vae=None) -> None:
    super().__init__()
    self.scheduler = scheduler
    self.transformer = transformer
    self.vae = vae

Functions

fastvideo.pipelines.stages.latent_preparation.CosmosLatentPreparationStage.adjust_video_length
adjust_video_length(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> int

Adjust video length based on VAE version.

Parameters:

Name Type Description Default
batch ForwardBatch

The current batch information.

required
fastvideo_args FastVideoArgs

The inference arguments.

required

Returns:

Type Description
int

The batch with adjusted video length.

Source code in fastvideo/pipelines/stages/latent_preparation.py
def adjust_video_length(self, batch: ForwardBatch,
                        fastvideo_args: FastVideoArgs) -> int:
    """
    Adjust video length based on VAE version.

    Args:
        batch: The current batch information.
        fastvideo_args: The inference arguments.

    Returns:
        The batch with adjusted video length.
    """

    video_length = batch.num_frames
    use_temporal_scaling_frames = fastvideo_args.pipeline_config.vae_config.use_temporal_scaling_frames
    if use_temporal_scaling_frames:
        temporal_scale_factor = fastvideo_args.pipeline_config.vae_config.arch_config.temporal_compression_ratio
        latent_num_frames = (video_length - 1) // temporal_scale_factor + 1
    else:  # stepvideo only
        latent_num_frames = video_length // 17 * 3
    return int(latent_num_frames)
fastvideo.pipelines.stages.latent_preparation.CosmosLatentPreparationStage.verify_input
verify_input(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult

Verify Cosmos latent preparation stage inputs.

Source code in fastvideo/pipelines/stages/latent_preparation.py
def verify_input(self, batch: ForwardBatch,
                 fastvideo_args: FastVideoArgs) -> VerificationResult:
    """Verify Cosmos latent preparation stage inputs."""
    result = VerificationResult()
    result.add_check(
        "prompt_or_embeds", None, lambda _: V.string_or_list_strings(
            batch.prompt) or V.list_not_empty(batch.prompt_embeds))
    result.add_check("prompt_embeds", batch.prompt_embeds,
                     V.list_of_tensors)
    result.add_check("num_videos_per_prompt", batch.num_videos_per_prompt,
                     V.positive_int)
    result.add_check("generator", batch.generator,
                     V.generator_or_list_generators)
    result.add_check("num_frames", batch.num_frames, V.positive_int)
    result.add_check("height", batch.height, V.positive_int)
    result.add_check("width", batch.width, V.positive_int)
    result.add_check("latents", batch.latents, V.none_or_tensor)
    return result
fastvideo.pipelines.stages.latent_preparation.CosmosLatentPreparationStage.verify_output
verify_output(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult

Verify latent preparation stage outputs.

Source code in fastvideo/pipelines/stages/latent_preparation.py
def verify_output(self, batch: ForwardBatch,
                  fastvideo_args: FastVideoArgs) -> VerificationResult:
    """Verify latent preparation stage outputs."""
    result = VerificationResult()
    result.add_check("latents", batch.latents,
                     [V.is_tensor, V.with_dims(5)])
    result.add_check("raw_latent_shape", batch.raw_latent_shape, V.is_tuple)
    return result

fastvideo.pipelines.stages.latent_preparation.LatentPreparationStage

LatentPreparationStage(scheduler, transformer)

Bases: PipelineStage

Stage for preparing initial latent variables for the diffusion process.

This stage handles the preparation of the initial latent variables that will be denoised during the diffusion process.

Source code in fastvideo/pipelines/stages/latent_preparation.py
def __init__(self, scheduler, transformer) -> None:
    super().__init__()
    self.scheduler = scheduler
    self.transformer = transformer

Functions

fastvideo.pipelines.stages.latent_preparation.LatentPreparationStage.adjust_video_length
adjust_video_length(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> int

Adjust video length based on VAE version.

Parameters:

Name Type Description Default
batch ForwardBatch

The current batch information.

required
fastvideo_args FastVideoArgs

The inference arguments.

required

Returns:

Type Description
int

The batch with adjusted video length.

Source code in fastvideo/pipelines/stages/latent_preparation.py
def adjust_video_length(self, batch: ForwardBatch,
                        fastvideo_args: FastVideoArgs) -> int:
    """
    Adjust video length based on VAE version.

    Args:
        batch: The current batch information.
        fastvideo_args: The inference arguments.

    Returns:
        The batch with adjusted video length.
    """

    video_length = batch.num_frames
    use_temporal_scaling_frames = fastvideo_args.pipeline_config.vae_config.use_temporal_scaling_frames
    if use_temporal_scaling_frames:
        temporal_scale_factor = fastvideo_args.pipeline_config.vae_config.arch_config.temporal_compression_ratio
        latent_num_frames = (video_length - 1) // temporal_scale_factor + 1
    else:  # stepvideo only
        latent_num_frames = video_length // 17 * 3
    return int(latent_num_frames)
fastvideo.pipelines.stages.latent_preparation.LatentPreparationStage.forward
forward(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> ForwardBatch

Prepare initial latent variables for the diffusion process.

Parameters:

Name Type Description Default
batch ForwardBatch

The current batch information.

required
fastvideo_args FastVideoArgs

The inference arguments.

required

Returns:

Type Description
ForwardBatch

The batch with prepared latent variables.

Source code in fastvideo/pipelines/stages/latent_preparation.py
def forward(
    self,
    batch: ForwardBatch,
    fastvideo_args: FastVideoArgs,
) -> ForwardBatch:
    """
    Prepare initial latent variables for the diffusion process.

    Args:
        batch: The current batch information.
        fastvideo_args: The inference arguments.

    Returns:
        The batch with prepared latent variables.
    """

    latent_num_frames = None
    # Adjust video length based on VAE version if needed
    if hasattr(self, 'adjust_video_length'):
        latent_num_frames = self.adjust_video_length(batch, fastvideo_args)
    # Determine batch size
    if isinstance(batch.prompt, list):
        batch_size = len(batch.prompt)
    elif batch.prompt is not None:
        batch_size = 1
    else:
        batch_size = batch.prompt_embeds[0].shape[0]

    # Adjust batch size for number of videos per prompt
    batch_size *= batch.num_videos_per_prompt

    # Get required parameters
    dtype = batch.prompt_embeds[0].dtype
    device = get_local_torch_device()
    generator = batch.generator
    latents = batch.latents
    num_frames = latent_num_frames if latent_num_frames is not None else batch.num_frames
    height = batch.height
    width = batch.width

    # TODO(will): remove this once we add input/output validation for stages
    if height is None or width is None:
        raise ValueError("Height and width must be provided")

    # Calculate latent shape
    shape = (
        batch_size,
        self.transformer.num_channels_latents,
        num_frames,
        height // fastvideo_args.pipeline_config.vae_config.arch_config.
        spatial_compression_ratio,
        width // fastvideo_args.pipeline_config.vae_config.arch_config.
        spatial_compression_ratio,
    )

    # Validate generator if it's a list
    if isinstance(generator, list) and len(generator) != batch_size:
        raise ValueError(
            f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
            f" size of {batch_size}. Make sure the batch size matches the length of the generators."
        )
    # Generate or use provided latents
    if latents is None:
        latents = randn_tensor(shape,
                               generator=generator,
                               device=device,
                               dtype=dtype)
    else:
        latents = latents.to(device)

    # Scale the initial noise if needed
    if hasattr(self.scheduler, "init_noise_sigma"):
        latents = latents * self.scheduler.init_noise_sigma
    # Update batch with prepared latents
    batch.latents = latents
    batch.raw_latent_shape = latents.shape

    return batch
fastvideo.pipelines.stages.latent_preparation.LatentPreparationStage.verify_input
verify_input(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult

Verify latent preparation stage inputs.

Source code in fastvideo/pipelines/stages/latent_preparation.py
def verify_input(self, batch: ForwardBatch,
                 fastvideo_args: FastVideoArgs) -> VerificationResult:
    """Verify latent preparation stage inputs."""
    result = VerificationResult()
    result.add_check(
        "prompt_or_embeds", None, lambda _: V.string_or_list_strings(
            batch.prompt) or V.list_not_empty(batch.prompt_embeds))
    result.add_check("prompt_embeds", batch.prompt_embeds,
                     V.list_of_tensors)
    result.add_check("num_videos_per_prompt", batch.num_videos_per_prompt,
                     V.positive_int)
    result.add_check("generator", batch.generator,
                     V.generator_or_list_generators)
    result.add_check("num_frames", batch.num_frames, V.positive_int)
    result.add_check("height", batch.height, V.positive_int)
    result.add_check("width", batch.width, V.positive_int)
    result.add_check("latents", batch.latents, V.none_or_tensor)
    return result
fastvideo.pipelines.stages.latent_preparation.LatentPreparationStage.verify_output
verify_output(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult

Verify latent preparation stage outputs.

Source code in fastvideo/pipelines/stages/latent_preparation.py
def verify_output(self, batch: ForwardBatch,
                  fastvideo_args: FastVideoArgs) -> VerificationResult:
    """Verify latent preparation stage outputs."""
    result = VerificationResult()
    result.add_check("latents", batch.latents,
                     [V.is_tensor, V.with_dims(5)])
    result.add_check("raw_latent_shape", batch.raw_latent_shape, V.is_tuple)
    return result

Functions