Skip to content

stages

Pipeline stages for diffusion models.

This package contains the various stages that can be composed to create complete diffusion pipelines.

Classes

fastvideo.pipelines.stages.CausalDMDDenosingStage

CausalDMDDenosingStage(transformer, scheduler, transformer_2=None, vae=None)

Bases: DenoisingStage

Denoising stage for causal diffusion.

Source code in fastvideo/pipelines/stages/causal_denoising.py
def __init__(self,
             transformer,
             scheduler,
             transformer_2=None,
             vae=None) -> None:
    super().__init__(transformer, scheduler, transformer_2)
    # KV and cross-attention cache state (initialized on first forward)
    self.transformer = transformer
    self.transformer_2 = transformer_2
    self.vae = vae
    # Model-dependent constants (aligned with causal_inference.py assumptions)
    self.num_transformer_blocks = len(self.transformer.blocks)
    self.num_frames_per_block = self.transformer.config.arch_config.num_frames_per_block
    self.sliding_window_num_frames = self.transformer.config.arch_config.sliding_window_num_frames

    try:
        self.local_attn_size = getattr(self.transformer.model,
                                       "local_attn_size",
                                       -1)  # type: ignore
    except Exception:
        self.local_attn_size = -1

Functions

fastvideo.pipelines.stages.CausalDMDDenosingStage.verify_input
verify_input(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult

Verify denoising stage inputs.

Source code in fastvideo/pipelines/stages/causal_denoising.py
def verify_input(self, batch: ForwardBatch,
                 fastvideo_args: FastVideoArgs) -> VerificationResult:
    """Verify denoising stage inputs."""
    result = VerificationResult()
    result.add_check("latents", batch.latents,
                     [V.is_tensor, V.with_dims(5)])
    result.add_check("prompt_embeds", batch.prompt_embeds, V.list_not_empty)
    result.add_check("image_embeds", batch.image_embeds, V.is_list)
    result.add_check("image_latent", batch.image_latent,
                     V.none_or_tensor_with_dims(5))
    result.add_check("num_inference_steps", batch.num_inference_steps,
                     V.positive_int)
    result.add_check("guidance_scale", batch.guidance_scale,
                     V.positive_float)
    result.add_check("eta", batch.eta, V.non_negative_float)
    result.add_check("generator", batch.generator,
                     V.generator_or_list_generators)
    result.add_check("do_classifier_free_guidance",
                     batch.do_classifier_free_guidance, V.bool_value)
    result.add_check(
        "negative_prompt_embeds", batch.negative_prompt_embeds, lambda x:
        not batch.do_classifier_free_guidance or V.list_not_empty(x))
    return result

fastvideo.pipelines.stages.ConditioningStage

Bases: PipelineStage

Stage for applying conditioning to the diffusion process.

This stage handles the application of conditioning, such as classifier-free guidance, to the diffusion process.

Functions

fastvideo.pipelines.stages.ConditioningStage.forward
forward(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> ForwardBatch

Apply conditioning to the diffusion process.

Parameters:

Name Type Description Default
batch ForwardBatch

The current batch information.

required
fastvideo_args FastVideoArgs

The inference arguments.

required

Returns:

Type Description
ForwardBatch

The batch with applied conditioning.

Source code in fastvideo/pipelines/stages/conditioning.py
@torch.no_grad()
def forward(
    self,
    batch: ForwardBatch,
    fastvideo_args: FastVideoArgs,
) -> ForwardBatch:
    """
    Apply conditioning to the diffusion process.

    Args:
        batch: The current batch information.
        fastvideo_args: The inference arguments.

    Returns:
        The batch with applied conditioning.
    """
    # TODO!!
    if not batch.do_classifier_free_guidance:
        return batch
    else:
        return batch

    logger.info("batch.negative_prompt_embeds: %s",
                batch.negative_prompt_embeds)
    logger.info("do_classifier_free_guidance: %s",
                batch.do_classifier_free_guidance)
    logger.info("cfg_scale: %s", batch.guidance_scale)

    # Ensure negative prompt embeddings are available
    assert batch.negative_prompt_embeds is not None, (
        "Negative prompt embeddings are required for classifier-free guidance"
    )

    # Concatenate primary embeddings and masks
    batch.prompt_embeds = torch.cat(
        [batch.negative_prompt_embeds, batch.prompt_embeds])
    if batch.attention_mask is not None:
        batch.attention_mask = torch.cat(
            [batch.negative_attention_mask, batch.attention_mask])

    # Concatenate secondary embeddings and masks if present
    if batch.prompt_embeds_2 is not None:
        batch.prompt_embeds_2 = torch.cat(
            [batch.negative_prompt_embeds_2, batch.prompt_embeds_2])
    if batch.attention_mask_2 is not None:
        batch.attention_mask_2 = torch.cat(
            [batch.negative_attention_mask_2, batch.attention_mask_2])

    return batch
fastvideo.pipelines.stages.ConditioningStage.verify_input
verify_input(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult

Verify conditioning stage inputs.

Source code in fastvideo/pipelines/stages/conditioning.py
def verify_input(self, batch: ForwardBatch,
                 fastvideo_args: FastVideoArgs) -> VerificationResult:
    """Verify conditioning stage inputs."""
    result = VerificationResult()
    if not batch.prompt_embeds:
        # No text encoder/prompt embeddings: skip checks and effectively disable CFG.
        batch.do_classifier_free_guidance = False
        return result
    result.add_check("do_classifier_free_guidance",
                     batch.do_classifier_free_guidance, V.bool_value)
    result.add_check("guidance_scale", batch.guidance_scale,
                     V.positive_float)
    # Matrix-Game allow empty prompt
    # embeddings when CFG isn't enabled.
    if batch.do_classifier_free_guidance or batch.prompt_embeds:
        result.add_check("prompt_embeds", batch.prompt_embeds,
                         V.list_not_empty)
        result.add_check(
            "negative_prompt_embeds", batch.negative_prompt_embeds, lambda
            x: not batch.do_classifier_free_guidance or V.list_not_empty(x))
    return result
fastvideo.pipelines.stages.ConditioningStage.verify_output
verify_output(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult

Verify conditioning stage outputs.

Source code in fastvideo/pipelines/stages/conditioning.py
def verify_output(self, batch: ForwardBatch,
                  fastvideo_args: FastVideoArgs) -> VerificationResult:
    """Verify conditioning stage outputs."""
    result = VerificationResult()
    if batch.prompt_embeds is None or not batch.prompt_embeds:
        batch.do_classifier_free_guidance = False
        return result
    if batch.do_classifier_free_guidance or batch.prompt_embeds:
        result.add_check("prompt_embeds", batch.prompt_embeds,
                         V.list_not_empty)
    return result

fastvideo.pipelines.stages.Cosmos25DenoisingStage

Cosmos25DenoisingStage(transformer, scheduler, pipeline=None)

Bases: CosmosDenoisingStage

Denoising stage for Cosmos 2.5 DiT (expects 1D/2D timestep, not 5D).

Source code in fastvideo/pipelines/stages/denoising.py
def __init__(self, transformer, scheduler, pipeline=None) -> None:
    super().__init__(transformer, scheduler, pipeline)

fastvideo.pipelines.stages.Cosmos25LatentPreparationStage

Cosmos25LatentPreparationStage(scheduler, transformer, vae=None)

Bases: CosmosLatentPreparationStage

Latent preparation for Cosmos 2.5 DiT input conventions.

Source code in fastvideo/pipelines/stages/latent_preparation.py
def __init__(self, scheduler, transformer, vae=None) -> None:
    super().__init__()
    self.scheduler = scheduler
    self.transformer = transformer
    self.vae = vae

Functions

fastvideo.pipelines.stages.Cosmos25LatentPreparationStage.adjust_video_length
adjust_video_length(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> int

Adjust video length based on VAE version.

Parameters:

Name Type Description Default
batch ForwardBatch

The current batch information.

required
fastvideo_args FastVideoArgs

The inference arguments.

required

Returns:

Type Description
int

The batch with adjusted video length.

Source code in fastvideo/pipelines/stages/latent_preparation.py
def adjust_video_length(self, batch: ForwardBatch,
                        fastvideo_args: FastVideoArgs) -> int:
    """
    Adjust video length based on VAE version.

    Args:
        batch: The current batch information.
        fastvideo_args: The inference arguments.

    Returns:
        The batch with adjusted video length.
    """

    video_length = batch.num_frames
    use_temporal_scaling_frames = fastvideo_args.pipeline_config.vae_config.use_temporal_scaling_frames
    if use_temporal_scaling_frames:
        temporal_scale_factor = fastvideo_args.pipeline_config.vae_config.arch_config.temporal_compression_ratio
        latent_num_frames = (video_length - 1) // temporal_scale_factor + 1
    else:  # stepvideo only
        latent_num_frames = video_length // 17 * 3
    return int(latent_num_frames)
fastvideo.pipelines.stages.Cosmos25LatentPreparationStage.verify_input
verify_input(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult

Verify Cosmos latent preparation stage inputs.

Source code in fastvideo/pipelines/stages/latent_preparation.py
def verify_input(self, batch: ForwardBatch,
                 fastvideo_args: FastVideoArgs) -> VerificationResult:
    """Verify Cosmos latent preparation stage inputs."""
    result = VerificationResult()
    result.add_check(
        "prompt_or_embeds", None, lambda _: V.string_or_list_strings(
            batch.prompt) or V.list_not_empty(batch.prompt_embeds))
    result.add_check("prompt_embeds", batch.prompt_embeds,
                     V.list_of_tensors)
    result.add_check("num_videos_per_prompt", batch.num_videos_per_prompt,
                     V.positive_int)
    result.add_check("generator", batch.generator,
                     V.generator_or_list_generators)
    result.add_check("num_frames", batch.num_frames, V.positive_int)
    result.add_check("height", batch.height, V.positive_int)
    result.add_check("width", batch.width, V.positive_int)
    result.add_check("latents", batch.latents, V.none_or_tensor)
    return result
fastvideo.pipelines.stages.Cosmos25LatentPreparationStage.verify_output
verify_output(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult

Verify latent preparation stage outputs.

Source code in fastvideo/pipelines/stages/latent_preparation.py
def verify_output(self, batch: ForwardBatch,
                  fastvideo_args: FastVideoArgs) -> VerificationResult:
    """Verify latent preparation stage outputs."""
    result = VerificationResult()
    result.add_check("latents", batch.latents,
                     [V.is_tensor, V.with_dims(5)])
    result.add_check("raw_latent_shape", batch.raw_latent_shape, V.is_tuple)
    return result

fastvideo.pipelines.stages.Cosmos25TextEncodingStage

Cosmos25TextEncodingStage(text_encoder)

Bases: PipelineStage

Cosmos 2.5 text encoding stage.

Cosmos 2.5 uses Reason1 (Qwen2.5-VL) and relies on the encoder's compute_text_embeddings_online().

Source code in fastvideo/pipelines/stages/text_encoding.py
def __init__(self, text_encoder) -> None:
    super().__init__()
    self.text_encoder = text_encoder

fastvideo.pipelines.stages.Cosmos25TimestepPreparationStage

Cosmos25TimestepPreparationStage(scheduler)

Bases: TimestepPreparationStage

Cosmos 2.5 timestep preparation with scheduler-specific kwargs.

Source code in fastvideo/pipelines/stages/timestep_preparation.py
def __init__(self, scheduler) -> None:
    self.scheduler = scheduler

fastvideo.pipelines.stages.CosmosDenoisingStage

CosmosDenoisingStage(transformer, scheduler, pipeline=None)

Bases: DenoisingStage

Denoising stage for Cosmos models using FlowMatchEulerDiscreteScheduler.

Source code in fastvideo/pipelines/stages/denoising.py
def __init__(self, transformer, scheduler, pipeline=None) -> None:
    super().__init__(transformer, scheduler, pipeline)

Functions

fastvideo.pipelines.stages.CosmosDenoisingStage.verify_input
verify_input(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult

Verify Cosmos denoising stage inputs.

Source code in fastvideo/pipelines/stages/denoising.py
def verify_input(self, batch: ForwardBatch,
                 fastvideo_args: FastVideoArgs) -> VerificationResult:
    """Verify Cosmos denoising stage inputs."""
    result = VerificationResult()
    result.add_check("latents", batch.latents,
                     [V.is_tensor, V.with_dims(5)])
    result.add_check("prompt_embeds", batch.prompt_embeds, V.list_not_empty)
    result.add_check("num_inference_steps", batch.num_inference_steps,
                     V.positive_int)
    result.add_check("guidance_scale", batch.guidance_scale,
                     V.positive_float)
    result.add_check("do_classifier_free_guidance",
                     batch.do_classifier_free_guidance, V.bool_value)
    result.add_check(
        "negative_prompt_embeds", batch.negative_prompt_embeds, lambda x:
        not batch.do_classifier_free_guidance or V.list_not_empty(x))
    return result
fastvideo.pipelines.stages.CosmosDenoisingStage.verify_output
verify_output(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult

Verify Cosmos denoising stage outputs.

Source code in fastvideo/pipelines/stages/denoising.py
def verify_output(self, batch: ForwardBatch,
                  fastvideo_args: FastVideoArgs) -> VerificationResult:
    """Verify Cosmos denoising stage outputs."""
    result = VerificationResult()
    result.add_check("latents", batch.latents,
                     [V.is_tensor, V.with_dims(5)])
    return result

fastvideo.pipelines.stages.CosmosLatentPreparationStage

CosmosLatentPreparationStage(scheduler, transformer, vae=None)

Bases: PipelineStage

Cosmos-specific latent preparation stage that properly handles the tensor shapes and conditioning masks required by the Cosmos transformer.

This stage replicates the logic from diffusers' Cosmos2VideoToWorldPipeline.prepare_latents()

Source code in fastvideo/pipelines/stages/latent_preparation.py
def __init__(self, scheduler, transformer, vae=None) -> None:
    super().__init__()
    self.scheduler = scheduler
    self.transformer = transformer
    self.vae = vae

fastvideo.pipelines.stages.DecodingStage

DecodingStage(vae, pipeline=None)

Bases: PipelineStage

Stage for decoding latent representations into pixel space.

This stage handles the decoding of latent representations into the final output format (e.g., pixel values).

Source code in fastvideo/pipelines/stages/decoding.py
def __init__(self, vae, pipeline=None) -> None:
    self.vae: ParallelTiledVAE = vae
    self.pipeline = weakref.ref(pipeline) if pipeline else None

Functions

fastvideo.pipelines.stages.DecodingStage.decode
decode(latents: Tensor, fastvideo_args: FastVideoArgs) -> Tensor

Decode latent representations into pixel space using VAE.

Parameters:

Name Type Description Default
latents Tensor

Input latent tensor with shape (batch, channels, frames, height_latents, width_latents)

required
fastvideo_args FastVideoArgs

Configuration containing: - disable_autocast: Whether to disable automatic mixed precision (default: False) - pipeline_config.vae_precision: VAE computation precision ("fp32", "fp16", "bf16") - pipeline_config.vae_tiling: Whether to enable VAE tiling for memory efficiency

required

Returns:

Type Description
Tensor

Decoded video tensor with shape (batch, channels, frames, height, width),

Tensor

normalized to [0, 1] range and moved to CPU as float32

Source code in fastvideo/pipelines/stages/decoding.py
@torch.no_grad()
def decode(self, latents: torch.Tensor,
           fastvideo_args: FastVideoArgs) -> torch.Tensor:
    """
    Decode latent representations into pixel space using VAE.

    Args:
        latents: Input latent tensor with shape (batch, channels, frames, height_latents, width_latents)
        fastvideo_args: Configuration containing:
            - disable_autocast: Whether to disable automatic mixed precision (default: False)
            - pipeline_config.vae_precision: VAE computation precision ("fp32", "fp16", "bf16")
            - pipeline_config.vae_tiling: Whether to enable VAE tiling for memory efficiency

    Returns:
        Decoded video tensor with shape (batch, channels, frames, height, width), 
        normalized to [0, 1] range and moved to CPU as float32
    """
    self.vae = self.vae.to(get_local_torch_device())
    latents = latents.to(get_local_torch_device())

    # Setup VAE precision
    vae_dtype = PRECISION_TO_TYPE[
        fastvideo_args.pipeline_config.vae_precision]
    vae_autocast_enabled = (
        vae_dtype != torch.float32) and not fastvideo_args.disable_autocast

    latents = self._denormalize_latents(latents)

    # Decode latents
    with torch.autocast(device_type="cuda",
                        dtype=vae_dtype,
                        enabled=vae_autocast_enabled):
        if fastvideo_args.pipeline_config.vae_tiling:
            self.vae.enable_tiling()
        # if fastvideo_args.vae_sp:
        #     self.vae.enable_parallel()
        if not vae_autocast_enabled:
            latents = latents.to(vae_dtype)
        image = self.vae.decode(latents)

    # Normalize image to [0, 1] range
    image = (image / 2 + 0.5).clamp(0, 1)
    return image
fastvideo.pipelines.stages.DecodingStage.forward
forward(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> ForwardBatch

Decode latent representations into pixel space.

This method processes the batch through the VAE decoder, converting latent representations to pixel-space video/images. It also optionally decodes trajectory latents for visualization purposes.

Parameters:

Name Type Description Default
batch ForwardBatch

The current batch containing: - latents: Tensor to decode (batch, channels, frames, height_latents, width_latents) - return_trajectory_decoded (optional): Flag to decode trajectory latents - trajectory_latents (optional): Latents at different timesteps - trajectory_timesteps (optional): Corresponding timesteps

required
fastvideo_args FastVideoArgs

Configuration containing: - output_type: "latent" to skip decoding, otherwise decode to pixels - vae_cpu_offload: Whether to offload VAE to CPU after decoding - model_loaded: Track VAE loading state - model_paths: Path to VAE model if loading needed

required

Returns:

Type Description
ForwardBatch

Modified batch with: - output: Decoded frames (batch, channels, frames, height, width) as CPU float32 - trajectory_decoded (if requested): List of decoded frames per timestep

Source code in fastvideo/pipelines/stages/decoding.py
@torch.no_grad()
def forward(
    self,
    batch: ForwardBatch,
    fastvideo_args: FastVideoArgs,
) -> ForwardBatch:
    """
    Decode latent representations into pixel space.

    This method processes the batch through the VAE decoder, converting latent
    representations to pixel-space video/images. It also optionally decodes
    trajectory latents for visualization purposes.

    Args:
        batch: The current batch containing:
            - latents: Tensor to decode (batch, channels, frames, height_latents, width_latents)
            - return_trajectory_decoded (optional): Flag to decode trajectory latents
            - trajectory_latents (optional): Latents at different timesteps
            - trajectory_timesteps (optional): Corresponding timesteps
        fastvideo_args: Configuration containing:
            - output_type: "latent" to skip decoding, otherwise decode to pixels
            - vae_cpu_offload: Whether to offload VAE to CPU after decoding
            - model_loaded: Track VAE loading state
            - model_paths: Path to VAE model if loading needed

    Returns:
        Modified batch with:
            - output: Decoded frames (batch, channels, frames, height, width) as CPU float32
            - trajectory_decoded (if requested): List of decoded frames per timestep
    """
    # load vae if not already loaded (used for memory constrained devices)
    pipeline = self.pipeline() if self.pipeline else None
    if not fastvideo_args.model_loaded["vae"]:
        loader = VAELoader()
        self.vae = loader.load(fastvideo_args.model_paths["vae"],
                               fastvideo_args)
        if pipeline:
            pipeline.add_module("vae", self.vae)
        fastvideo_args.model_loaded["vae"] = True

    if fastvideo_args.output_type == "latent":
        frames = batch.latents
    else:
        frames = self.decode(batch.latents, fastvideo_args)

    # decode trajectory latents if needed
    if batch.return_trajectory_decoded:
        batch.trajectory_decoded = []
        assert batch.trajectory_latents is not None, "batch should have trajectory latents"
        for idx in range(batch.trajectory_latents.shape[1]):
            # batch.trajectory_latents is [batch_size, timesteps, channels, frames, height, width]
            cur_latent = batch.trajectory_latents[:, idx, :, :, :, :]
            cur_timestep = batch.trajectory_timesteps[idx]
            logger.info("decoding trajectory latent for timestep: %s",
                        cur_timestep)
            decoded_frames = self.decode(cur_latent, fastvideo_args)
            batch.trajectory_decoded.append(decoded_frames.cpu().float())

    # Convert to CPU float32 for compatibility
    frames = frames.cpu().float()

    # Crop padding if this is a LongCat refinement
    if hasattr(batch, 'num_cond_frames_added') and hasattr(
            batch, 'new_frame_size_before_padding'):
        num_cond_frames_added = batch.num_cond_frames_added
        new_frame_size = batch.new_frame_size_before_padding
        if num_cond_frames_added > 0 or frames.shape[2] != new_frame_size:
            # frames is [B, C, T, H, W], crop temporal dimension
            frames = frames[:, :,
                            num_cond_frames_added:num_cond_frames_added +
                            new_frame_size, :, :]
            logger.info(
                "Cropped LongCat refinement padding: %s:%s, final shape: %s",
                num_cond_frames_added,
                num_cond_frames_added + new_frame_size, frames.shape)

    # Update batch with decoded image
    batch.output = frames

    # Offload models if needed
    if hasattr(self, 'maybe_free_model_hooks'):
        self.maybe_free_model_hooks()

    if fastvideo_args.vae_cpu_offload:
        self.vae.to("cpu")

    if torch.backends.mps.is_available():
        del self.vae
        if pipeline is not None and "vae" in pipeline.modules:
            del pipeline.modules["vae"]
        fastvideo_args.model_loaded["vae"] = False

    return batch
fastvideo.pipelines.stages.DecodingStage.streaming_decode
streaming_decode(latents: Tensor, fastvideo_args: FastVideoArgs, cache: list[Tensor | None] | None = None, is_first_chunk: bool = False) -> tuple[Tensor, list[Tensor | None]]

Decode latent representations into pixel space using VAE with streaming cache.

Parameters:

Name Type Description Default
latents Tensor

Input latent tensor with shape (batch, channels, frames, height_latents, width_latents)

required
fastvideo_args FastVideoArgs

Configuration object.

required
cache list[Tensor | None] | None

VAE cache from previous call, or None to initialize a new cache.

None
is_first_chunk bool

Whether this is the first chunk.

False

Returns:

Type Description
tuple[Tensor, list[Tensor | None]]

A tuple of (decoded_frames, updated_cache).

Source code in fastvideo/pipelines/stages/decoding.py
@torch.no_grad()
def streaming_decode(
    self,
    latents: torch.Tensor,
    fastvideo_args: FastVideoArgs,
    cache: list[torch.Tensor | None] | None = None,
    is_first_chunk: bool = False,
) -> tuple[torch.Tensor, list[torch.Tensor | None]]:
    """
    Decode latent representations into pixel space using VAE with streaming cache.

    Args:
        latents: Input latent tensor with shape (batch, channels, frames, height_latents, width_latents)
        fastvideo_args: Configuration object.
        cache: VAE cache from previous call, or None to initialize a new cache.
        is_first_chunk: Whether this is the first chunk.

    Returns:
        A tuple of (decoded_frames, updated_cache).
    """
    self.vae = self.vae.to(get_local_torch_device())
    latents = latents.to(get_local_torch_device())

    # Setup VAE precision
    vae_dtype = PRECISION_TO_TYPE[
        fastvideo_args.pipeline_config.vae_precision]
    vae_autocast_enabled = (
        vae_dtype != torch.float32) and not fastvideo_args.disable_autocast

    latents = self._denormalize_latents(latents)

    # Initialize cache if needed
    if cache is None:
        cache = self.vae.get_streaming_cache()

    # Decode latents with streaming
    with torch.autocast(device_type="cuda",
                        dtype=vae_dtype,
                        enabled=vae_autocast_enabled):
        if fastvideo_args.pipeline_config.vae_tiling:
            self.vae.enable_tiling()
        if not vae_autocast_enabled:
            latents = latents.to(vae_dtype)
        image, cache = self.vae.streaming_decode(latents, cache,
                                                 is_first_chunk)

    # Normalize image to [0, 1] range
    image = (image / 2 + 0.5).clamp(0, 1)
    assert cache is not None, "cache should not be None after streaming_decode"
    return image, cache
fastvideo.pipelines.stages.DecodingStage.verify_input
verify_input(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult

Verify decoding stage inputs.

Source code in fastvideo/pipelines/stages/decoding.py
def verify_input(self, batch: ForwardBatch,
                 fastvideo_args: FastVideoArgs) -> VerificationResult:
    """Verify decoding stage inputs."""
    result = VerificationResult()
    # Denoised latents for VAE decoding: [batch_size, channels, frames, height_latents, width_latents]
    result.add_check("latents", batch.latents,
                     [V.is_tensor, V.with_dims(5)])
    return result
fastvideo.pipelines.stages.DecodingStage.verify_output
verify_output(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult

Verify decoding stage outputs.

Source code in fastvideo/pipelines/stages/decoding.py
def verify_output(self, batch: ForwardBatch,
                  fastvideo_args: FastVideoArgs) -> VerificationResult:
    """Verify decoding stage outputs."""
    result = VerificationResult()
    # Decoded video/images: [batch_size, channels, frames, height, width]
    result.add_check("output", batch.output, [V.is_tensor, V.with_dims(5)])
    return result

fastvideo.pipelines.stages.DenoisingStage

DenoisingStage(transformer, scheduler, pipeline=None, transformer_2=None, vae=None)

Bases: PipelineStage

Stage for running the denoising loop in diffusion pipelines.

This stage handles the iterative denoising process that transforms the initial noise into the final output.

Source code in fastvideo/pipelines/stages/denoising.py
def __init__(self,
             transformer,
             scheduler,
             pipeline=None,
             transformer_2=None,
             vae=None) -> None:
    super().__init__()
    self.transformer = transformer
    self.transformer_2 = transformer_2
    self.scheduler = scheduler
    self.vae = vae
    self.pipeline = weakref.ref(pipeline) if pipeline else None
    attn_head_size = self.transformer.hidden_size // self.transformer.num_attention_heads
    self.attn_backend = get_attn_backend(
        head_size=attn_head_size,
        dtype=torch.float16,  # TODO(will): hack
        supported_attention_backends=(
            AttentionBackendEnum.SLIDING_TILE_ATTN,
            AttentionBackendEnum.VIDEO_SPARSE_ATTN,
            AttentionBackendEnum.VMOBA_ATTN,
            AttentionBackendEnum.FLASH_ATTN,
            AttentionBackendEnum.TORCH_SDPA,
            AttentionBackendEnum.SAGE_ATTN_THREE)  # hack
    )

Functions

fastvideo.pipelines.stages.DenoisingStage.forward
forward(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> ForwardBatch

Run the denoising loop.

Parameters:

Name Type Description Default
batch ForwardBatch

The current batch information.

required
fastvideo_args FastVideoArgs

The inference arguments.

required

Returns:

Type Description
ForwardBatch

The batch with denoised latents.

Source code in fastvideo/pipelines/stages/denoising.py
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
def forward(
    self,
    batch: ForwardBatch,
    fastvideo_args: FastVideoArgs,
) -> ForwardBatch:
    """
    Run the denoising loop.

    Args:
        batch: The current batch information.
        fastvideo_args: The inference arguments.

    Returns:
        The batch with denoised latents.
    """
    pipeline = self.pipeline() if self.pipeline else None
    if not fastvideo_args.model_loaded["transformer"]:
        loader = TransformerLoader()
        self.transformer = loader.load(
            fastvideo_args.model_paths["transformer"], fastvideo_args)
        if pipeline:
            pipeline.add_module("transformer", self.transformer)
        fastvideo_args.model_loaded["transformer"] = True

    # Prepare extra step kwargs for scheduler
    extra_step_kwargs = self.prepare_extra_func_kwargs(
        self.scheduler.step,
        {
            "generator": batch.generator,
            "eta": batch.eta
        },
    )

    # Setup precision and autocast settings
    # TODO(will): make the precision configurable for inference
    # target_dtype = PRECISION_TO_TYPE[fastvideo_args.precision]
    target_dtype = torch.bfloat16
    autocast_enabled = (target_dtype != torch.float32
                        ) and not fastvideo_args.disable_autocast

    # Get timesteps and calculate warmup steps
    timesteps = batch.timesteps
    # TODO(will): remove this once we add input/output validation for stages
    if timesteps is None:
        raise ValueError("Timesteps must be provided")
    num_inference_steps = batch.num_inference_steps
    num_warmup_steps = len(
        timesteps) - num_inference_steps * self.scheduler.order

    # Prepare image latents and embeddings for I2V generation
    image_embeds = batch.image_embeds
    if len(image_embeds) > 0:
        assert not torch.isnan(
            image_embeds[0]).any(), "image_embeds contains nan"
        image_embeds = [
            image_embed.to(target_dtype) for image_embed in image_embeds
        ]

    image_kwargs = self.prepare_extra_func_kwargs(
        self.transformer.forward,
        {
            "encoder_hidden_states_image": image_embeds,
            "mask_strategy": dict_to_3d_list(
                None, t_max=50, l_max=60, h_max=24)
        },
    )

    pos_cond_kwargs = self.prepare_extra_func_kwargs(
        self.transformer.forward,
        {
            "encoder_hidden_states_2": batch.clip_embedding_pos,
            "encoder_attention_mask": batch.prompt_attention_mask,
        },
    )

    neg_cond_kwargs = self.prepare_extra_func_kwargs(
        self.transformer.forward,
        {
            "encoder_hidden_states_2": batch.clip_embedding_neg,
            "encoder_attention_mask": batch.negative_attention_mask,
        },
    )

    action_kwargs = self.prepare_extra_func_kwargs(
        self.transformer.forward,
        {
            "mouse_cond": batch.mouse_cond,
            "keyboard_cond": batch.keyboard_cond,
        },
    )

    # Prepare STA parameters
    if st_attn_available and self.attn_backend == SlidingTileAttentionBackend:
        self.prepare_sta_param(batch, fastvideo_args)

    # Get latents and embeddings
    latents = batch.latents
    prompt_embeds = batch.prompt_embeds
    assert not torch.isnan(
        prompt_embeds[0]).any(), "prompt_embeds contains nan"
    if batch.do_classifier_free_guidance:
        neg_prompt_embeds = batch.negative_prompt_embeds
        assert neg_prompt_embeds is not None
        assert not torch.isnan(
            neg_prompt_embeds[0]).any(), "neg_prompt_embeds contains nan"

    # (Wan2.2) Calculate timestep to switch from high noise expert to low noise expert
    boundary_ratio = fastvideo_args.pipeline_config.dit_config.boundary_ratio
    if batch.boundary_ratio is not None:
        logger.info("Overriding boundary ratio from %s to %s",
                    boundary_ratio, batch.boundary_ratio)
        boundary_ratio = batch.boundary_ratio

    if boundary_ratio is not None:
        boundary_timestep = boundary_ratio * self.scheduler.num_train_timesteps
    else:
        boundary_timestep = None
    latent_model_input = latents.to(target_dtype)
    assert latent_model_input.shape[0] == 1, "only support batch size 1"

    if fastvideo_args.pipeline_config.ti2v_task and batch.pil_image is not None:
        # TI2V directly replaces the first frame of the latent with
        # the image latent instead of appending along the channel dim
        assert batch.image_latent is None, "TI2V task should not have image latents"
        assert self.vae is not None, "VAE is not provided for TI2V task"
        z = self.vae.encode(batch.pil_image).mean.float()
        if (hasattr(self.vae, "shift_factor")
                and self.vae.shift_factor is not None):
            if isinstance(self.vae.shift_factor, torch.Tensor):
                z -= self.vae.shift_factor.to(z.device, z.dtype)
            else:
                z -= self.vae.shift_factor

        if isinstance(self.vae.scaling_factor, torch.Tensor):
            z = z * self.vae.scaling_factor.to(z.device, z.dtype)
        else:
            z = z * self.vae.scaling_factor

        latent_model_input = latent_model_input.squeeze(0)
        _, mask2 = masks_like([latent_model_input], zero=True)

        latent_model_input = (1. -
                              mask2[0]) * z + mask2[0] * latent_model_input
        # latent_model_input = latent_model_input.unsqueeze(0)
        latent_model_input = latent_model_input.to(get_local_torch_device())
        latents = latent_model_input
        F = batch.num_frames
        temporal_scale = fastvideo_args.pipeline_config.vae_config.arch_config.scale_factor_temporal
        spatial_scale = fastvideo_args.pipeline_config.vae_config.arch_config.scale_factor_spatial
        patch_size = fastvideo_args.pipeline_config.dit_config.arch_config.patch_size
        seq_len = ((F - 1) // temporal_scale +
                   1) * (batch.height // spatial_scale) * (
                       batch.width // spatial_scale) // (patch_size[1] *
                                                         patch_size[2])

    # Initialize lists for ODE trajectory
    trajectory_timesteps: list[torch.Tensor] = []
    trajectory_latents: list[torch.Tensor] = []

    # Run denoising loop
    with self.progress_bar(total=num_inference_steps) as progress_bar:
        for i, t in enumerate(timesteps):
            # Skip if interrupted
            if hasattr(self, 'interrupt') and self.interrupt:
                continue

            if boundary_timestep is None or t >= boundary_timestep:
                if (fastvideo_args.dit_cpu_offload
                        and not fastvideo_args.dit_layerwise_offload
                        and self.transformer_2 is not None and next(
                            self.transformer_2.parameters()).device.type
                        == 'cuda'):
                    self.transformer_2.to('cpu')
                current_model = self.transformer
                if (fastvideo_args.dit_cpu_offload
                        and not fastvideo_args.dit_layerwise_offload
                        and not fastvideo_args.use_fsdp_inference
                        and current_model is not None):
                    transformer_device = next(
                        current_model.parameters()).device.type
                    if transformer_device == 'cpu':
                        current_model.to(get_local_torch_device())
                current_guidance_scale = batch.guidance_scale
            else:
                # low-noise stage in wan2.2
                if (fastvideo_args.dit_cpu_offload
                        and not fastvideo_args.dit_layerwise_offload
                        and next(self.transformer.parameters()).device.type
                        == 'cuda'):
                    self.transformer.to('cpu')
                current_model = self.transformer_2
                if (fastvideo_args.dit_cpu_offload
                        and not fastvideo_args.dit_layerwise_offload
                        and not fastvideo_args.use_fsdp_inference
                        and current_model is not None):
                    transformer_2_device = next(
                        current_model.parameters()).device.type
                    if transformer_2_device == 'cpu':
                        current_model.to(get_local_torch_device())
                current_guidance_scale = batch.guidance_scale_2
            assert current_model is not None, "current_model is None"

            # Expand latents for V2V/I2V
            latent_model_input = latents.to(target_dtype)
            if batch.video_latent is not None:
                latent_model_input = torch.cat([
                    latent_model_input, batch.video_latent,
                    torch.zeros_like(latents)
                ],
                                               dim=1).to(target_dtype)
            elif batch.image_latent is not None:
                assert not fastvideo_args.pipeline_config.ti2v_task, "image latents should not be provided for TI2V task"
                latent_model_input = torch.cat(
                    [latent_model_input, batch.image_latent],
                    dim=1).to(target_dtype)

            assert not torch.isnan(
                latent_model_input).any(), "latent_model_input contains nan"
            if fastvideo_args.pipeline_config.ti2v_task and batch.pil_image is not None:
                timestep = torch.stack([t]).to(get_local_torch_device())
                temp_ts = (mask2[0][0][:, ::2, ::2] * timestep).flatten()
                temp_ts = torch.cat([
                    temp_ts,
                    temp_ts.new_ones(seq_len - temp_ts.size(0)) * timestep
                ])
                timestep = temp_ts.unsqueeze(0)
                t_expand = timestep.repeat(latent_model_input.shape[0], 1)
            else:
                t_expand = t.repeat(latent_model_input.shape[0])

            latent_model_input = self.scheduler.scale_model_input(
                latent_model_input, t)

            # Prepare inputs for transformer
            guidance_expand = (
                torch.tensor(
                    [fastvideo_args.pipeline_config.embedded_cfg_scale] *
                    latent_model_input.shape[0],
                    dtype=torch.float32,
                    device=get_local_torch_device(),
                ).to(target_dtype) *
                1000.0 if fastvideo_args.pipeline_config.embedded_cfg_scale
                is not None else None)

            # Predict noise residual
            with torch.autocast(device_type="cuda",
                                dtype=target_dtype,
                                enabled=autocast_enabled):
                if (st_attn_available
                        and self.attn_backend == SlidingTileAttentionBackend
                    ) or (vsa_available and self.attn_backend
                          == VideoSparseAttentionBackend):
                    self.attn_metadata_builder_cls = self.attn_backend.get_builder_cls(
                    )

                    if self.attn_metadata_builder_cls is not None:
                        self.attn_metadata_builder = self.attn_metadata_builder_cls(
                        )
                        # TODO(will): clean this up
                        attn_metadata = self.attn_metadata_builder.build(  # type: ignore
                            current_timestep=i,  # type: ignore
                            raw_latent_shape=batch.
                            raw_latent_shape[2:5],  # type: ignore
                            patch_size=fastvideo_args.
                            pipeline_config.  # type: ignore
                            dit_config.patch_size,  # type: ignore
                            STA_param=batch.STA_param,  # type: ignore
                            VSA_sparsity=fastvideo_args.
                            VSA_sparsity,  # type: ignore
                            device=get_local_torch_device(),
                        )
                        assert attn_metadata is not None, "attn_metadata cannot be None"
                    else:
                        attn_metadata = None
                elif (vmoba_attn_available
                      and self.attn_backend == VMOBAAttentionBackend):
                    self.attn_metadata_builder_cls = self.attn_backend.get_builder_cls(
                    )
                    if self.attn_metadata_builder_cls is not None:
                        self.attn_metadata_builder = self.attn_metadata_builder_cls(
                        )
                        # Prepare V-MoBA parameters from config
                        moba_params = fastvideo_args.moba_config.copy()
                        moba_params.update({
                            "current_timestep":
                            i,
                            "raw_latent_shape":
                            batch.raw_latent_shape[2:5],
                            "patch_size":
                            fastvideo_args.pipeline_config.dit_config.
                            patch_size,
                            "device":
                            get_local_torch_device(),
                        })
                        attn_metadata = self.attn_metadata_builder.build(
                            **moba_params)
                        assert attn_metadata is not None, "attn_metadata cannot be None"
                    else:
                        attn_metadata = None
                else:
                    attn_metadata = None
                # TODO(will): finalize the interface. vLLM uses this to
                # support torch dynamo compilation. They pass in
                # attn_metadata, vllm_config, and num_tokens. We can pass in
                # fastvideo_args or training_args, and attn_metadata.
                batch.is_cfg_negative = False
                with set_forward_context(
                        current_timestep=i,
                        attn_metadata=attn_metadata,
                        forward_batch=batch,
                        # fastvideo_args=fastvideo_args
                ):
                    # Run transformer
                    noise_pred = current_model(
                        latent_model_input,
                        prompt_embeds,
                        t_expand,
                        guidance=guidance_expand,
                        **image_kwargs,
                        **pos_cond_kwargs,
                        **action_kwargs,
                    )

                if batch.do_classifier_free_guidance:
                    batch.is_cfg_negative = True
                    with set_forward_context(
                            current_timestep=i,
                            attn_metadata=attn_metadata,
                            forward_batch=batch,
                    ):
                        noise_pred_uncond = current_model(
                            latent_model_input,
                            neg_prompt_embeds,
                            t_expand,
                            guidance=guidance_expand,
                            **image_kwargs,
                            **neg_cond_kwargs,
                            **action_kwargs,
                        )

                    noise_pred_text = noise_pred
                    noise_pred = noise_pred_uncond + current_guidance_scale * (
                        noise_pred_text - noise_pred_uncond)

                    # Apply guidance rescale if needed
                    if batch.guidance_rescale > 0.0:
                        # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
                        noise_pred = self.rescale_noise_cfg(
                            noise_pred,
                            noise_pred_text,
                            guidance_rescale=batch.guidance_rescale,
                        )
                # Compute the previous noisy sample
                latents = self.scheduler.step(noise_pred,
                                              t,
                                              latents,
                                              **extra_step_kwargs,
                                              return_dict=False)[0]
                if fastvideo_args.pipeline_config.ti2v_task and batch.pil_image is not None:
                    latents = latents.squeeze(0)
                    latents = (1. - mask2[0]) * z + mask2[0] * latents
                    # latents = latents.unsqueeze(0)

            # save trajectory latents if needed
            if batch.return_trajectory_latents:
                trajectory_timesteps.append(t)
                trajectory_latents.append(latents)

            # Update progress bar
            if i == len(timesteps) - 1 or (
                (i + 1) > num_warmup_steps and
                (i + 1) % self.scheduler.order == 0
                    and progress_bar is not None):
                progress_bar.update()

    trajectory_tensor: torch.Tensor | None = None
    if trajectory_latents:
        trajectory_tensor = torch.stack(trajectory_latents, dim=1)
        trajectory_timesteps_tensor = torch.stack(trajectory_timesteps,
                                                  dim=0)
    else:
        trajectory_tensor = None
        trajectory_timesteps_tensor = None

    if trajectory_tensor is not None and trajectory_timesteps_tensor is not None:
        batch.trajectory_timesteps = trajectory_timesteps_tensor.cpu()
        batch.trajectory_latents = trajectory_tensor.cpu()

    # Update batch with final latents
    batch.latents = latents

    if fastvideo_args.dit_layerwise_offload:
        mgr = getattr(self.transformer, "_layerwise_offload_manager", None)
        if mgr is not None and getattr(mgr, "enabled", False):
            mgr.release_all()
        if self.transformer_2 is not None:
            mgr2 = getattr(self.transformer_2, "_layerwise_offload_manager",
                           None)
            if mgr2 is not None and getattr(mgr2, "enabled", False):
                mgr2.release_all()

    # Save STA mask search results if needed
    if st_attn_available and self.attn_backend == SlidingTileAttentionBackend and fastvideo_args.STA_mode == STA_Mode.STA_SEARCHING:
        self.save_sta_search_results(batch)

    # deallocate transformer if on mps
    if torch.backends.mps.is_available():
        logger.info("Memory before deallocating transformer: %s",
                    torch.mps.current_allocated_memory())
        del self.transformer
        if pipeline is not None and "transformer" in pipeline.modules:
            del pipeline.modules["transformer"]
        fastvideo_args.model_loaded["transformer"] = False
        logger.info("Memory after deallocating transformer: %s",
                    torch.mps.current_allocated_memory())

    return batch
fastvideo.pipelines.stages.DenoisingStage.prepare_extra_func_kwargs
prepare_extra_func_kwargs(func, kwargs) -> dict[str, Any]

Prepare extra kwargs for the scheduler step / denoise step.

Parameters:

Name Type Description Default
func

The function to prepare kwargs for.

required
kwargs

The kwargs to prepare.

required

Returns:

Type Description
dict[str, Any]

The prepared kwargs.

Source code in fastvideo/pipelines/stages/denoising.py
def prepare_extra_func_kwargs(self, func, kwargs) -> dict[str, Any]:
    """
    Prepare extra kwargs for the scheduler step / denoise step.

    Args:
        func: The function to prepare kwargs for.
        kwargs: The kwargs to prepare.

    Returns:
        The prepared kwargs.
    """
    extra_step_kwargs = {}
    for k, v in kwargs.items():
        accepts = k in set(inspect.signature(func).parameters.keys())
        if accepts:
            extra_step_kwargs[k] = v
    return extra_step_kwargs
fastvideo.pipelines.stages.DenoisingStage.prepare_sta_param
prepare_sta_param(batch: ForwardBatch, fastvideo_args: FastVideoArgs)

Prepare Sliding Tile Attention (STA) parameters and settings.

Parameters:

Name Type Description Default
batch ForwardBatch

The current batch information.

required
fastvideo_args FastVideoArgs

The inference arguments.

required
Source code in fastvideo/pipelines/stages/denoising.py
def prepare_sta_param(self, batch: ForwardBatch,
                      fastvideo_args: FastVideoArgs):
    """
    Prepare Sliding Tile Attention (STA) parameters and settings.

    Args:
        batch: The current batch information.
        fastvideo_args: The inference arguments.
    """
    # TODO(kevin): STA mask search, currently only support Wan2.1 with 69x768x1280
    from fastvideo.attention.backends.STA_configuration import configure_sta
    STA_mode = fastvideo_args.STA_mode
    skip_time_steps = fastvideo_args.skip_time_steps
    if batch.timesteps is None:
        raise ValueError("Timesteps must be provided")
    timesteps_num = batch.timesteps.shape[0]

    logger.info("STA_mode: %s", STA_mode)
    if (batch.num_frames, batch.height,
            batch.width) != (69, 768, 1280) and STA_mode != "STA_inference":
        raise NotImplementedError(
            "STA mask search/tuning is not supported for this resolution")

    if STA_mode == STA_Mode.STA_SEARCHING or STA_mode == STA_Mode.STA_TUNING or STA_mode == STA_Mode.STA_TUNING_CFG:
        size = (batch.width, batch.height)
        if size == (1280, 768):
            # TODO: make it configurable
            sparse_mask_candidates_searching = [
                "3, 1, 10", "1, 5, 7", "3, 3, 3", "1, 6, 5", "1, 3, 10",
                "3, 6, 1"
            ]
            sparse_mask_candidates_tuning = [
                "3, 1, 10", "1, 5, 7", "3, 3, 3", "1, 6, 5", "1, 3, 10",
                "3, 6, 1"
            ]
            full_mask = ["3,6,10"]
        else:
            raise NotImplementedError(
                "STA mask search is not supported for this resolution")
    layer_num = self.transformer.config.num_layers
    # specific for HunyuanVideo
    if hasattr(self.transformer.config, "num_single_layers"):
        layer_num += self.transformer.config.num_single_layers
    head_num = self.transformer.config.num_attention_heads

    if STA_mode == STA_Mode.STA_SEARCHING:
        STA_param = configure_sta(
            mode=STA_Mode.STA_SEARCHING,
            layer_num=layer_num,
            head_num=head_num,
            time_step_num=timesteps_num,
            mask_candidates=sparse_mask_candidates_searching +
            full_mask,  # last is full mask; Can add more sparse masks while keep last one as full mask
        )
    elif STA_mode == STA_Mode.STA_TUNING:
        STA_param = configure_sta(
            mode=STA_Mode.STA_TUNING,
            layer_num=layer_num,
            head_num=head_num,
            time_step_num=timesteps_num,
            mask_search_files_path=
            f'output/mask_search_result_pos_{size[0]}x{size[1]}/',
            mask_candidates=sparse_mask_candidates_tuning,
            full_attention_mask=[int(x) for x in full_mask[0].split(',')],
            skip_time_steps=
            skip_time_steps,  # Use full attention for first 12 steps
            save_dir=
            f'output/mask_search_strategy_{size[0]}x{size[1]}/',  # Custom save directory
            timesteps=timesteps_num)
    elif STA_mode == STA_Mode.STA_TUNING_CFG:
        STA_param = configure_sta(
            mode=STA_Mode.STA_TUNING_CFG,
            layer_num=layer_num,
            head_num=head_num,
            time_step_num=timesteps_num,
            mask_search_files_path_pos=
            f'output/mask_search_result_pos_{size[0]}x{size[1]}/',
            mask_search_files_path_neg=
            f'output/mask_search_result_neg_{size[0]}x{size[1]}/',
            mask_candidates=sparse_mask_candidates_tuning,
            full_attention_mask=[int(x) for x in full_mask[0].split(',')],
            skip_time_steps=skip_time_steps,
            save_dir=f'output/mask_search_strategy_{size[0]}x{size[1]}/',
            timesteps=timesteps_num)
    elif STA_mode == STA_Mode.STA_INFERENCE:
        import fastvideo.envs as envs
        config_file = envs.FASTVIDEO_ATTENTION_CONFIG
        if config_file is None:
            raise ValueError("FASTVIDEO_ATTENTION_CONFIG is not set")
        STA_param = configure_sta(mode=STA_Mode.STA_INFERENCE,
                                  layer_num=layer_num,
                                  head_num=head_num,
                                  time_step_num=timesteps_num,
                                  load_path=config_file)

    batch.STA_param = STA_param
    batch.mask_search_final_result_pos = [[] for _ in range(timesteps_num)]
    batch.mask_search_final_result_neg = [[] for _ in range(timesteps_num)]
fastvideo.pipelines.stages.DenoisingStage.progress_bar
progress_bar(iterable: Iterable | None = None, total: int | None = None) -> tqdm

Create a progress bar for the denoising process.

Parameters:

Name Type Description Default
iterable Iterable | None

The iterable to iterate over.

None
total int | None

The total number of items.

None

Returns:

Type Description
tqdm

A tqdm progress bar.

Source code in fastvideo/pipelines/stages/denoising.py
def progress_bar(self,
                 iterable: Iterable | None = None,
                 total: int | None = None) -> tqdm:
    """
    Create a progress bar for the denoising process.

    Args:
        iterable: The iterable to iterate over.
        total: The total number of items.

    Returns:
        A tqdm progress bar.
    """
    local_rank = get_world_group().local_rank
    if local_rank == 0:
        return tqdm(iterable=iterable, total=total)
    else:
        return tqdm(iterable=iterable, total=total, disable=True)
fastvideo.pipelines.stages.DenoisingStage.rescale_noise_cfg
rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0) -> Tensor

Rescale noise prediction according to guidance_rescale.

Based on findings of "Common Diffusion Noise Schedules and Sample Steps are Flawed" (https://arxiv.org/pdf/2305.08891.pdf), Section 3.4.

Parameters:

Name Type Description Default
noise_cfg

The noise prediction with guidance.

required
noise_pred_text

The text-conditioned noise prediction.

required
guidance_rescale

The guidance rescale factor.

0.0

Returns:

Type Description
Tensor

The rescaled noise prediction.

Source code in fastvideo/pipelines/stages/denoising.py
def rescale_noise_cfg(self,
                      noise_cfg,
                      noise_pred_text,
                      guidance_rescale=0.0) -> torch.Tensor:
    """
    Rescale noise prediction according to guidance_rescale.

    Based on findings of "Common Diffusion Noise Schedules and Sample Steps are Flawed"
    (https://arxiv.org/pdf/2305.08891.pdf), Section 3.4.

    Args:
        noise_cfg: The noise prediction with guidance.
        noise_pred_text: The text-conditioned noise prediction.
        guidance_rescale: The guidance rescale factor.

    Returns:
        The rescaled noise prediction.
    """
    std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)),
                                   keepdim=True)
    std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)),
                            keepdim=True)
    # Rescale the results from guidance (fixes overexposure)
    noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
    # Mix with the original results from guidance by factor guidance_rescale
    noise_cfg = (guidance_rescale * noise_pred_rescaled +
                 (1 - guidance_rescale) * noise_cfg)
    return noise_cfg
fastvideo.pipelines.stages.DenoisingStage.save_sta_search_results
save_sta_search_results(batch: ForwardBatch)

Save the STA mask search results.

Parameters:

Name Type Description Default
batch ForwardBatch

The current batch information.

required
Source code in fastvideo/pipelines/stages/denoising.py
def save_sta_search_results(self, batch: ForwardBatch):
    """
    Save the STA mask search results.

    Args:
        batch: The current batch information.
    """
    size = (batch.width, batch.height)
    if size == (1280, 768):
        # TODO: make it configurable
        sparse_mask_candidates_searching = [
            "3, 1, 10", "1, 5, 7", "3, 3, 3", "1, 6, 5", "1, 3, 10",
            "3, 6, 1"
        ]
    else:
        raise NotImplementedError(
            "STA mask search is not supported for this resolution")

    from fastvideo.attention.backends.STA_configuration import save_mask_search_results
    if batch.mask_search_final_result_pos is not None and batch.prompt is not None:
        save_mask_search_results(
            [
                dict(layer_data)
                for layer_data in batch.mask_search_final_result_pos
            ],
            prompt=str(batch.prompt),
            mask_strategies=sparse_mask_candidates_searching,
            output_dir=f'output/mask_search_result_pos_{size[0]}x{size[1]}/'
        )
    if batch.mask_search_final_result_neg is not None and batch.prompt is not None:
        save_mask_search_results(
            [
                dict(layer_data)
                for layer_data in batch.mask_search_final_result_neg
            ],
            prompt=str(batch.prompt),
            mask_strategies=sparse_mask_candidates_searching,
            output_dir=f'output/mask_search_result_neg_{size[0]}x{size[1]}/'
        )
fastvideo.pipelines.stages.DenoisingStage.verify_input
verify_input(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult

Verify denoising stage inputs.

Source code in fastvideo/pipelines/stages/denoising.py
def verify_input(self, batch: ForwardBatch,
                 fastvideo_args: FastVideoArgs) -> VerificationResult:
    """Verify denoising stage inputs."""
    result = VerificationResult()
    result.add_check("timesteps", batch.timesteps,
                     [V.is_tensor, V.min_dims(1)])
    result.add_check("latents", batch.latents,
                     [V.is_tensor, V.with_dims(5)])
    result.add_check("prompt_embeds", batch.prompt_embeds, V.list_not_empty)
    result.add_check("image_embeds", batch.image_embeds, V.is_list)
    result.add_check("image_latent", batch.image_latent,
                     V.none_or_tensor_with_dims(5))
    result.add_check("num_inference_steps", batch.num_inference_steps,
                     V.positive_int)
    result.add_check("guidance_scale", batch.guidance_scale,
                     V.positive_float)
    result.add_check("eta", batch.eta, V.non_negative_float)
    result.add_check("generator", batch.generator,
                     V.generator_or_list_generators)
    result.add_check("do_classifier_free_guidance",
                     batch.do_classifier_free_guidance, V.bool_value)
    result.add_check(
        "negative_prompt_embeds", batch.negative_prompt_embeds, lambda x:
        not batch.do_classifier_free_guidance or V.list_not_empty(x))
    return result
fastvideo.pipelines.stages.DenoisingStage.verify_output
verify_output(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult

Verify denoising stage outputs.

Source code in fastvideo/pipelines/stages/denoising.py
def verify_output(self, batch: ForwardBatch,
                  fastvideo_args: FastVideoArgs) -> VerificationResult:
    """Verify denoising stage outputs."""
    result = VerificationResult()
    result.add_check("latents", batch.latents,
                     [V.is_tensor, V.with_dims(5)])
    return result

fastvideo.pipelines.stages.DmdDenoisingStage

DmdDenoisingStage(transformer, scheduler)

Bases: DenoisingStage

Denoising stage for DMD.

Source code in fastvideo/pipelines/stages/denoising.py
def __init__(self, transformer, scheduler) -> None:
    super().__init__(transformer, scheduler)
    self.scheduler = FlowMatchEulerDiscreteScheduler(shift=8.0)

Functions

fastvideo.pipelines.stages.DmdDenoisingStage.forward
forward(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> ForwardBatch

Run the denoising loop.

Parameters:

Name Type Description Default
batch ForwardBatch

The current batch information.

required
fastvideo_args FastVideoArgs

The inference arguments.

required

Returns:

Type Description
ForwardBatch

The batch with denoised latents.

Source code in fastvideo/pipelines/stages/denoising.py
def forward(
    self,
    batch: ForwardBatch,
    fastvideo_args: FastVideoArgs,
) -> ForwardBatch:
    """
    Run the denoising loop.

    Args:
        batch: The current batch information.
        fastvideo_args: The inference arguments.

    Returns:
        The batch with denoised latents.
    """
    # Setup precision and autocast settings
    # TODO(will): make the precision configurable for inference
    # target_dtype = PRECISION_TO_TYPE[fastvideo_args.precision]
    target_dtype = torch.bfloat16
    autocast_enabled = (target_dtype != torch.float32
                        ) and not fastvideo_args.disable_autocast

    # Get timesteps and calculate warmup steps
    timesteps = batch.timesteps

    # TODO(will): remove this once we add input/output validation for stages
    if timesteps is None:
        raise ValueError("Timesteps must be provided")
    num_inference_steps = batch.num_inference_steps
    num_warmup_steps = len(
        timesteps) - num_inference_steps * self.scheduler.order

    # Prepare image latents and embeddings for I2V generation
    image_embeds = batch.image_embeds
    if len(image_embeds) > 0:
        assert torch.isnan(image_embeds[0]).sum() == 0
        image_embeds = [
            image_embed.to(target_dtype) for image_embed in image_embeds
        ]

    image_kwargs = self.prepare_extra_func_kwargs(
        self.transformer.forward,
        {
            "encoder_hidden_states_image": image_embeds,
            "mask_strategy": dict_to_3d_list(
                None, t_max=50, l_max=60, h_max=24)
        },
    )

    pos_cond_kwargs = self.prepare_extra_func_kwargs(
        self.transformer.forward,
        {
            "encoder_hidden_states_2": batch.clip_embedding_pos,
            "encoder_attention_mask": batch.prompt_attention_mask,
        },
    )

    # Prepare STA parameters
    if st_attn_available and self.attn_backend == SlidingTileAttentionBackend:
        self.prepare_sta_param(batch, fastvideo_args)

    # Get latents and embeddings
    assert batch.latents is not None, "latents must be provided"
    latents = batch.latents

    video_raw_latent_shape = latents.shape
    prompt_embeds = batch.prompt_embeds
    assert not torch.isnan(
        prompt_embeds[0]).any(), "prompt_embeds contains nan"
    timesteps = torch.tensor(
        fastvideo_args.pipeline_config.dmd_denoising_steps,
        dtype=torch.long,
        device=get_local_torch_device())

    # Run denoising loop
    with self.progress_bar(total=len(timesteps)) as progress_bar:
        for i, t in enumerate(timesteps):
            # Skip if interrupted
            if hasattr(self, 'interrupt') and self.interrupt:
                continue
            # Expand latents for I2V
            noise_latents = latents.clone()
            latent_model_input = latents.to(target_dtype)

            if batch.image_latent is not None:
                latent_model_input = torch.cat([
                    latent_model_input,
                    batch.image_latent.permute(0, 2, 1, 3, 4)
                ],
                                               dim=2).to(target_dtype)
            assert not torch.isnan(
                latent_model_input).any(), "latent_model_input contains nan"

            # Prepare inputs for transformer
            t_expand = t.repeat(latent_model_input.shape[0])
            guidance_expand = (
                torch.tensor(
                    [fastvideo_args.pipeline_config.embedded_cfg_scale] *
                    latent_model_input.shape[0],
                    dtype=torch.float32,
                    device=get_local_torch_device(),
                ).to(target_dtype) *
                1000.0 if fastvideo_args.pipeline_config.embedded_cfg_scale
                is not None else None)

            # Predict noise residual
            with torch.autocast(device_type="cuda",
                                dtype=target_dtype,
                                enabled=autocast_enabled):
                if (vsa_available and self.attn_backend
                        == VideoSparseAttentionBackend):
                    self.attn_metadata_builder_cls = self.attn_backend.get_builder_cls(
                    )

                    if self.attn_metadata_builder_cls is not None:
                        self.attn_metadata_builder = self.attn_metadata_builder_cls(
                        )
                        # TODO(will): clean this up
                        attn_metadata = self.attn_metadata_builder.build(  # type: ignore
                            current_timestep=i,  # type: ignore
                            raw_latent_shape=batch.
                            raw_latent_shape[2:5],  # type: ignore
                            patch_size=fastvideo_args.
                            pipeline_config.  # type: ignore
                            dit_config.patch_size,  # type: ignore
                            STA_param=batch.STA_param,  # type: ignore
                            VSA_sparsity=fastvideo_args.
                            VSA_sparsity,  # type: ignore
                            device=get_local_torch_device(),  # type: ignore
                        )  # type: ignore
                        assert attn_metadata is not None, "attn_metadata cannot be None"
                    else:
                        attn_metadata = None
                else:
                    attn_metadata = None

                batch.is_cfg_negative = False
                with set_forward_context(
                        current_timestep=i,
                        attn_metadata=attn_metadata,
                        forward_batch=batch,
                        # fastvideo_args=fastvideo_args
                ):
                    # Run transformer
                    pred_noise = self.transformer(
                        latent_model_input.permute(0, 2, 1, 3, 4),
                        prompt_embeds,
                        t_expand,
                        guidance=guidance_expand,
                        **image_kwargs,
                        **pos_cond_kwargs,
                    ).permute(0, 2, 1, 3, 4)

                pred_video = pred_noise_to_pred_video(
                    pred_noise=pred_noise.flatten(0, 1),
                    noise_input_latent=noise_latents.flatten(0, 1),
                    timestep=t_expand,
                    scheduler=self.scheduler).unflatten(
                        0, pred_noise.shape[:2])

                if i < len(timesteps) - 1:
                    next_timestep = timesteps[i + 1] * torch.ones(
                        [1], dtype=torch.long, device=pred_video.device)
                    noise = torch.randn(video_raw_latent_shape,
                                        dtype=pred_video.dtype,
                                        generator=batch.generator[0]).to(
                                            self.device)
                    latents = self.scheduler.add_noise(
                        pred_video.flatten(0, 1), noise.flatten(0, 1),
                        next_timestep).unflatten(0, pred_video.shape[:2])
                else:
                    latents = pred_video

                # Update progress bar
                if i == len(timesteps) - 1 or (
                    (i + 1) > num_warmup_steps and
                    (i + 1) % self.scheduler.order == 0
                        and progress_bar is not None):
                    progress_bar.update()

    # Gather results if using sequence parallelism
    latents = latents.permute(0, 2, 1, 3, 4)
    # Update batch with final latents
    batch.latents = latents

    return batch

fastvideo.pipelines.stages.EncodingStage

EncodingStage(vae: ParallelTiledVAE)

Bases: PipelineStage

Stage for encoding pixel space representations into latent space.

This stage handles the encoding of pixel-space video/images into latent representations for further processing in the diffusion pipeline.

Source code in fastvideo/pipelines/stages/encoding.py
def __init__(self, vae: ParallelTiledVAE) -> None:
    self.vae: ParallelTiledVAE = vae

Functions

fastvideo.pipelines.stages.EncodingStage.forward
forward(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> ForwardBatch

Encode pixel space representations into latent space.

Parameters:

Name Type Description Default
batch ForwardBatch

The current batch information.

required
fastvideo_args FastVideoArgs

The inference arguments.

required

Returns:

Type Description
ForwardBatch

The batch with encoded latents.

Source code in fastvideo/pipelines/stages/encoding.py
def forward(
    self,
    batch: ForwardBatch,
    fastvideo_args: FastVideoArgs,
) -> ForwardBatch:
    """
    Encode pixel space representations into latent space.

    Args:
        batch: The current batch information.
        fastvideo_args: The inference arguments.

    Returns:
        The batch with encoded latents.
    """
    assert batch.latents is not None and isinstance(batch.latents,
                                                    torch.Tensor)

    self.vae = self.vae.to(get_local_torch_device())

    # Setup VAE precision
    vae_dtype = PRECISION_TO_TYPE[
        fastvideo_args.pipeline_config.vae_precision]
    vae_autocast_enabled = (
        vae_dtype != torch.float32) and not fastvideo_args.disable_autocast

    # Normalize input to [-1, 1] range (reverse of decoding normalization)
    latents = (batch.latents * 2.0 - 1.0).clamp(-1, 1)

    # Move to appropriate device and dtype
    latents = latents.to(get_local_torch_device())

    # Encode image to latents
    with torch.autocast(device_type="cuda",
                        dtype=vae_dtype,
                        enabled=vae_autocast_enabled):
        if fastvideo_args.pipeline_config.vae_tiling:
            self.vae.enable_tiling()
        # if fastvideo_args.vae_sp:
        #     self.vae.enable_parallel()
        if not vae_autocast_enabled:
            latents = latents.to(vae_dtype)
        latents = self.vae.encode(latents).mean

    # Update batch with encoded latents
    batch.latents = latents

    # Offload models if needed
    if hasattr(self, 'maybe_free_model_hooks'):
        self.maybe_free_model_hooks()

    if fastvideo_args.vae_cpu_offload:
        self.vae.to("cpu")

    return batch
fastvideo.pipelines.stages.EncodingStage.verify_input
verify_input(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult

Verify encoding stage inputs.

Source code in fastvideo/pipelines/stages/encoding.py
@torch.no_grad()
def verify_input(self, batch: ForwardBatch,
                 fastvideo_args: FastVideoArgs) -> VerificationResult:
    """Verify encoding stage inputs."""
    result = VerificationResult()
    # Input video/images for VAE encoding: [batch_size, channels, frames, height, width]
    result.add_check("latents", batch.latents,
                     [V.is_tensor, V.with_dims(5)])
    return result
fastvideo.pipelines.stages.EncodingStage.verify_output
verify_output(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult

Verify encoding stage outputs.

Source code in fastvideo/pipelines/stages/encoding.py
def verify_output(self, batch: ForwardBatch,
                  fastvideo_args: FastVideoArgs) -> VerificationResult:
    """Verify encoding stage outputs."""
    result = VerificationResult()
    # Encoded latents: [batch_size, channels, frames, height_latents, width_latents]
    result.add_check("latents", batch.latents,
                     [V.is_tensor, V.with_dims(5)])
    return result

fastvideo.pipelines.stages.Hy15ImageEncodingStage

Hy15ImageEncodingStage(image_encoder, image_processor)

Bases: ImageEncodingStage

Stage for encoding image prompts into embeddings for HunyuanVideo1.5 models.

Source code in fastvideo/pipelines/stages/image_encoding.py
def __init__(self, image_encoder, image_processor) -> None:
    """
    Initialize the prompt encoding stage.

    Args:
        enable_logging: Whether to enable logging for this stage.
        is_secondary: Whether this is a secondary image encoder.
    """
    super().__init__()
    self.image_processor = image_processor
    self.image_encoder = image_encoder

Functions

fastvideo.pipelines.stages.Hy15ImageEncodingStage.forward
forward(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> ForwardBatch

Encode the prompt into image encoder hidden states.

Source code in fastvideo/pipelines/stages/image_encoding.py
def forward(self, batch: ForwardBatch,
            fastvideo_args: FastVideoArgs) -> ForwardBatch:
    """
    Encode the prompt into image encoder hidden states.
    """
    if batch.pil_image is None:
        batch.image_embeds = [
            torch.zeros(1, 729, 1152, device=get_local_torch_device())
        ]

    raw_latent_shape = list(batch.raw_latent_shape)
    raw_latent_shape[1] = 1
    batch.video_latent = torch.zeros(tuple(raw_latent_shape),
                                     device=get_local_torch_device())
    return batch
fastvideo.pipelines.stages.Hy15ImageEncodingStage.verify_input
verify_input(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult

Verify image encoding stage inputs.

Source code in fastvideo/pipelines/stages/image_encoding.py
def verify_input(self, batch: ForwardBatch,
                 fastvideo_args: FastVideoArgs) -> VerificationResult:
    """Verify image encoding stage inputs."""
    return VerificationResult()

fastvideo.pipelines.stages.ImageEncodingStage

ImageEncodingStage(image_encoder, image_processor)

Bases: PipelineStage

Stage for encoding image prompts into embeddings for diffusion models.

This stage handles the encoding of image prompts into the embedding space expected by the diffusion model.

Initialize the prompt encoding stage.

Parameters:

Name Type Description Default
enable_logging

Whether to enable logging for this stage.

required
is_secondary

Whether this is a secondary image encoder.

required
Source code in fastvideo/pipelines/stages/image_encoding.py
def __init__(self, image_encoder, image_processor) -> None:
    """
    Initialize the prompt encoding stage.

    Args:
        enable_logging: Whether to enable logging for this stage.
        is_secondary: Whether this is a secondary image encoder.
    """
    super().__init__()
    self.image_processor = image_processor
    self.image_encoder = image_encoder

Functions

fastvideo.pipelines.stages.ImageEncodingStage.forward
forward(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> ForwardBatch

Encode the prompt into image encoder hidden states.

Parameters:

Name Type Description Default
batch ForwardBatch

The current batch information.

required
fastvideo_args FastVideoArgs

The inference arguments.

required

Returns:

Type Description
ForwardBatch

The batch with encoded prompt embeddings.

Source code in fastvideo/pipelines/stages/image_encoding.py
@torch.no_grad()
def forward(
    self,
    batch: ForwardBatch,
    fastvideo_args: FastVideoArgs,
) -> ForwardBatch:
    """
    Encode the prompt into image encoder hidden states.

    Args:
        batch: The current batch information.
        fastvideo_args: The inference arguments.

    Returns:
        The batch with encoded prompt embeddings.
    """
    self.image_encoder = self.image_encoder.to(get_local_torch_device())

    image = batch.pil_image

    image_inputs = self.image_processor(
        images=image, return_tensors="pt").to(get_local_torch_device())
    with set_forward_context(current_timestep=0, attn_metadata=None):
        outputs = self.image_encoder(**image_inputs)
        image_embeds = outputs.last_hidden_state

    batch.image_embeds.append(image_embeds)

    if fastvideo_args.image_encoder_cpu_offload:
        self.image_encoder.to('cpu')

    return batch
fastvideo.pipelines.stages.ImageEncodingStage.verify_input
verify_input(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult

Verify image encoding stage inputs.

Source code in fastvideo/pipelines/stages/image_encoding.py
def verify_input(self, batch: ForwardBatch,
                 fastvideo_args: FastVideoArgs) -> VerificationResult:
    """Verify image encoding stage inputs."""
    result = VerificationResult()
    result.add_check("pil_image", batch.pil_image, V.not_none)
    result.add_check("image_embeds", batch.image_embeds, V.is_list)
    return result
fastvideo.pipelines.stages.ImageEncodingStage.verify_output
verify_output(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult

Verify image encoding stage outputs.

Source code in fastvideo/pipelines/stages/image_encoding.py
def verify_output(self, batch: ForwardBatch,
                  fastvideo_args: FastVideoArgs) -> VerificationResult:
    """Verify image encoding stage outputs."""
    result = VerificationResult()
    result.add_check("image_embeds", batch.image_embeds,
                     V.list_of_tensors_dims(3))
    return result

fastvideo.pipelines.stages.ImageVAEEncodingStage

ImageVAEEncodingStage(vae: ParallelTiledVAE)

Bases: PipelineStage

Stage for encoding image pixel representations into latent space.

This stage handles the encoding of image pixel representations into the final input format (e.g., latents) for image-to-video generation.

Source code in fastvideo/pipelines/stages/image_encoding.py
def __init__(self, vae: ParallelTiledVAE) -> None:
    self.vae: ParallelTiledVAE = vae

Functions

fastvideo.pipelines.stages.ImageVAEEncodingStage.forward
forward(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> ForwardBatch

Encode pixel representations into latent space.

Parameters:

Name Type Description Default
batch ForwardBatch

The current batch information.

required
fastvideo_args FastVideoArgs

The inference arguments.

required

Returns:

Type Description
ForwardBatch

The batch with encoded outputs.

Source code in fastvideo/pipelines/stages/image_encoding.py
def forward(
    self,
    batch: ForwardBatch,
    fastvideo_args: FastVideoArgs,
) -> ForwardBatch:
    """
    Encode pixel representations into latent space.

    Args:
        batch: The current batch information.
        fastvideo_args: The inference arguments.

    Returns:
        The batch with encoded outputs.
    """
    assert batch.pil_image is not None
    if fastvideo_args.mode == ExecutionMode.INFERENCE:
        assert batch.pil_image is not None and isinstance(
            batch.pil_image, PIL.Image.Image)
        assert batch.height is not None and isinstance(batch.height, int)
        assert batch.width is not None and isinstance(batch.width, int)
        assert batch.num_frames is not None and isinstance(
            batch.num_frames, int)
        height = batch.height
        width = batch.width
        num_frames = batch.num_frames
    elif fastvideo_args.mode == ExecutionMode.PREPROCESS:
        assert batch.pil_image is not None and isinstance(
            batch.pil_image, torch.Tensor)
        assert batch.height is not None and isinstance(batch.height, list)
        assert batch.width is not None and isinstance(batch.width, list)
        assert batch.num_frames is not None and isinstance(
            batch.num_frames, list)
        num_frames = batch.num_frames[0]
        height = batch.height[0]
        width = batch.width[0]

    self.vae = self.vae.to(get_local_torch_device())

    # Process single image for I2V
    latent_height = height // self.vae.spatial_compression_ratio
    latent_width = width // self.vae.spatial_compression_ratio
    image = batch.pil_image
    image = self.preprocess(
        image,
        vae_scale_factor=self.vae.spatial_compression_ratio,
        height=height,
        width=width).to(get_local_torch_device(), dtype=torch.float32)

    # (B, C, H, W) -> (B, C, 1, H, W)
    image = image.unsqueeze(2)

    video_condition = torch.cat([
        image,
        image.new_zeros(image.shape[0], image.shape[1], num_frames - 1,
                        image.shape[3], image.shape[4])
    ],
                                dim=2)
    video_condition = video_condition.to(device=get_local_torch_device(),
                                         dtype=torch.float32)

    # Setup VAE precision
    vae_dtype = PRECISION_TO_TYPE[
        fastvideo_args.pipeline_config.vae_precision]
    vae_autocast_enabled = (
        vae_dtype != torch.float32) and not fastvideo_args.disable_autocast

    # Encode Image
    with torch.autocast(device_type="cuda",
                        dtype=vae_dtype,
                        enabled=vae_autocast_enabled):
        if fastvideo_args.pipeline_config.vae_tiling:
            self.vae.enable_tiling()
        # if fastvideo_args.vae_sp:
        #     self.vae.enable_parallel()
        if not vae_autocast_enabled:
            video_condition = video_condition.to(vae_dtype)
        encoder_output = self.vae.encode(video_condition)

    if fastvideo_args.mode == ExecutionMode.PREPROCESS:
        latent_condition = encoder_output.mean
    else:
        generator = batch.generator
        if generator is None:
            raise ValueError("Generator must be provided")
        latent_condition = self.retrieve_latents(encoder_output, generator)

    # Apply shifting if needed
    if (hasattr(self.vae, "shift_factor")
            and self.vae.shift_factor is not None):
        if isinstance(self.vae.shift_factor, torch.Tensor):
            latent_condition -= self.vae.shift_factor.to(
                latent_condition.device, latent_condition.dtype)
        else:
            latent_condition -= self.vae.shift_factor

    if isinstance(self.vae.scaling_factor, torch.Tensor):
        latent_condition = latent_condition * self.vae.scaling_factor.to(
            latent_condition.device, latent_condition.dtype)
    else:
        latent_condition = latent_condition * self.vae.scaling_factor

    if fastvideo_args.mode == ExecutionMode.PREPROCESS:
        batch.image_latent = latent_condition
    else:
        mask_lat_size = torch.ones(1, 1, num_frames, latent_height,
                                   latent_width)
        mask_lat_size[:, :, list(range(1, num_frames))] = 0
        first_frame_mask = mask_lat_size[:, :, 0:1]
        first_frame_mask = torch.repeat_interleave(
            first_frame_mask,
            dim=2,
            repeats=self.vae.temporal_compression_ratio)
        mask_lat_size = torch.concat(
            [first_frame_mask, mask_lat_size[:, :, 1:, :]], dim=2)
        mask_lat_size = mask_lat_size.view(
            1, -1, self.vae.temporal_compression_ratio, latent_height,
            latent_width)
        mask_lat_size = mask_lat_size.transpose(1, 2)
        mask_lat_size = mask_lat_size.to(latent_condition.device)

        batch.image_latent = torch.concat([mask_lat_size, latent_condition],
                                          dim=1)

    # Offload models if needed
    if hasattr(self, 'maybe_free_model_hooks'):
        self.maybe_free_model_hooks()

    self.vae.to("cpu")

    return batch
fastvideo.pipelines.stages.ImageVAEEncodingStage.verify_input
verify_input(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult

Verify encoding stage inputs.

Source code in fastvideo/pipelines/stages/image_encoding.py
def verify_input(self, batch: ForwardBatch,
                 fastvideo_args: FastVideoArgs) -> VerificationResult:
    """Verify encoding stage inputs."""
    result = VerificationResult()
    result.add_check("generator", batch.generator,
                     V.generator_or_list_generators)
    if fastvideo_args.mode == ExecutionMode.PREPROCESS:
        result.add_check("height", batch.height, V.list_not_empty)
        result.add_check("width", batch.width, V.list_not_empty)
        result.add_check("num_frames", batch.num_frames, V.list_not_empty)
    else:
        result.add_check("height", batch.height, V.positive_int)
        result.add_check("width", batch.width, V.positive_int)
        result.add_check("num_frames", batch.num_frames, V.positive_int)
    return result
fastvideo.pipelines.stages.ImageVAEEncodingStage.verify_output
verify_output(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult

Verify encoding stage outputs.

Source code in fastvideo/pipelines/stages/image_encoding.py
def verify_output(self, batch: ForwardBatch,
                  fastvideo_args: FastVideoArgs) -> VerificationResult:
    """Verify encoding stage outputs."""
    result = VerificationResult()
    result.add_check("image_latent", batch.image_latent,
                     [V.is_tensor, V.with_dims(5)])
    return result

fastvideo.pipelines.stages.InputValidationStage

Bases: PipelineStage

Stage for validating and preparing inputs for diffusion pipelines.

This stage validates that all required inputs are present and properly formatted before proceeding with the diffusion process.

Functions

fastvideo.pipelines.stages.InputValidationStage.forward
forward(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> ForwardBatch

Validate and prepare inputs.

Parameters:

Name Type Description Default
batch ForwardBatch

The current batch information.

required
fastvideo_args FastVideoArgs

The inference arguments.

required

Returns:

Type Description
ForwardBatch

The validated batch information.

Source code in fastvideo/pipelines/stages/input_validation.py
def forward(
    self,
    batch: ForwardBatch,
    fastvideo_args: FastVideoArgs,
) -> ForwardBatch:
    """
    Validate and prepare inputs.

    Args:
        batch: The current batch information.
        fastvideo_args: The inference arguments.

    Returns:
        The validated batch information.
    """

    self._generate_seeds(batch, fastvideo_args)

    # Ensure prompt is properly formatted
    if batch.prompt is None and batch.prompt_embeds is None:
        raise ValueError(
            "Either `prompt` or `prompt_embeds` must be provided")

    # Ensure negative prompt is properly formatted if using classifier-free guidance
    if (batch.do_classifier_free_guidance and batch.negative_prompt is None
            and batch.negative_prompt_embeds is None):
        raise ValueError(
            "For classifier-free guidance, either `negative_prompt` or "
            "`negative_prompt_embeds` must be provided")

    # Validate height and width
    if batch.height is None or batch.width is None:
        raise ValueError(
            "Height and width must be provided. Please set `height` and `width`."
        )
    if batch.height % 8 != 0 or batch.width % 8 != 0:
        raise ValueError(
            f"Height and width must be divisible by 8 but are {batch.height} and {batch.width}."
        )

    # Validate number of inference steps
    if batch.num_inference_steps <= 0:
        raise ValueError(
            f"Number of inference steps must be positive, but got {batch.num_inference_steps}"
        )

    # Validate guidance scale if using classifier-free guidance
    if batch.do_classifier_free_guidance and batch.guidance_scale <= 0:
        raise ValueError(
            f"Guidance scale must be positive, but got {batch.guidance_scale}"
        )

    # for i2v, get image from image_path
    # @TODO(Wei) hard-coded for wan2.2 5b ti2v for now. Should put this in image_encoding stage
    if batch.image_path is not None:
        if batch.image_path.endswith(".mp4"):
            image = load_video(batch.image_path)[0]
        else:
            image = load_image(batch.image_path)
        batch.pil_image = image

    # further processing for ti2v task
    if (fastvideo_args.pipeline_config.ti2v_task
            or fastvideo_args.pipeline_config.is_causal
        ) and batch.pil_image is not None:
        img = batch.pil_image
        ih, iw = img.height, img.width

        pipeline_class_name = type(fastvideo_args.pipeline_config).__name__
        if 'MatrixGame' in pipeline_class_name or 'MatrixCausal' in pipeline_class_name:
            oh, ow = batch.height, batch.width
            img = img.resize((ow, oh), Image.LANCZOS)
        else:
            # Standard Wan logic
            patch_size = fastvideo_args.pipeline_config.dit_config.arch_config.patch_size
            vae_stride = fastvideo_args.pipeline_config.vae_config.arch_config.scale_factor_spatial
            dh, dw = patch_size[1] * vae_stride, patch_size[2] * vae_stride
            max_area = 480 * 832
            ow, oh = best_output_size(iw, ih, dw, dh, max_area)

            scale = max(ow / iw, oh / ih)
            img = img.resize((round(iw * scale), round(ih * scale)),
                             Image.LANCZOS)

            # center-crop
            x1 = (img.width - ow) // 2
            y1 = (img.height - oh) // 2
            img = img.crop((x1, y1, x1 + ow, y1 + oh))

        assert img.width == ow and img.height == oh
        logger.info("final processed img height: %s, img width: %s",
                    img.height, img.width)

        # to tensor
        img = TF.to_tensor(img).sub_(0.5).div_(0.5).to(
            self.device).unsqueeze(1)
        img = img.unsqueeze(0)
        batch.height = oh
        batch.width = ow
        batch.pil_image = img

    # for v2v, get control video from video path
    if batch.video_path is not None:
        pil_images, original_fps = load_video(batch.video_path,
                                              return_fps=True)
        logger.info("Loaded video with %s frames, original FPS: %s",
                    len(pil_images), original_fps)

        # Get target parameters from batch
        target_fps = batch.fps
        target_num_frames = batch.num_frames
        target_height = batch.height
        target_width = batch.width

        if target_fps is not None and original_fps is not None:
            frame_skip = max(1, int(original_fps // target_fps))
            if frame_skip > 1:
                pil_images = pil_images[::frame_skip]
                effective_fps = original_fps / frame_skip
                logger.info(
                    "Resampled video from %.1f fps to %.1f fps (skip=%s)",
                    original_fps, effective_fps, frame_skip)

        # Limit to target number of frames
        if target_num_frames is not None and len(
                pil_images) > target_num_frames:
            pil_images = pil_images[:target_num_frames]
            logger.info("Limited video to %s frames (from %s total)",
                        target_num_frames, len(pil_images))

        # Resize each PIL image to target dimensions
        resized_images = []
        for pil_img in pil_images:
            resized_img = resize(pil_img,
                                 target_height,
                                 target_width,
                                 resize_mode="default",
                                 resample="lanczos")
            resized_images.append(resized_img)

        # Convert PIL images to numpy array
        video_numpy = pil_to_numpy(resized_images)
        video_numpy = normalize(video_numpy)
        video_tensor = numpy_to_pt(video_numpy)

        # Rearrange to [C, T, H, W] and add batch dimension -> [B, C, T, H, W]
        input_video = video_tensor.permute(1, 0, 2, 3).unsqueeze(0)

        batch.video_latent = input_video

    # Validate action control inputs (Matrix-Game)
    if batch.mouse_cond is not None:
        if batch.mouse_cond.dim() != 3 or batch.mouse_cond.shape[-1] != 2:
            raise ValueError(
                f"mouse_cond must have shape (B, T, 2), but got {batch.mouse_cond.shape}"
            )
        logger.info("Action control: mouse_cond validated - shape %s",
                    batch.mouse_cond.shape)

    if batch.keyboard_cond is not None:
        if batch.keyboard_cond.dim() != 3:
            raise ValueError(
                f"keyboard_cond must have 3 dimensions (B, T, K), but got {batch.keyboard_cond.dim()}"
            )
        keyboard_dim = batch.keyboard_cond.shape[-1]
        if keyboard_dim not in {2, 4, 6, 7}:
            raise ValueError(
                f"keyboard_cond last dimension must be 2, 4, 6, or 7, but got {keyboard_dim}"
            )
        logger.info(
            "Action control: keyboard_cond validated - shape %s (dim=%d)",
            batch.keyboard_cond.shape, keyboard_dim)

    if batch.grid_sizes is not None:
        if not isinstance(batch.grid_sizes, list | tuple | torch.Tensor):
            raise ValueError("grid_sizes must be a list, tuple, or tensor")
        if isinstance(batch.grid_sizes, torch.Tensor):
            if batch.grid_sizes.numel() != 3:
                raise ValueError(
                    "grid_sizes must have 3 elements [F, H, W]")
        else:
            if len(batch.grid_sizes) != 3:
                raise ValueError(
                    "grid_sizes must have 3 elements [F, H, W]")
        logger.info("Action control: grid_sizes validated - %s",
                    batch.grid_sizes)

    return batch
fastvideo.pipelines.stages.InputValidationStage.verify_input
verify_input(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult

Verify input validation stage inputs.

Source code in fastvideo/pipelines/stages/input_validation.py
def verify_input(self, batch: ForwardBatch,
                 fastvideo_args: FastVideoArgs) -> VerificationResult:
    """Verify input validation stage inputs."""
    result = VerificationResult()
    result.add_check("seed", batch.seed, [V.not_none, V.positive_int])
    result.add_check("num_videos_per_prompt", batch.num_videos_per_prompt,
                     V.positive_int)
    result.add_check(
        "prompt_or_embeds", None, lambda _: V.string_or_list_strings(
            batch.prompt) or V.list_not_empty(batch.prompt_embeds))
    result.add_check("height", batch.height, V.positive_int)
    result.add_check("width", batch.width, V.positive_int)
    result.add_check("num_inference_steps", batch.num_inference_steps,
                     V.positive_int)
    result.add_check(
        "guidance_scale", batch.guidance_scale, lambda x: not batch.
        do_classifier_free_guidance or V.positive_float(x))
    return result
fastvideo.pipelines.stages.InputValidationStage.verify_output
verify_output(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult

Verify input validation stage outputs.

Source code in fastvideo/pipelines/stages/input_validation.py
def verify_output(self, batch: ForwardBatch,
                  fastvideo_args: FastVideoArgs) -> VerificationResult:
    """Verify input validation stage outputs."""
    result = VerificationResult()
    result.add_check("seeds", batch.seeds, V.list_not_empty)
    result.add_check("generator", batch.generator,
                     V.generator_or_list_generators)
    return result

fastvideo.pipelines.stages.LatentPreparationStage

LatentPreparationStage(scheduler, transformer, use_btchw_layout: bool = False)

Bases: PipelineStage

Stage for preparing initial latent variables for the diffusion process.

This stage handles the preparation of the initial latent variables that will be denoised during the diffusion process.

Source code in fastvideo/pipelines/stages/latent_preparation.py
def __init__(self,
             scheduler,
             transformer,
             use_btchw_layout: bool = False) -> None:
    super().__init__()
    self.scheduler = scheduler
    self.transformer = transformer
    self.use_btchw_layout = use_btchw_layout

Functions

fastvideo.pipelines.stages.LatentPreparationStage.adjust_video_length
adjust_video_length(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> int

Adjust video length based on VAE version.

Parameters:

Name Type Description Default
batch ForwardBatch

The current batch information.

required
fastvideo_args FastVideoArgs

The inference arguments.

required

Returns:

Type Description
int

The batch with adjusted video length.

Source code in fastvideo/pipelines/stages/latent_preparation.py
def adjust_video_length(self, batch: ForwardBatch,
                        fastvideo_args: FastVideoArgs) -> int:
    """
    Adjust video length based on VAE version.

    Args:
        batch: The current batch information.
        fastvideo_args: The inference arguments.

    Returns:
        The batch with adjusted video length.
    """

    video_length = batch.num_frames
    use_temporal_scaling_frames = fastvideo_args.pipeline_config.vae_config.use_temporal_scaling_frames
    if use_temporal_scaling_frames:
        temporal_scale_factor = fastvideo_args.pipeline_config.vae_config.arch_config.temporal_compression_ratio
        latent_num_frames = (video_length - 1) // temporal_scale_factor + 1
    else:  # stepvideo only
        latent_num_frames = video_length // 17 * 3
    return int(latent_num_frames)
fastvideo.pipelines.stages.LatentPreparationStage.forward
forward(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> ForwardBatch

Prepare initial latent variables for the diffusion process.

Parameters:

Name Type Description Default
batch ForwardBatch

The current batch information.

required
fastvideo_args FastVideoArgs

The inference arguments.

required

Returns:

Type Description
ForwardBatch

The batch with prepared latent variables.

Source code in fastvideo/pipelines/stages/latent_preparation.py
def forward(
    self,
    batch: ForwardBatch,
    fastvideo_args: FastVideoArgs,
) -> ForwardBatch:
    """
    Prepare initial latent variables for the diffusion process.

    Args:
        batch: The current batch information.
        fastvideo_args: The inference arguments.

    Returns:
        The batch with prepared latent variables.
    """

    latent_num_frames = None
    # Adjust video length based on VAE version if needed
    if hasattr(self, 'adjust_video_length'):
        latent_num_frames = self.adjust_video_length(batch, fastvideo_args)
    # Determine batch size; fall back to action/image inputs when no text encoder is present
    if not batch.prompt_embeds:
        if batch.keyboard_cond is not None:
            batch_size = batch.keyboard_cond.shape[0]
        elif batch.mouse_cond is not None:
            batch_size = batch.mouse_cond.shape[0]
        elif batch.image_embeds:
            batch_size = batch.image_embeds[0].shape[0]
        else:
            batch_size = 1
    elif isinstance(batch.prompt, list):
        batch_size = len(batch.prompt)
    elif batch.prompt is not None:
        batch_size = 1
    else:
        batch_size = batch.prompt_embeds[0].shape[0]

    # Adjust batch size for number of videos per prompt
    batch_size *= batch.num_videos_per_prompt

    # Get required parameters
    if not batch.prompt_embeds:
        # Create a dummy zero-length text embedding to satisfy downstream checks.
        # Matrix-Game models have text_dim=0 and ignore encoder_hidden_states.
        transformer_dtype = next(self.transformer.parameters()).dtype
        device = get_local_torch_device()
        dummy_prompt = torch.zeros(batch_size,
                                   0,
                                   self.transformer.hidden_size,
                                   device=device,
                                   dtype=transformer_dtype)
        batch.prompt_embeds = [dummy_prompt]
        batch.negative_prompt_embeds = []
        batch.do_classifier_free_guidance = False
    dtype = batch.prompt_embeds[0].dtype
    device = get_local_torch_device()
    generator = batch.generator
    latents = batch.latents
    num_frames = latent_num_frames if latent_num_frames is not None else batch.num_frames
    height = batch.height
    width = batch.width

    # TODO(will): remove this once we add input/output validation for stages
    if height is None or width is None:
        raise ValueError("Height and width must be provided")

    # Calculate latent shape
    bcthw_shape: tuple[int, ...] | None = None
    if self.use_btchw_layout:
        shape = (
            batch_size,
            num_frames,
            self.transformer.num_channels_latents,
            height // fastvideo_args.pipeline_config.vae_config.arch_config.
            spatial_compression_ratio,
            width // fastvideo_args.pipeline_config.vae_config.arch_config.
            spatial_compression_ratio,
        )
        bcthw_shape = tuple(shape[i] for i in [0, 2, 1, 3, 4])
    else:
        shape = (
            batch_size,
            self.transformer.num_channels_latents,
            num_frames,
            height // fastvideo_args.pipeline_config.vae_config.arch_config.
            spatial_compression_ratio,
            width // fastvideo_args.pipeline_config.vae_config.arch_config.
            spatial_compression_ratio,
        )
        bcthw_shape = shape

    # Validate generator if it's a list
    if isinstance(generator, list) and len(generator) != batch_size:
        raise ValueError(
            f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
            f" size of {batch_size}. Make sure the batch size matches the length of the generators."
        )
    # Generate or use provided latents
    if latents is None:
        latents = randn_tensor(
            shape,
            generator=generator,
            device=device,
            dtype=dtype,
        )
        if hasattr(self.scheduler, "init_noise_sigma"):
            latents = latents * self.scheduler.init_noise_sigma
    else:
        # Pre-initialized latents:
        # - For LongCat refine (refine_from or stage1_video present), we should not re-scale by init_noise_sigma.
        # - For other models, keep the original behavior.
        latents = latents.to(device)
        is_longcat_refine = (batch.refine_from
                             is not None) or (batch.stage1_video
                                              is not None)
        if (not is_longcat_refine) and hasattr(self.scheduler,
                                               "init_noise_sigma"):
            latents = latents * self.scheduler.init_noise_sigma

    # Update batch with prepared latents
    batch.latents = latents
    batch.raw_latent_shape = bcthw_shape

    return batch
fastvideo.pipelines.stages.LatentPreparationStage.verify_input
verify_input(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult

Verify latent preparation stage inputs.

Source code in fastvideo/pipelines/stages/latent_preparation.py
def verify_input(self, batch: ForwardBatch,
                 fastvideo_args: FastVideoArgs) -> VerificationResult:
    """Verify latent preparation stage inputs."""
    result = VerificationResult()
    result.add_check(
        "prompt_or_embeds", None,
        lambda _: V.string_or_list_strings(batch.prompt) or not batch.
        prompt_embeds or V.list_not_empty(batch.prompt_embeds))
    if batch.prompt_embeds:
        result.add_check("prompt_embeds", batch.prompt_embeds,
                         V.list_of_tensors)
    result.add_check("num_videos_per_prompt", batch.num_videos_per_prompt,
                     V.positive_int)
    result.add_check("generator", batch.generator,
                     V.generator_or_list_generators)
    result.add_check("num_frames", batch.num_frames, V.positive_int)
    result.add_check("height", batch.height, V.positive_int)
    result.add_check("width", batch.width, V.positive_int)
    result.add_check("latents", batch.latents, V.none_or_tensor)
    return result
fastvideo.pipelines.stages.LatentPreparationStage.verify_output
verify_output(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult

Verify latent preparation stage outputs.

Source code in fastvideo/pipelines/stages/latent_preparation.py
def verify_output(self, batch: ForwardBatch,
                  fastvideo_args: FastVideoArgs) -> VerificationResult:
    """Verify latent preparation stage outputs."""
    result = VerificationResult()
    result.add_check("latents", batch.latents,
                     [V.is_tensor, V.with_dims(5)])
    result.add_check("raw_latent_shape", batch.raw_latent_shape, V.is_tuple)
    return result

fastvideo.pipelines.stages.LongCatKVCacheInitStage

LongCatKVCacheInitStage(transformer)

Bases: PipelineStage

Pre-compute KV cache for conditioning frames.

After this stage: - batch.kv_cache_dict contains {block_idx: (k, v)} - batch.cond_latents contains the conditioning latents - batch.latents contains ONLY noise latents

Source code in fastvideo/pipelines/stages/longcat_kv_cache_init.py
def __init__(self, transformer):
    super().__init__()
    self.transformer = transformer

Functions

fastvideo.pipelines.stages.LongCatKVCacheInitStage.forward
forward(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> ForwardBatch

Initialize KV cache from conditioning latents.

Source code in fastvideo/pipelines/stages/longcat_kv_cache_init.py
def forward(
    self,
    batch: ForwardBatch,
    fastvideo_args: FastVideoArgs,
) -> ForwardBatch:
    """Initialize KV cache from conditioning latents."""

    # Check if KV cache is enabled
    use_kv_cache = getattr(fastvideo_args.pipeline_config, 'use_kv_cache',
                           True)
    if not use_kv_cache:
        batch.kv_cache_dict = {}
        batch.use_kv_cache = False
        logger.info("KV cache disabled, skipping initialization")
        return batch

    batch.use_kv_cache = True
    offload_kv_cache = getattr(fastvideo_args.pipeline_config,
                               'offload_kv_cache', False)

    # Get conditioning latents
    num_cond_latents = batch.num_cond_latents
    if num_cond_latents <= 0:
        batch.kv_cache_dict = {}
        logger.warning("num_cond_latents <= 0, skipping KV cache init")
        return batch

    # Extract conditioning latents
    cond_latents = batch.latents[:, :, :num_cond_latents].clone()

    logger.info(
        "Initializing KV cache for %d conditioning latents, shape: %s",
        num_cond_latents, cond_latents.shape)

    # Timestep = 0 for conditioning (they are "clean")
    B = cond_latents.shape[0]
    T_cond = cond_latents.shape[2]
    timestep = torch.zeros(B,
                           T_cond,
                           device=cond_latents.device,
                           dtype=cond_latents.dtype)

    # Empty prompt embeddings (cross-attn will be skipped)
    max_seq_len = 512
    # Get caption dimension from transformer config
    caption_dim = self.transformer.config.caption_channels
    empty_embeds = torch.zeros(B,
                               max_seq_len,
                               caption_dim,
                               device=cond_latents.device,
                               dtype=cond_latents.dtype)

    # Get transformer dtype
    if hasattr(self.transformer, 'module'):
        transformer_dtype = next(self.transformer.module.parameters()).dtype
    else:
        transformer_dtype = next(self.transformer.parameters()).dtype

    # Run transformer with return_kv=True, skip_crs_attn=True
    with (
            torch.no_grad(),
            set_forward_context(
                current_timestep=0,
                attn_metadata=None,
                forward_batch=batch,
            ),
            torch.autocast(device_type='cuda', dtype=transformer_dtype),
    ):
        _, kv_cache_dict = self.transformer(
            hidden_states=cond_latents.to(transformer_dtype),
            encoder_hidden_states=empty_embeds.to(transformer_dtype),
            timestep=timestep.to(transformer_dtype),
            return_kv=True,
            skip_crs_attn=True,
            offload_kv_cache=offload_kv_cache,
        )

    # Store cache and save cond_latents for later concatenation
    batch.kv_cache_dict = kv_cache_dict
    batch.cond_latents = cond_latents

    # Remove conditioning latents from main latents
    # After this, batch.latents contains ONLY noise frames
    batch.latents = batch.latents[:, :, num_cond_latents:]

    logger.info(
        "KV cache initialized: %d blocks, offload=%s, remaining latents shape: %s",
        len(kv_cache_dict), offload_kv_cache, batch.latents.shape)

    return batch

fastvideo.pipelines.stages.LongCatVCDenoisingStage

LongCatVCDenoisingStage(transformer, scheduler, pipeline=None, transformer_2=None, vae=None)

Bases: LongCatDenoisingStage

LongCat denoising with Video Continuation and KV cache support.

Key differences from I2V denoising: - Supports KV cache (reuses cached K/V from conditioning frames) - Handles larger num_cond_latents - Concatenates conditioning latents back after denoising

When use_kv_cache=True: - batch.latents contains ONLY noise frames (cond removed by KV cache init) - batch.kv_cache_dict contains cached K/V - batch.cond_latents contains conditioning latents for post-concat

When use_kv_cache=False: - batch.latents contains ALL frames (cond + noise) - Timestep masking: timestep[:, :num_cond_latents] = 0 - Selective denoising: only update noise frames

Source code in fastvideo/pipelines/stages/denoising.py
def __init__(self,
             transformer,
             scheduler,
             pipeline=None,
             transformer_2=None,
             vae=None) -> None:
    super().__init__()
    self.transformer = transformer
    self.transformer_2 = transformer_2
    self.scheduler = scheduler
    self.vae = vae
    self.pipeline = weakref.ref(pipeline) if pipeline else None
    attn_head_size = self.transformer.hidden_size // self.transformer.num_attention_heads
    self.attn_backend = get_attn_backend(
        head_size=attn_head_size,
        dtype=torch.float16,  # TODO(will): hack
        supported_attention_backends=(
            AttentionBackendEnum.SLIDING_TILE_ATTN,
            AttentionBackendEnum.VIDEO_SPARSE_ATTN,
            AttentionBackendEnum.VMOBA_ATTN,
            AttentionBackendEnum.FLASH_ATTN,
            AttentionBackendEnum.TORCH_SDPA,
            AttentionBackendEnum.SAGE_ATTN_THREE)  # hack
    )

Functions

fastvideo.pipelines.stages.LongCatVCDenoisingStage.forward
forward(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> ForwardBatch

Run denoising loop with VC conditioning and optional KV cache.

Source code in fastvideo/pipelines/stages/longcat_vc_denoising.py
def forward(
    self,
    batch: ForwardBatch,
    fastvideo_args: FastVideoArgs,
) -> ForwardBatch:
    """Run denoising loop with VC conditioning and optional KV cache."""

    # Load transformer if needed
    if not fastvideo_args.model_loaded["transformer"]:
        loader = TransformerLoader()
        self.transformer = loader.load(
            fastvideo_args.model_paths["transformer"], fastvideo_args)
        fastvideo_args.model_loaded["transformer"] = True

    # Setup
    target_dtype = torch.bfloat16
    autocast_enabled = (target_dtype != torch.float32
                        ) and not fastvideo_args.disable_autocast

    latents = batch.latents
    timesteps = batch.timesteps
    prompt_embeds = batch.prompt_embeds[0]
    prompt_attention_mask = (batch.prompt_attention_mask[0]
                             if batch.prompt_attention_mask else None)
    guidance_scale = batch.guidance_scale
    do_classifier_free_guidance = batch.do_classifier_free_guidance

    # Get VC-specific parameters
    num_cond_latents = getattr(batch, 'num_cond_latents', 0)
    use_kv_cache = getattr(batch, 'use_kv_cache', False)
    kv_cache_dict = getattr(batch, 'kv_cache_dict', {})

    logger.info(
        "VC Denoising: num_cond_latents=%d, use_kv_cache=%s, latent_shape=%s",
        num_cond_latents, use_kv_cache, latents.shape)

    # Prepare negative prompts for CFG
    if do_classifier_free_guidance:
        negative_prompt_embeds = batch.negative_prompt_embeds[0]
        negative_prompt_attention_mask = (batch.negative_attention_mask[0]
                                          if batch.negative_attention_mask
                                          else None)

        prompt_embeds_combined = torch.cat(
            [negative_prompt_embeds, prompt_embeds], dim=0)
        if prompt_attention_mask is not None:
            prompt_attention_mask_combined = torch.cat(
                [negative_prompt_attention_mask, prompt_attention_mask],
                dim=0)
        else:
            prompt_attention_mask_combined = None
    else:
        prompt_embeds_combined = prompt_embeds
        prompt_attention_mask_combined = prompt_attention_mask

    # Denoising loop
    num_inference_steps = len(timesteps)
    step_times = []

    with tqdm(total=num_inference_steps,
              desc="VC Denoising") as progress_bar:
        for i, t in enumerate(timesteps):
            step_start = time.time()

            # 1. Expand latents for CFG
            if do_classifier_free_guidance:
                latent_model_input = torch.cat([latents] * 2)
            else:
                latent_model_input = latents

            latent_model_input = latent_model_input.to(target_dtype)

            # 2. Expand timestep to match batch size
            timestep = t.expand(
                latent_model_input.shape[0]).to(target_dtype)

            # 3. Expand timestep to temporal dimension
            timestep = timestep.unsqueeze(-1).repeat(
                1, latent_model_input.shape[2])

            # 4. Timestep masking (only when NOT using KV cache)
            if not use_kv_cache and num_cond_latents > 0:
                timestep[:, :num_cond_latents] = 0

            # 5. Prepare transformer kwargs
            # IMPORTANT: num_cond_latents is ALWAYS passed - needed for RoPE position offset
            transformer_kwargs = {
                'num_cond_latents': num_cond_latents,
            }
            if use_kv_cache:
                transformer_kwargs['kv_cache_dict'] = kv_cache_dict

            # 6. Run transformer
            batch.is_cfg_negative = False
            with set_forward_context(
                    current_timestep=i,
                    attn_metadata=None,
                    forward_batch=batch,
            ), torch.autocast(device_type='cuda',
                              dtype=target_dtype,
                              enabled=autocast_enabled):
                noise_pred = self.transformer(
                    hidden_states=latent_model_input,
                    encoder_hidden_states=prompt_embeds_combined,
                    timestep=timestep,
                    encoder_attention_mask=prompt_attention_mask_combined,
                    **transformer_kwargs,
                )

            # 7. Apply CFG with optimized scale (CFG-zero)
            if do_classifier_free_guidance:
                noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)

                B = noise_pred_cond.shape[0]
                positive = noise_pred_cond.reshape(B, -1)
                negative = noise_pred_uncond.reshape(B, -1)

                st_star = self.optimized_scale(positive, negative)
                st_star = st_star.view(B, 1, 1, 1, 1)

                noise_pred = (
                    noise_pred_uncond * st_star + guidance_scale *
                    (noise_pred_cond - noise_pred_uncond * st_star))

            # 8. Negate for flow matching scheduler
            noise_pred = -noise_pred

            # 9. Scheduler step
            if use_kv_cache:
                # All latents are noise frames (conditioning is in cache)
                latents = self.scheduler.step(noise_pred,
                                              t,
                                              latents,
                                              return_dict=False)[0]
            else:
                # Only update noise frames (skip conditioning)
                if num_cond_latents > 0:
                    latents[:, :, num_cond_latents:] = self.scheduler.step(
                        noise_pred[:, :, num_cond_latents:],
                        t,
                        latents[:, :, num_cond_latents:],
                        return_dict=False,
                    )[0]
                else:
                    latents = self.scheduler.step(noise_pred,
                                                  t,
                                                  latents,
                                                  return_dict=False)[0]

            step_time = time.time() - step_start
            step_times.append(step_time)

            # Log timing for first few steps
            if i < 3:
                logger.info("Step %d: %.2fs", i, step_time)

            progress_bar.update()

    # 10. If using KV cache, concatenate conditioning latents back
    if use_kv_cache and hasattr(
            batch, 'cond_latents') and batch.cond_latents is not None:
        latents = torch.cat([batch.cond_latents, latents], dim=2)
        logger.info(
            "Concatenated conditioning latents back, final shape: %s",
            latents.shape)

    # Log average timing
    avg_time = sum(step_times) / len(step_times)
    logger.info("Average step time: %.2fs (total: %.1fs)", avg_time,
                sum(step_times))

    # Update batch with denoised latents
    batch.latents = latents
    return batch

fastvideo.pipelines.stages.LongCatVideoVAEEncodingStage

LongCatVideoVAEEncodingStage(vae)

Bases: PipelineStage

Encode video frames to latent space for VC conditioning.

This stage: 1. Loads video frames from path or uses provided frames 2. Takes the last num_cond_frames from the video 3. Preprocesses and stacks frames 4. Encodes via VAE to latent space 5. Applies LongCat-specific normalization 6. Calculates num_cond_latents

Source code in fastvideo/pipelines/stages/longcat_video_vae_encoding.py
def __init__(self, vae):
    super().__init__()
    self.vae = vae

Functions

fastvideo.pipelines.stages.LongCatVideoVAEEncodingStage.forward
forward(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> ForwardBatch

Encode video frames to latent for VC conditioning.

Source code in fastvideo/pipelines/stages/longcat_video_vae_encoding.py
def forward(
    self,
    batch: ForwardBatch,
    fastvideo_args: FastVideoArgs,
) -> ForwardBatch:
    """Encode video frames to latent for VC conditioning."""

    # Get video from batch - can be path, list of PIL images, or already loaded
    video = getattr(batch, 'video_frames', None) or getattr(
        batch, 'video_path', None)
    num_cond_frames = getattr(batch, 'num_cond_frames',
                              13)  # Default 13 for VC

    if video is None:
        raise ValueError(
            "video_frames or video_path must be provided for VC")

    # Load video if path
    if isinstance(video, str):
        from diffusers.utils import load_video
        video = load_video(video)
        logger.info("Loaded video from path: %d frames", len(video))

    # Take last num_cond_frames
    if len(video) > num_cond_frames:
        video = video[-num_cond_frames:]
        logger.info("Using last %d frames for conditioning",
                    num_cond_frames)
    elif len(video) < num_cond_frames:
        logger.warning(
            "Video has only %d frames, less than num_cond_frames=%d",
            len(video), num_cond_frames)
        num_cond_frames = len(video)

    # Get target dimensions
    height = batch.height
    width = batch.width

    if height is None or width is None:
        raise ValueError("height and width must be set for VC")

    # Preprocess and stack frames
    processed_frames = []
    for frame in video:
        if not isinstance(frame, PIL.Image.Image):
            raise TypeError(f"Frame must be PIL.Image, got {type(frame)}")

        frame = resize(frame, height, width, resize_mode="default")
        frame = pil_to_numpy(frame)  # Returns [1, H, W, C] then converted
        frame = numpy_to_pt(frame)  # Returns [1, C, H, W]
        frame = normalize(frame)  # to [-1, 1]
        processed_frames.append(frame)

    # Stack frames: [num_frames, C, H, W] -> [1, C, T, H, W]
    video_tensor = torch.cat(processed_frames, dim=0)  # [T, C, H, W]
    video_tensor = video_tensor.permute(1, 0, 2,
                                        3).unsqueeze(0)  # [1, C, T, H, W]
    video_tensor = video_tensor.to(get_local_torch_device(),
                                   dtype=torch.float32)

    logger.info("VC: Preprocessed video tensor shape: %s",
                video_tensor.shape)

    # Encode via VAE
    self.vae = self.vae.to(get_local_torch_device())

    # Setup VAE precision
    vae_dtype = PRECISION_TO_TYPE[
        fastvideo_args.pipeline_config.vae_precision]
    vae_autocast_enabled = (
        vae_dtype != torch.float32) and not fastvideo_args.disable_autocast

    with torch.autocast(device_type="cuda",
                        dtype=vae_dtype,
                        enabled=vae_autocast_enabled):
        if fastvideo_args.pipeline_config.vae_tiling:
            self.vae.enable_tiling()

        if not vae_autocast_enabled:
            video_tensor = video_tensor.to(vae_dtype)

        with torch.no_grad():
            encoder_output = self.vae.encode(video_tensor)
            latent = self.retrieve_latents(encoder_output, batch.generator)

    # Apply LongCat-specific normalization
    latent = self.normalize_latents(latent)

    # Calculate num_cond_latents
    # Formula: 1 + (num_cond_frames - 1) // vae_temporal_scale
    vae_temporal_scale = self.vae.config.scale_factor_temporal
    num_cond_latents = 1 + (num_cond_frames - 1) // vae_temporal_scale

    # Store in batch
    batch.video_latent = latent
    batch.num_cond_frames = num_cond_frames
    batch.num_cond_latents = num_cond_latents

    logger.info(
        "VC: Encoded %d frames to latent shape %s, num_cond_latents=%d",
        num_cond_frames, latent.shape, num_cond_latents)

    # Offload VAE if needed
    if fastvideo_args.vae_cpu_offload:
        self.vae.to("cpu")

    return batch
fastvideo.pipelines.stages.LongCatVideoVAEEncodingStage.normalize_latents
normalize_latents(latents: Tensor) -> Tensor

Apply LongCat-specific latent normalization.

Formula: (latents - mean) / std

Source code in fastvideo/pipelines/stages/longcat_video_vae_encoding.py
def normalize_latents(self, latents: torch.Tensor) -> torch.Tensor:
    """
    Apply LongCat-specific latent normalization.

    Formula: (latents - mean) / std
    """
    if not hasattr(self.vae.config, 'latents_mean') or not hasattr(
            self.vae.config, 'latents_std'):
        raise ValueError(
            "VAE config must have 'latents_mean' and 'latents_std' "
            "for LongCat normalization")

    latents_mean = torch.tensor(self.vae.config.latents_mean).view(
        1, self.vae.config.z_dim, 1, 1, 1).to(latents.device, latents.dtype)

    latents_std = torch.tensor(self.vae.config.latents_std).view(
        1, self.vae.config.z_dim, 1, 1, 1).to(latents.device, latents.dtype)

    return (latents - latents_mean) / latents_std
fastvideo.pipelines.stages.LongCatVideoVAEEncodingStage.retrieve_latents
retrieve_latents(encoder_output: Any, generator: Generator | None) -> Tensor

Sample from VAE posterior.

Source code in fastvideo/pipelines/stages/longcat_video_vae_encoding.py
def retrieve_latents(self, encoder_output: Any,
                     generator: torch.Generator | None) -> torch.Tensor:
    """Sample from VAE posterior."""
    if hasattr(encoder_output, 'sample'):
        return encoder_output.sample(generator)
    elif hasattr(encoder_output, 'latent_dist'):
        return encoder_output.latent_dist.sample(generator)
    elif hasattr(encoder_output, 'latents'):
        return encoder_output.latents
    else:
        raise AttributeError("Could not access latents from encoder output")

fastvideo.pipelines.stages.PipelineStage

Bases: ABC

Abstract base class for all pipeline stages.

A pipeline stage represents a discrete step in the diffusion process that can be composed with other stages to create a complete pipeline. Each stage is responsible for a specific part of the process, such as prompt encoding, latent preparation, etc.

Attributes

fastvideo.pipelines.stages.PipelineStage.device property
device: device

Get the device for this stage.

Functions

fastvideo.pipelines.stages.PipelineStage.__call__
__call__(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> ForwardBatch

Execute the stage's processing on the batch with optional verification and logging. Should not be overridden by subclasses.

Parameters:

Name Type Description Default
batch ForwardBatch

The current batch information.

required
fastvideo_args FastVideoArgs

The inference arguments.

required

Returns:

Type Description
ForwardBatch

The updated batch information after this stage's processing.

Source code in fastvideo/pipelines/stages/base.py
def __call__(
    self,
    batch: ForwardBatch,
    fastvideo_args: FastVideoArgs,
) -> ForwardBatch:
    """
    Execute the stage's processing on the batch with optional verification and logging.
    Should not be overridden by subclasses.

    Args:
        batch: The current batch information.
        fastvideo_args: The inference arguments.

    Returns:
        The updated batch information after this stage's processing.
    """
    stage_name = self.__class__.__name__

    # Check if verification is enabled (simple approach for prototype)
    enable_verification = getattr(fastvideo_args,
                                  'enable_stage_verification', False)

    if enable_verification:
        # Pre-execution input verification
        try:
            input_result = self.verify_input(batch, fastvideo_args)
            self._run_verification(input_result, stage_name, "input")
        except Exception as e:
            logger.error("Input verification failed for %s: %s", stage_name,
                         str(e))
            raise

    # Execute the actual stage logic
    if envs.FASTVIDEO_STAGE_LOGGING:
        logger.info("[%s] Starting execution", stage_name)
        start_time = time.perf_counter()

        try:
            result = self.forward(batch, fastvideo_args)
            execution_time = time.perf_counter() - start_time
            logger.info("[%s] Execution completed in %s ms", stage_name,
                        execution_time * 1000)
            batch.logging_info.add_stage_execution_time(
                stage_name, execution_time)
        except Exception as e:
            execution_time = time.perf_counter() - start_time
            logger.error("[%s] Error during execution after %s ms: %s",
                         stage_name, execution_time * 1000, e)
            logger.error("[%s] Traceback: %s", stage_name,
                         traceback.format_exc())
            raise
    else:
        # Direct execution (current behavior)
        result = self.forward(batch, fastvideo_args)

    if enable_verification:
        # Post-execution output verification
        try:
            output_result = self.verify_output(result, fastvideo_args)
            self._run_verification(output_result, stage_name, "output")
        except Exception as e:
            logger.error("Output verification failed for %s: %s",
                         stage_name, str(e))
            raise

    return result
fastvideo.pipelines.stages.PipelineStage.forward abstractmethod
forward(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> ForwardBatch

Forward pass of the stage's processing.

This method should be implemented by subclasses to provide the forward processing logic for the stage.

Parameters:

Name Type Description Default
batch ForwardBatch

The current batch information.

required
fastvideo_args FastVideoArgs

The inference arguments.

required

Returns:

Type Description
ForwardBatch

The updated batch information after this stage's processing.

Source code in fastvideo/pipelines/stages/base.py
@abstractmethod
def forward(
    self,
    batch: ForwardBatch,
    fastvideo_args: FastVideoArgs,
) -> ForwardBatch:
    """
    Forward pass of the stage's processing.

    This method should be implemented by subclasses to provide the forward
    processing logic for the stage.

    Args:
        batch: The current batch information.
        fastvideo_args: The inference arguments.

    Returns:
        The updated batch information after this stage's processing.
    """
    raise NotImplementedError
fastvideo.pipelines.stages.PipelineStage.set_logging
set_logging(enable: bool)

Enable or disable logging for this stage.

Parameters:

Name Type Description Default
enable bool

Whether to enable logging.

required
Source code in fastvideo/pipelines/stages/base.py
def set_logging(self, enable: bool):
    """
    Enable or disable logging for this stage.

    Args:
        enable: Whether to enable logging.
    """
    self._enable_logging = enable
fastvideo.pipelines.stages.PipelineStage.verify_input
verify_input(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult

Verify the input for the stage.

Example

from fastvideo.pipelines.stages.validators import V, VerificationResult

def verify_input(self, batch, fastvideo_args): result = VerificationResult() result.add_check("height", batch.height, V.positive_int_divisible(8)) result.add_check("width", batch.width, V.positive_int_divisible(8)) result.add_check("image_latent", batch.image_latent, V.is_tensor) return result

Parameters:

Name Type Description Default
batch ForwardBatch

The current batch information.

required
fastvideo_args FastVideoArgs

The inference arguments.

required

Returns:

Type Description
VerificationResult

A VerificationResult containing the verification status.

Source code in fastvideo/pipelines/stages/base.py
def verify_input(self, batch: ForwardBatch,
                 fastvideo_args: FastVideoArgs) -> VerificationResult:
    """
    Verify the input for the stage.

    Example:
        from fastvideo.pipelines.stages.validators import V, VerificationResult

        def verify_input(self, batch, fastvideo_args):
            result = VerificationResult()
            result.add_check("height", batch.height, V.positive_int_divisible(8))
            result.add_check("width", batch.width, V.positive_int_divisible(8))
            result.add_check("image_latent", batch.image_latent, V.is_tensor)
            return result

    Args:
        batch: The current batch information.
        fastvideo_args: The inference arguments.

    Returns:
        A VerificationResult containing the verification status.

    """
    # Default implementation - no verification
    return VerificationResult()
fastvideo.pipelines.stages.PipelineStage.verify_output
verify_output(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult

Verify the output for the stage.

Parameters:

Name Type Description Default
batch ForwardBatch

The current batch information.

required
fastvideo_args FastVideoArgs

The inference arguments.

required

Returns:

Type Description
VerificationResult

A VerificationResult containing the verification status.

Source code in fastvideo/pipelines/stages/base.py
def verify_output(self, batch: ForwardBatch,
                  fastvideo_args: FastVideoArgs) -> VerificationResult:
    """
    Verify the output for the stage.

    Args:
        batch: The current batch information.
        fastvideo_args: The inference arguments.

    Returns:
        A VerificationResult containing the verification status.
    """
    # Default implementation - no verification
    return VerificationResult()

fastvideo.pipelines.stages.RefImageEncodingStage

RefImageEncodingStage(image_encoder, image_processor)

Bases: ImageEncodingStage

Stage for encoding reference image prompts into embeddings for Wan2.1 Control models.

This stage extends ImageEncodingStage with specialized preprocessing for reference images.

Source code in fastvideo/pipelines/stages/image_encoding.py
def __init__(self, image_encoder, image_processor) -> None:
    """
    Initialize the prompt encoding stage.

    Args:
        enable_logging: Whether to enable logging for this stage.
        is_secondary: Whether this is a secondary image encoder.
    """
    super().__init__()
    self.image_processor = image_processor
    self.image_encoder = image_encoder

Functions

fastvideo.pipelines.stages.RefImageEncodingStage.forward
forward(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> ForwardBatch

Encode the prompt into image encoder hidden states.

Parameters:

Name Type Description Default
batch ForwardBatch

The current batch information.

required
fastvideo_args FastVideoArgs

The inference arguments.

required

Returns:

Type Description
ForwardBatch

The batch with encoded prompt embeddings.

Source code in fastvideo/pipelines/stages/image_encoding.py
@torch.no_grad()
def forward(
    self,
    batch: ForwardBatch,
    fastvideo_args: FastVideoArgs,
) -> ForwardBatch:
    """
    Encode the prompt into image encoder hidden states.

    Args:
        batch: The current batch information.
        fastvideo_args: The inference arguments.

    Returns:
        The batch with encoded prompt embeddings.
    """
    self.image_encoder = self.image_encoder.to(get_local_torch_device())

    image = batch.pil_image
    if image is None:
        image = create_default_image()
    # Preprocess reference image for CLIP encoder
    image_tensor = preprocess_reference_image_for_clip(
        image, get_local_torch_device())

    image_inputs = self.image_processor(images=image_tensor,
                                        return_tensors="pt").to(
                                            get_local_torch_device())
    with set_forward_context(current_timestep=0, attn_metadata=None):
        outputs = self.image_encoder(**image_inputs)
        image_embeds = outputs.last_hidden_state
    batch.image_embeds.append(image_embeds)

    if batch.pil_image is None:
        batch.image_embeds = [
            torch.zeros_like(x) for x in batch.image_embeds
        ]

    return batch

fastvideo.pipelines.stages.StepvideoPromptEncodingStage

StepvideoPromptEncodingStage(stepllm, clip)

Bases: PipelineStage

Stage for encoding prompts using the remote caption API.

This stage applies the magic string transformations and calls the remote caption service asynchronously to get: - primary prompt embeddings, - an attention mask, - and a clip embedding.

Source code in fastvideo/pipelines/stages/stepvideo_encoding.py
def __init__(self, stepllm, clip) -> None:
    super().__init__()
    # self.caption_client = caption_client  # This should have a call_caption(prompts: List[str]) method.
    self.stepllm = stepllm
    self.clip = clip

Functions

fastvideo.pipelines.stages.StepvideoPromptEncodingStage.verify_input
verify_input(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult

Verify stepvideo encoding stage inputs.

Source code in fastvideo/pipelines/stages/stepvideo_encoding.py
def verify_input(self, batch: ForwardBatch,
                 fastvideo_args: FastVideoArgs) -> VerificationResult:
    """Verify stepvideo encoding stage inputs."""
    result = VerificationResult()
    result.add_check("prompt", batch.prompt, V.string_not_empty)
    return result
fastvideo.pipelines.stages.StepvideoPromptEncodingStage.verify_output
verify_output(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult

Verify stepvideo encoding stage outputs.

Source code in fastvideo/pipelines/stages/stepvideo_encoding.py
def verify_output(self, batch: ForwardBatch,
                  fastvideo_args: FastVideoArgs) -> VerificationResult:
    """Verify stepvideo encoding stage outputs."""
    result = VerificationResult()
    result.add_check("prompt_embeds", batch.prompt_embeds,
                     [V.is_tensor, V.with_dims(3)])
    result.add_check("negative_prompt_embeds", batch.negative_prompt_embeds,
                     [V.is_tensor, V.with_dims(3)])
    result.add_check("prompt_attention_mask", batch.prompt_attention_mask,
                     [V.is_tensor, V.with_dims(2)])
    result.add_check("negative_attention_mask",
                     batch.negative_attention_mask,
                     [V.is_tensor, V.with_dims(2)])
    result.add_check("clip_embedding_pos", batch.clip_embedding_pos,
                     [V.is_tensor, V.with_dims(2)])
    result.add_check("clip_embedding_neg", batch.clip_embedding_neg,
                     [V.is_tensor, V.with_dims(2)])
    return result

fastvideo.pipelines.stages.TextEncodingStage

TextEncodingStage(text_encoders, tokenizers)

Bases: PipelineStage

Stage for encoding text prompts into embeddings for diffusion models.

This stage handles the encoding of text prompts into the embedding space expected by the diffusion model.

Initialize the prompt encoding stage.

Parameters:

Name Type Description Default
enable_logging

Whether to enable logging for this stage.

required
is_secondary

Whether this is a secondary text encoder.

required
Source code in fastvideo/pipelines/stages/text_encoding.py
def __init__(self, text_encoders, tokenizers) -> None:
    """
    Initialize the prompt encoding stage.

    Args:
        enable_logging: Whether to enable logging for this stage.
        is_secondary: Whether this is a secondary text encoder.
    """
    super().__init__()
    self.tokenizers = tokenizers
    self.text_encoders = text_encoders

Functions

fastvideo.pipelines.stages.TextEncodingStage.encode_text
encode_text(text: str | list[str], fastvideo_args: FastVideoArgs, encoder_index: int | list[int] | None = None, return_attention_mask: bool = False, return_type: str = 'list', device: device | str | None = None, dtype: dtype | None = None, max_length: int | None = None, truncation: bool | None = None, padding: bool | str | None = None)

Encode plain text using selected text encoder(s) and return embeddings.

Parameters:

Name Type Description Default
text str | list[str]

A single string or a list of strings to encode.

required
fastvideo_args FastVideoArgs

The inference arguments providing pipeline config, including tokenizer and encoder settings, preprocess and postprocess functions.

required
encoder_index int | list[int] | None

Encoder selector by index. Accepts an int or list of ints.

None
return_attention_mask bool

If True, also return attention masks for each selected encoder.

False
return_type str

"list" (default) returns a list aligned with selection; "dict" returns a dict keyed by encoder index as a string; "stack" stacks along a new first dimension (requires matching shapes).

'list'
device device | str | None

Optional device override for inputs; defaults to local torch device.

None
dtype dtype | None

Optional dtype to cast returned embeddings to.

None
max_length int | None

Optional per-call tokenizer override.

None
truncation bool | None

Optional per-call tokenizer override.

None
padding bool | str | None

Optional per-call tokenizer override.

None

Returns:

Type Description

Depending on return_type and return_attention_mask:

  • list: List[Tensor] or (List[Tensor], List[Tensor])
  • dict: Dict[str, Tensor] or (Dict[str, Tensor], Dict[str, Tensor])
  • stack: Tensor of shape [num_encoders, ...] or a tuple with stacked attention masks
Source code in fastvideo/pipelines/stages/text_encoding.py
@torch.no_grad()
def encode_text(
    self,
    text: str | list[str],
    fastvideo_args: FastVideoArgs,
    encoder_index: int | list[int] | None = None,
    return_attention_mask: bool = False,
    return_type: str = "list",  # one of: "list", "dict", "stack"
    device: torch.device | str | None = None,
    dtype: torch.dtype | None = None,
    max_length: int | None = None,
    truncation: bool | None = None,
    padding: bool | str | None = None,
):
    """
    Encode plain text using selected text encoder(s) and return embeddings.

    Args:
        text: A single string or a list of strings to encode.
        fastvideo_args: The inference arguments providing pipeline config,
            including tokenizer and encoder settings, preprocess and postprocess
            functions.
        encoder_index: Encoder selector by index. Accepts an int or list of ints.
        return_attention_mask: If True, also return attention masks for each
            selected encoder.
        return_type: "list" (default) returns a list aligned with selection;
            "dict" returns a dict keyed by encoder index as a string; "stack" stacks along a
            new first dimension (requires matching shapes).
        device: Optional device override for inputs; defaults to local torch device.
        dtype: Optional dtype to cast returned embeddings to.
        max_length: Optional per-call tokenizer override.
        truncation: Optional per-call tokenizer override.
        padding: Optional per-call tokenizer override.

    Returns:
        Depending on return_type and return_attention_mask:
        - list: List[Tensor] or (List[Tensor], List[Tensor])
        - dict: Dict[str, Tensor] or (Dict[str, Tensor], Dict[str, Tensor])
        - stack: Tensor of shape [num_encoders, ...] or a tuple with stacked
          attention masks
    """

    assert len(self.tokenizers) == len(self.text_encoders)
    assert len(self.text_encoders) == len(
        fastvideo_args.pipeline_config.text_encoder_configs)

    # Resolve selection into indices
    encoder_cfgs = fastvideo_args.pipeline_config.text_encoder_configs
    if encoder_index is None:
        indices: list[int] = [0]
    elif isinstance(encoder_index, int):
        indices = [encoder_index]
    else:
        indices = list(encoder_index)
    # validate range
    num_encoders = len(self.text_encoders)
    for idx in indices:
        if idx < 0 or idx >= num_encoders:
            raise IndexError(
                f"encoder index {idx} out of range [0, {num_encoders-1}]")

    # Validate indices are within range
    num_encoders = len(self.text_encoders)

    # Normalize input to list[str]
    assert isinstance(text, str | list)
    if isinstance(text, str):
        texts: list[str] = [text]
    else:
        texts = text

    embeds_list: list[torch.Tensor] = []
    attn_masks_list: list[torch.Tensor] = []

    preprocess_funcs = fastvideo_args.pipeline_config.preprocess_text_funcs
    postprocess_funcs = fastvideo_args.pipeline_config.postprocess_text_funcs
    encoder_cfgs = fastvideo_args.pipeline_config.text_encoder_configs

    if return_type not in ("list", "dict", "stack"):
        raise ValueError(
            f"Invalid return_type '{return_type}'. Expected one of: 'list', 'dict', 'stack'"
        )

    target_device = device if device is not None else get_local_torch_device(
    )

    for i in indices:
        tokenizer = self.tokenizers[i]
        text_encoder = self.text_encoders[i]
        encoder_config = encoder_cfgs[i]
        preprocess_func = preprocess_funcs[i]
        postprocess_func = postprocess_funcs[i]

        tok_kwargs = dict(encoder_config.tokenizer_kwargs)
        if max_length is not None:
            tok_kwargs["max_length"] = max_length
        elif hasattr(fastvideo_args.pipeline_config,
                     "text_encoder_max_lengths"):
            tok_kwargs[
                "max_length"] = fastvideo_args.pipeline_config.text_encoder_max_lengths[
                    i]

        if truncation is not None:
            tok_kwargs["truncation"] = truncation
        if padding is not None:
            tok_kwargs["padding"] = padding

        processed_texts: list[str] = []
        for prompt_str in texts:
            processed_text = preprocess_func(prompt_str)
            if processed_text is not None:
                processed_texts.append(processed_text)
            else:
                # Assuming batch_size = 1
                prompt_embeds = torch.zeros((1, tok_kwargs["max_length"],
                                             encoder_config.hidden_size),
                                            device=target_device)
                attention_mask = torch.zeros((1, tok_kwargs["max_length"]),
                                             device=target_device,
                                             dtype=torch.int64)
                embeds_list.append(prompt_embeds)
                attn_masks_list.append(attention_mask)
                return self.return_embeds(embeds_list, attn_masks_list,
                                          return_type,
                                          return_attention_mask, indices)

        if encoder_config.is_chat_model:
            text_inputs = tokenizer.apply_chat_template(
                processed_texts, **tok_kwargs).to(target_device)
        else:
            text_inputs = tokenizer(processed_texts,
                                    **tok_kwargs).to(target_device)

        input_ids = text_inputs["input_ids"]
        attention_mask = text_inputs["attention_mask"]

        with set_forward_context(current_timestep=0, attn_metadata=None):
            outputs = text_encoder(
                input_ids=input_ids,
                attention_mask=attention_mask,
                output_hidden_states=True,
            )

        try:
            prompt_embeds = postprocess_func(outputs)
        except Exception:
            prompt_embeds, attention_mask = postprocess_func(
                outputs, attention_mask)

        if dtype is not None:
            prompt_embeds = prompt_embeds.to(dtype=dtype)
        embeds_list.append(prompt_embeds)
        if return_attention_mask:
            attn_masks_list.append(attention_mask)

    return self.return_embeds(embeds_list, attn_masks_list, return_type,
                              return_attention_mask, indices)
fastvideo.pipelines.stages.TextEncodingStage.forward
forward(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> ForwardBatch

Encode the prompt into text encoder hidden states.

Parameters:

Name Type Description Default
batch ForwardBatch

The current batch information.

required
fastvideo_args FastVideoArgs

The inference arguments.

required

Returns:

Type Description
ForwardBatch

The batch with encoded prompt embeddings.

Source code in fastvideo/pipelines/stages/text_encoding.py
@torch.no_grad()
def forward(
    self,
    batch: ForwardBatch,
    fastvideo_args: FastVideoArgs,
) -> ForwardBatch:
    """
    Encode the prompt into text encoder hidden states.

    Args:
        batch: The current batch information.
        fastvideo_args: The inference arguments.

    Returns:
        The batch with encoded prompt embeddings.
    """
    assert len(self.tokenizers) == len(self.text_encoders)
    assert len(self.text_encoders) == len(
        fastvideo_args.pipeline_config.text_encoder_configs)

    # Encode positive prompt with all available encoders
    assert batch.prompt is not None
    prompt_text: str | list[str] = batch.prompt
    all_indices: list[int] = list(range(len(self.text_encoders)))
    prompt_embeds_list, prompt_masks_list = self.encode_text(
        prompt_text,
        fastvideo_args,
        encoder_index=all_indices,
        return_attention_mask=True,
    )

    for pe in prompt_embeds_list:
        batch.prompt_embeds.append(pe)
    if batch.prompt_attention_mask is not None:
        for am in prompt_masks_list:
            batch.prompt_attention_mask.append(am)

    # Encode negative prompt if CFG is enabled
    if batch.do_classifier_free_guidance:
        assert isinstance(batch.negative_prompt, str)
        neg_embeds_list, neg_masks_list = self.encode_text(
            batch.negative_prompt,
            fastvideo_args,
            encoder_index=all_indices,
            return_attention_mask=True,
        )

        assert batch.negative_prompt_embeds is not None
        for ne in neg_embeds_list:
            batch.negative_prompt_embeds.append(ne)
        if batch.negative_attention_mask is not None:
            for nm in neg_masks_list:
                batch.negative_attention_mask.append(nm)

    return batch
fastvideo.pipelines.stages.TextEncodingStage.verify_input
verify_input(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult

Verify text encoding stage inputs.

Source code in fastvideo/pipelines/stages/text_encoding.py
def verify_input(self, batch: ForwardBatch,
                 fastvideo_args: FastVideoArgs) -> VerificationResult:
    """Verify text encoding stage inputs."""
    result = VerificationResult()
    result.add_check("prompt", batch.prompt, V.string_or_list_strings)
    # result.add_check(
    #     "negative_prompt", batch.negative_prompt, lambda x: not batch.
    #     do_classifier_free_guidance or V.string_not_empty(x))
    result.add_check("do_classifier_free_guidance",
                     batch.do_classifier_free_guidance, V.bool_value)
    result.add_check("prompt_embeds", batch.prompt_embeds, V.is_list)
    result.add_check("negative_prompt_embeds", batch.negative_prompt_embeds,
                     V.none_or_list)
    return result
fastvideo.pipelines.stages.TextEncodingStage.verify_output
verify_output(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult

Verify text encoding stage outputs.

Source code in fastvideo/pipelines/stages/text_encoding.py
def verify_output(self, batch: ForwardBatch,
                  fastvideo_args: FastVideoArgs) -> VerificationResult:
    """Verify text encoding stage outputs."""
    result = VerificationResult()
    result.add_check("prompt_embeds", batch.prompt_embeds,
                     V.list_of_tensors_min_dims(2))
    result.add_check(
        "negative_prompt_embeds", batch.negative_prompt_embeds,
        lambda x: not batch.do_classifier_free_guidance or V.
        list_of_tensors_with_min_dims(x, 2))
    return result

fastvideo.pipelines.stages.TimestepPreparationStage

TimestepPreparationStage(scheduler)

Bases: PipelineStage

Stage for preparing timesteps for the diffusion process.

This stage handles the preparation of the timestep sequence that will be used during the diffusion process.

Source code in fastvideo/pipelines/stages/timestep_preparation.py
def __init__(self, scheduler) -> None:
    self.scheduler = scheduler

Functions

fastvideo.pipelines.stages.TimestepPreparationStage.forward
forward(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> ForwardBatch

Prepare timesteps for the diffusion process.

Parameters:

Name Type Description Default
batch ForwardBatch

The current batch information.

required
fastvideo_args FastVideoArgs

The inference arguments.

required

Returns:

Type Description
ForwardBatch

The batch with prepared timesteps.

Source code in fastvideo/pipelines/stages/timestep_preparation.py
def forward(
    self,
    batch: ForwardBatch,
    fastvideo_args: FastVideoArgs,
) -> ForwardBatch:
    """
    Prepare timesteps for the diffusion process.

    Args:
        batch: The current batch information.
        fastvideo_args: The inference arguments.

    Returns:
        The batch with prepared timesteps.
    """
    scheduler = self.scheduler
    device = get_local_torch_device()
    num_inference_steps = batch.num_inference_steps
    timesteps = batch.timesteps
    sigmas = batch.sigmas
    n_tokens = batch.n_tokens

    # Prepare extra kwargs for set_timesteps
    extra_set_timesteps_kwargs = {}
    if n_tokens is not None and "n_tokens" in inspect.signature(
            scheduler.set_timesteps).parameters:
        extra_set_timesteps_kwargs["n_tokens"] = n_tokens

    # Handle custom timesteps or sigmas
    if timesteps is not None and sigmas is not None:
        raise ValueError(
            "Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values"
        )

    if timesteps is not None:
        accepts_timesteps = "timesteps" in inspect.signature(
            scheduler.set_timesteps).parameters
        if not accepts_timesteps:
            raise ValueError(
                f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
                f" timestep schedules. Please check whether you are using the correct scheduler."
            )
        # Convert timesteps to CPU if it's a tensor (for numpy conversion in scheduler)
        if isinstance(timesteps, torch.Tensor):
            timesteps_for_scheduler = timesteps.cpu()
        else:
            timesteps_for_scheduler = timesteps
        scheduler.set_timesteps(timesteps=timesteps_for_scheduler,
                                device=device,
                                **extra_set_timesteps_kwargs)
        timesteps = scheduler.timesteps
    elif sigmas is not None:
        accept_sigmas = "sigmas" in inspect.signature(
            scheduler.set_timesteps).parameters
        if not accept_sigmas:
            raise ValueError(
                f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
                f" sigmas schedules. Please check whether you are using the correct scheduler."
            )
        scheduler.set_timesteps(sigmas=sigmas,
                                device=device,
                                **extra_set_timesteps_kwargs)
        timesteps = scheduler.timesteps
    else:
        scheduler.set_timesteps(num_inference_steps,
                                device=device,
                                **extra_set_timesteps_kwargs)
        timesteps = scheduler.timesteps

    # Update batch with prepared timesteps
    batch.timesteps = timesteps

    return batch
fastvideo.pipelines.stages.TimestepPreparationStage.verify_input
verify_input(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult

Verify timestep preparation stage inputs.

Source code in fastvideo/pipelines/stages/timestep_preparation.py
def verify_input(self, batch: ForwardBatch,
                 fastvideo_args: FastVideoArgs) -> VerificationResult:
    """Verify timestep preparation stage inputs."""
    result = VerificationResult()
    result.add_check("num_inference_steps", batch.num_inference_steps,
                     V.positive_int)
    result.add_check("timesteps", batch.timesteps, V.none_or_tensor)
    result.add_check("sigmas", batch.sigmas, V.none_or_list)
    result.add_check("n_tokens", batch.n_tokens, V.none_or_positive_int)
    return result
fastvideo.pipelines.stages.TimestepPreparationStage.verify_output
verify_output(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult

Verify timestep preparation stage outputs.

Source code in fastvideo/pipelines/stages/timestep_preparation.py
def verify_output(self, batch: ForwardBatch,
                  fastvideo_args: FastVideoArgs) -> VerificationResult:
    """Verify timestep preparation stage outputs."""
    result = VerificationResult()
    result.add_check("timesteps", batch.timesteps,
                     [V.is_tensor, V.with_dims(1)])
    return result

fastvideo.pipelines.stages.VideoVAEEncodingStage

VideoVAEEncodingStage(vae: ParallelTiledVAE)

Bases: ImageVAEEncodingStage

Stage for encoding video pixel representations into latent space.

This stage handles the encoding of video pixel representations for video-to-video generation and control. Inherits from ImageVAEEncodingStage to reuse common functionality.

Source code in fastvideo/pipelines/stages/image_encoding.py
def __init__(self, vae: ParallelTiledVAE) -> None:
    self.vae: ParallelTiledVAE = vae

Functions

fastvideo.pipelines.stages.VideoVAEEncodingStage.forward
forward(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> ForwardBatch

Encode video pixel representations into latent space.

Parameters:

Name Type Description Default
batch ForwardBatch

The current batch information.

required
fastvideo_args FastVideoArgs

The inference arguments.

required

Returns:

Type Description
ForwardBatch

The batch with encoded outputs.

Source code in fastvideo/pipelines/stages/image_encoding.py
def forward(
    self,
    batch: ForwardBatch,
    fastvideo_args: FastVideoArgs,
) -> ForwardBatch:
    """
    Encode video pixel representations into latent space.

    Args:
        batch: The current batch information.
        fastvideo_args: The inference arguments.

    Returns:
        The batch with encoded outputs.
    """
    assert batch.video_latent is not None, "Video latent input is required for VideoVAEEncodingStage"

    if fastvideo_args.mode == ExecutionMode.INFERENCE:
        assert batch.height is not None and isinstance(batch.height, int)
        assert batch.width is not None and isinstance(batch.width, int)
        assert batch.num_frames is not None and isinstance(
            batch.num_frames, int)
        height = batch.height
        width = batch.width
        num_frames = batch.num_frames
    elif fastvideo_args.mode == ExecutionMode.PREPROCESS:
        assert batch.height is not None and isinstance(batch.height, list)
        assert batch.width is not None and isinstance(batch.width, list)
        assert batch.num_frames is not None and isinstance(
            batch.num_frames, list)
        num_frames = batch.num_frames[0]
        height = batch.height[0]
        width = batch.width[0]

    self.vae = self.vae.to(get_local_torch_device())

    # Prepare video tensor from control video
    video_condition = self._prepare_control_video_tensor(
        batch.video_latent, num_frames, height,
        width).to(get_local_torch_device(), dtype=torch.float32)

    # Setup VAE precision
    vae_dtype = PRECISION_TO_TYPE[
        fastvideo_args.pipeline_config.vae_precision]
    vae_autocast_enabled = (
        vae_dtype != torch.float32) and not fastvideo_args.disable_autocast

    # Encode control video
    with torch.autocast(device_type="cuda",
                        dtype=vae_dtype,
                        enabled=vae_autocast_enabled):
        if fastvideo_args.pipeline_config.vae_tiling:
            self.vae.enable_tiling()
        if not vae_autocast_enabled:
            video_condition = video_condition.to(vae_dtype)
        encoder_output = self.vae.encode(video_condition)

    generator = batch.generator
    if generator is None:
        raise ValueError("Generator must be provided")
    latent_condition = self.retrieve_latents(encoder_output, generator)

    if (hasattr(self.vae, "shift_factor")
            and self.vae.shift_factor is not None):
        if isinstance(self.vae.shift_factor, torch.Tensor):
            latent_condition -= self.vae.shift_factor.to(
                latent_condition.device, latent_condition.dtype)
        else:
            latent_condition -= self.vae.shift_factor

    if isinstance(self.vae.scaling_factor, torch.Tensor):
        latent_condition = latent_condition * self.vae.scaling_factor.to(
            latent_condition.device, latent_condition.dtype)
    else:
        latent_condition = latent_condition * self.vae.scaling_factor

    batch.video_latent = latent_condition

    # Offload models if needed
    if hasattr(self, 'maybe_free_model_hooks'):
        self.maybe_free_model_hooks()

    self.vae.to("cpu")

    return batch
fastvideo.pipelines.stages.VideoVAEEncodingStage.verify_input
verify_input(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult

Verify video encoding stage inputs.

Source code in fastvideo/pipelines/stages/image_encoding.py
def verify_input(self, batch: ForwardBatch,
                 fastvideo_args: FastVideoArgs) -> VerificationResult:
    """Verify video encoding stage inputs."""
    result = VerificationResult()
    result.add_check("video_latent", batch.video_latent, V.not_none)
    result.add_check("generator", batch.generator,
                     V.generator_or_list_generators)
    if fastvideo_args.mode == ExecutionMode.PREPROCESS:
        result.add_check("height", batch.height, V.list_not_empty)
        result.add_check("width", batch.width, V.list_not_empty)
        result.add_check("num_frames", batch.num_frames, V.list_not_empty)
    else:
        result.add_check("height", batch.height, V.positive_int)
        result.add_check("width", batch.width, V.positive_int)
        result.add_check("num_frames", batch.num_frames, V.positive_int)
    return result
fastvideo.pipelines.stages.VideoVAEEncodingStage.verify_output
verify_output(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult

Verify video encoding stage outputs.

Source code in fastvideo/pipelines/stages/image_encoding.py
def verify_output(self, batch: ForwardBatch,
                  fastvideo_args: FastVideoArgs) -> VerificationResult:
    """Verify video encoding stage outputs."""
    result = VerificationResult()
    result.add_check("video_latent", batch.video_latent,
                     [V.is_tensor, V.with_dims(5)])
    return result

Modules

fastvideo.pipelines.stages.base

Base classes for pipeline stages.

This module defines the abstract base classes for pipeline stages that can be composed to create complete diffusion pipelines.

Classes

fastvideo.pipelines.stages.base.PipelineStage

Bases: ABC

Abstract base class for all pipeline stages.

A pipeline stage represents a discrete step in the diffusion process that can be composed with other stages to create a complete pipeline. Each stage is responsible for a specific part of the process, such as prompt encoding, latent preparation, etc.

Attributes
fastvideo.pipelines.stages.base.PipelineStage.device property
device: device

Get the device for this stage.

Functions
fastvideo.pipelines.stages.base.PipelineStage.__call__
__call__(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> ForwardBatch

Execute the stage's processing on the batch with optional verification and logging. Should not be overridden by subclasses.

Parameters:

Name Type Description Default
batch ForwardBatch

The current batch information.

required
fastvideo_args FastVideoArgs

The inference arguments.

required

Returns:

Type Description
ForwardBatch

The updated batch information after this stage's processing.

Source code in fastvideo/pipelines/stages/base.py
def __call__(
    self,
    batch: ForwardBatch,
    fastvideo_args: FastVideoArgs,
) -> ForwardBatch:
    """
    Execute the stage's processing on the batch with optional verification and logging.
    Should not be overridden by subclasses.

    Args:
        batch: The current batch information.
        fastvideo_args: The inference arguments.

    Returns:
        The updated batch information after this stage's processing.
    """
    stage_name = self.__class__.__name__

    # Check if verification is enabled (simple approach for prototype)
    enable_verification = getattr(fastvideo_args,
                                  'enable_stage_verification', False)

    if enable_verification:
        # Pre-execution input verification
        try:
            input_result = self.verify_input(batch, fastvideo_args)
            self._run_verification(input_result, stage_name, "input")
        except Exception as e:
            logger.error("Input verification failed for %s: %s", stage_name,
                         str(e))
            raise

    # Execute the actual stage logic
    if envs.FASTVIDEO_STAGE_LOGGING:
        logger.info("[%s] Starting execution", stage_name)
        start_time = time.perf_counter()

        try:
            result = self.forward(batch, fastvideo_args)
            execution_time = time.perf_counter() - start_time
            logger.info("[%s] Execution completed in %s ms", stage_name,
                        execution_time * 1000)
            batch.logging_info.add_stage_execution_time(
                stage_name, execution_time)
        except Exception as e:
            execution_time = time.perf_counter() - start_time
            logger.error("[%s] Error during execution after %s ms: %s",
                         stage_name, execution_time * 1000, e)
            logger.error("[%s] Traceback: %s", stage_name,
                         traceback.format_exc())
            raise
    else:
        # Direct execution (current behavior)
        result = self.forward(batch, fastvideo_args)

    if enable_verification:
        # Post-execution output verification
        try:
            output_result = self.verify_output(result, fastvideo_args)
            self._run_verification(output_result, stage_name, "output")
        except Exception as e:
            logger.error("Output verification failed for %s: %s",
                         stage_name, str(e))
            raise

    return result
fastvideo.pipelines.stages.base.PipelineStage.forward abstractmethod
forward(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> ForwardBatch

Forward pass of the stage's processing.

This method should be implemented by subclasses to provide the forward processing logic for the stage.

Parameters:

Name Type Description Default
batch ForwardBatch

The current batch information.

required
fastvideo_args FastVideoArgs

The inference arguments.

required

Returns:

Type Description
ForwardBatch

The updated batch information after this stage's processing.

Source code in fastvideo/pipelines/stages/base.py
@abstractmethod
def forward(
    self,
    batch: ForwardBatch,
    fastvideo_args: FastVideoArgs,
) -> ForwardBatch:
    """
    Forward pass of the stage's processing.

    This method should be implemented by subclasses to provide the forward
    processing logic for the stage.

    Args:
        batch: The current batch information.
        fastvideo_args: The inference arguments.

    Returns:
        The updated batch information after this stage's processing.
    """
    raise NotImplementedError
fastvideo.pipelines.stages.base.PipelineStage.set_logging
set_logging(enable: bool)

Enable or disable logging for this stage.

Parameters:

Name Type Description Default
enable bool

Whether to enable logging.

required
Source code in fastvideo/pipelines/stages/base.py
def set_logging(self, enable: bool):
    """
    Enable or disable logging for this stage.

    Args:
        enable: Whether to enable logging.
    """
    self._enable_logging = enable
fastvideo.pipelines.stages.base.PipelineStage.verify_input
verify_input(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult

Verify the input for the stage.

Example

from fastvideo.pipelines.stages.validators import V, VerificationResult

def verify_input(self, batch, fastvideo_args): result = VerificationResult() result.add_check("height", batch.height, V.positive_int_divisible(8)) result.add_check("width", batch.width, V.positive_int_divisible(8)) result.add_check("image_latent", batch.image_latent, V.is_tensor) return result

Parameters:

Name Type Description Default
batch ForwardBatch

The current batch information.

required
fastvideo_args FastVideoArgs

The inference arguments.

required

Returns:

Type Description
VerificationResult

A VerificationResult containing the verification status.

Source code in fastvideo/pipelines/stages/base.py
def verify_input(self, batch: ForwardBatch,
                 fastvideo_args: FastVideoArgs) -> VerificationResult:
    """
    Verify the input for the stage.

    Example:
        from fastvideo.pipelines.stages.validators import V, VerificationResult

        def verify_input(self, batch, fastvideo_args):
            result = VerificationResult()
            result.add_check("height", batch.height, V.positive_int_divisible(8))
            result.add_check("width", batch.width, V.positive_int_divisible(8))
            result.add_check("image_latent", batch.image_latent, V.is_tensor)
            return result

    Args:
        batch: The current batch information.
        fastvideo_args: The inference arguments.

    Returns:
        A VerificationResult containing the verification status.

    """
    # Default implementation - no verification
    return VerificationResult()
fastvideo.pipelines.stages.base.PipelineStage.verify_output
verify_output(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult

Verify the output for the stage.

Parameters:

Name Type Description Default
batch ForwardBatch

The current batch information.

required
fastvideo_args FastVideoArgs

The inference arguments.

required

Returns:

Type Description
VerificationResult

A VerificationResult containing the verification status.

Source code in fastvideo/pipelines/stages/base.py
def verify_output(self, batch: ForwardBatch,
                  fastvideo_args: FastVideoArgs) -> VerificationResult:
    """
    Verify the output for the stage.

    Args:
        batch: The current batch information.
        fastvideo_args: The inference arguments.

    Returns:
        A VerificationResult containing the verification status.
    """
    # Default implementation - no verification
    return VerificationResult()
fastvideo.pipelines.stages.base.StageVerificationError

Bases: Exception

Exception raised when stage verification fails.

Functions

fastvideo.pipelines.stages.causal_denoising

Classes

fastvideo.pipelines.stages.causal_denoising.CausalDMDDenosingStage
CausalDMDDenosingStage(transformer, scheduler, transformer_2=None, vae=None)

Bases: DenoisingStage

Denoising stage for causal diffusion.

Source code in fastvideo/pipelines/stages/causal_denoising.py
def __init__(self,
             transformer,
             scheduler,
             transformer_2=None,
             vae=None) -> None:
    super().__init__(transformer, scheduler, transformer_2)
    # KV and cross-attention cache state (initialized on first forward)
    self.transformer = transformer
    self.transformer_2 = transformer_2
    self.vae = vae
    # Model-dependent constants (aligned with causal_inference.py assumptions)
    self.num_transformer_blocks = len(self.transformer.blocks)
    self.num_frames_per_block = self.transformer.config.arch_config.num_frames_per_block
    self.sliding_window_num_frames = self.transformer.config.arch_config.sliding_window_num_frames

    try:
        self.local_attn_size = getattr(self.transformer.model,
                                       "local_attn_size",
                                       -1)  # type: ignore
    except Exception:
        self.local_attn_size = -1
Functions
fastvideo.pipelines.stages.causal_denoising.CausalDMDDenosingStage.verify_input
verify_input(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult

Verify denoising stage inputs.

Source code in fastvideo/pipelines/stages/causal_denoising.py
def verify_input(self, batch: ForwardBatch,
                 fastvideo_args: FastVideoArgs) -> VerificationResult:
    """Verify denoising stage inputs."""
    result = VerificationResult()
    result.add_check("latents", batch.latents,
                     [V.is_tensor, V.with_dims(5)])
    result.add_check("prompt_embeds", batch.prompt_embeds, V.list_not_empty)
    result.add_check("image_embeds", batch.image_embeds, V.is_list)
    result.add_check("image_latent", batch.image_latent,
                     V.none_or_tensor_with_dims(5))
    result.add_check("num_inference_steps", batch.num_inference_steps,
                     V.positive_int)
    result.add_check("guidance_scale", batch.guidance_scale,
                     V.positive_float)
    result.add_check("eta", batch.eta, V.non_negative_float)
    result.add_check("generator", batch.generator,
                     V.generator_or_list_generators)
    result.add_check("do_classifier_free_guidance",
                     batch.do_classifier_free_guidance, V.bool_value)
    result.add_check(
        "negative_prompt_embeds", batch.negative_prompt_embeds, lambda x:
        not batch.do_classifier_free_guidance or V.list_not_empty(x))
    return result

Functions

fastvideo.pipelines.stages.conditioning

Conditioning stage for diffusion pipelines.

Classes

fastvideo.pipelines.stages.conditioning.ConditioningStage

Bases: PipelineStage

Stage for applying conditioning to the diffusion process.

This stage handles the application of conditioning, such as classifier-free guidance, to the diffusion process.

Functions
fastvideo.pipelines.stages.conditioning.ConditioningStage.forward
forward(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> ForwardBatch

Apply conditioning to the diffusion process.

Parameters:

Name Type Description Default
batch ForwardBatch

The current batch information.

required
fastvideo_args FastVideoArgs

The inference arguments.

required

Returns:

Type Description
ForwardBatch

The batch with applied conditioning.

Source code in fastvideo/pipelines/stages/conditioning.py
@torch.no_grad()
def forward(
    self,
    batch: ForwardBatch,
    fastvideo_args: FastVideoArgs,
) -> ForwardBatch:
    """
    Apply conditioning to the diffusion process.

    Args:
        batch: The current batch information.
        fastvideo_args: The inference arguments.

    Returns:
        The batch with applied conditioning.
    """
    # TODO!!
    if not batch.do_classifier_free_guidance:
        return batch
    else:
        return batch

    logger.info("batch.negative_prompt_embeds: %s",
                batch.negative_prompt_embeds)
    logger.info("do_classifier_free_guidance: %s",
                batch.do_classifier_free_guidance)
    logger.info("cfg_scale: %s", batch.guidance_scale)

    # Ensure negative prompt embeddings are available
    assert batch.negative_prompt_embeds is not None, (
        "Negative prompt embeddings are required for classifier-free guidance"
    )

    # Concatenate primary embeddings and masks
    batch.prompt_embeds = torch.cat(
        [batch.negative_prompt_embeds, batch.prompt_embeds])
    if batch.attention_mask is not None:
        batch.attention_mask = torch.cat(
            [batch.negative_attention_mask, batch.attention_mask])

    # Concatenate secondary embeddings and masks if present
    if batch.prompt_embeds_2 is not None:
        batch.prompt_embeds_2 = torch.cat(
            [batch.negative_prompt_embeds_2, batch.prompt_embeds_2])
    if batch.attention_mask_2 is not None:
        batch.attention_mask_2 = torch.cat(
            [batch.negative_attention_mask_2, batch.attention_mask_2])

    return batch
fastvideo.pipelines.stages.conditioning.ConditioningStage.verify_input
verify_input(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult

Verify conditioning stage inputs.

Source code in fastvideo/pipelines/stages/conditioning.py
def verify_input(self, batch: ForwardBatch,
                 fastvideo_args: FastVideoArgs) -> VerificationResult:
    """Verify conditioning stage inputs."""
    result = VerificationResult()
    if not batch.prompt_embeds:
        # No text encoder/prompt embeddings: skip checks and effectively disable CFG.
        batch.do_classifier_free_guidance = False
        return result
    result.add_check("do_classifier_free_guidance",
                     batch.do_classifier_free_guidance, V.bool_value)
    result.add_check("guidance_scale", batch.guidance_scale,
                     V.positive_float)
    # Matrix-Game allow empty prompt
    # embeddings when CFG isn't enabled.
    if batch.do_classifier_free_guidance or batch.prompt_embeds:
        result.add_check("prompt_embeds", batch.prompt_embeds,
                         V.list_not_empty)
        result.add_check(
            "negative_prompt_embeds", batch.negative_prompt_embeds, lambda
            x: not batch.do_classifier_free_guidance or V.list_not_empty(x))
    return result
fastvideo.pipelines.stages.conditioning.ConditioningStage.verify_output
verify_output(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult

Verify conditioning stage outputs.

Source code in fastvideo/pipelines/stages/conditioning.py
def verify_output(self, batch: ForwardBatch,
                  fastvideo_args: FastVideoArgs) -> VerificationResult:
    """Verify conditioning stage outputs."""
    result = VerificationResult()
    if batch.prompt_embeds is None or not batch.prompt_embeds:
        batch.do_classifier_free_guidance = False
        return result
    if batch.do_classifier_free_guidance or batch.prompt_embeds:
        result.add_check("prompt_embeds", batch.prompt_embeds,
                         V.list_not_empty)
    return result

Functions

fastvideo.pipelines.stages.decoding

Decoding stage for diffusion pipelines.

Classes

fastvideo.pipelines.stages.decoding.DecodingStage
DecodingStage(vae, pipeline=None)

Bases: PipelineStage

Stage for decoding latent representations into pixel space.

This stage handles the decoding of latent representations into the final output format (e.g., pixel values).

Source code in fastvideo/pipelines/stages/decoding.py
def __init__(self, vae, pipeline=None) -> None:
    self.vae: ParallelTiledVAE = vae
    self.pipeline = weakref.ref(pipeline) if pipeline else None
Functions
fastvideo.pipelines.stages.decoding.DecodingStage.decode
decode(latents: Tensor, fastvideo_args: FastVideoArgs) -> Tensor

Decode latent representations into pixel space using VAE.

Parameters:

Name Type Description Default
latents Tensor

Input latent tensor with shape (batch, channels, frames, height_latents, width_latents)

required
fastvideo_args FastVideoArgs

Configuration containing: - disable_autocast: Whether to disable automatic mixed precision (default: False) - pipeline_config.vae_precision: VAE computation precision ("fp32", "fp16", "bf16") - pipeline_config.vae_tiling: Whether to enable VAE tiling for memory efficiency

required

Returns:

Type Description
Tensor

Decoded video tensor with shape (batch, channels, frames, height, width),

Tensor

normalized to [0, 1] range and moved to CPU as float32

Source code in fastvideo/pipelines/stages/decoding.py
@torch.no_grad()
def decode(self, latents: torch.Tensor,
           fastvideo_args: FastVideoArgs) -> torch.Tensor:
    """
    Decode latent representations into pixel space using VAE.

    Args:
        latents: Input latent tensor with shape (batch, channels, frames, height_latents, width_latents)
        fastvideo_args: Configuration containing:
            - disable_autocast: Whether to disable automatic mixed precision (default: False)
            - pipeline_config.vae_precision: VAE computation precision ("fp32", "fp16", "bf16")
            - pipeline_config.vae_tiling: Whether to enable VAE tiling for memory efficiency

    Returns:
        Decoded video tensor with shape (batch, channels, frames, height, width), 
        normalized to [0, 1] range and moved to CPU as float32
    """
    self.vae = self.vae.to(get_local_torch_device())
    latents = latents.to(get_local_torch_device())

    # Setup VAE precision
    vae_dtype = PRECISION_TO_TYPE[
        fastvideo_args.pipeline_config.vae_precision]
    vae_autocast_enabled = (
        vae_dtype != torch.float32) and not fastvideo_args.disable_autocast

    latents = self._denormalize_latents(latents)

    # Decode latents
    with torch.autocast(device_type="cuda",
                        dtype=vae_dtype,
                        enabled=vae_autocast_enabled):
        if fastvideo_args.pipeline_config.vae_tiling:
            self.vae.enable_tiling()
        # if fastvideo_args.vae_sp:
        #     self.vae.enable_parallel()
        if not vae_autocast_enabled:
            latents = latents.to(vae_dtype)
        image = self.vae.decode(latents)

    # Normalize image to [0, 1] range
    image = (image / 2 + 0.5).clamp(0, 1)
    return image
fastvideo.pipelines.stages.decoding.DecodingStage.forward
forward(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> ForwardBatch

Decode latent representations into pixel space.

This method processes the batch through the VAE decoder, converting latent representations to pixel-space video/images. It also optionally decodes trajectory latents for visualization purposes.

Parameters:

Name Type Description Default
batch ForwardBatch

The current batch containing: - latents: Tensor to decode (batch, channels, frames, height_latents, width_latents) - return_trajectory_decoded (optional): Flag to decode trajectory latents - trajectory_latents (optional): Latents at different timesteps - trajectory_timesteps (optional): Corresponding timesteps

required
fastvideo_args FastVideoArgs

Configuration containing: - output_type: "latent" to skip decoding, otherwise decode to pixels - vae_cpu_offload: Whether to offload VAE to CPU after decoding - model_loaded: Track VAE loading state - model_paths: Path to VAE model if loading needed

required

Returns:

Type Description
ForwardBatch

Modified batch with: - output: Decoded frames (batch, channels, frames, height, width) as CPU float32 - trajectory_decoded (if requested): List of decoded frames per timestep

Source code in fastvideo/pipelines/stages/decoding.py
@torch.no_grad()
def forward(
    self,
    batch: ForwardBatch,
    fastvideo_args: FastVideoArgs,
) -> ForwardBatch:
    """
    Decode latent representations into pixel space.

    This method processes the batch through the VAE decoder, converting latent
    representations to pixel-space video/images. It also optionally decodes
    trajectory latents for visualization purposes.

    Args:
        batch: The current batch containing:
            - latents: Tensor to decode (batch, channels, frames, height_latents, width_latents)
            - return_trajectory_decoded (optional): Flag to decode trajectory latents
            - trajectory_latents (optional): Latents at different timesteps
            - trajectory_timesteps (optional): Corresponding timesteps
        fastvideo_args: Configuration containing:
            - output_type: "latent" to skip decoding, otherwise decode to pixels
            - vae_cpu_offload: Whether to offload VAE to CPU after decoding
            - model_loaded: Track VAE loading state
            - model_paths: Path to VAE model if loading needed

    Returns:
        Modified batch with:
            - output: Decoded frames (batch, channels, frames, height, width) as CPU float32
            - trajectory_decoded (if requested): List of decoded frames per timestep
    """
    # load vae if not already loaded (used for memory constrained devices)
    pipeline = self.pipeline() if self.pipeline else None
    if not fastvideo_args.model_loaded["vae"]:
        loader = VAELoader()
        self.vae = loader.load(fastvideo_args.model_paths["vae"],
                               fastvideo_args)
        if pipeline:
            pipeline.add_module("vae", self.vae)
        fastvideo_args.model_loaded["vae"] = True

    if fastvideo_args.output_type == "latent":
        frames = batch.latents
    else:
        frames = self.decode(batch.latents, fastvideo_args)

    # decode trajectory latents if needed
    if batch.return_trajectory_decoded:
        batch.trajectory_decoded = []
        assert batch.trajectory_latents is not None, "batch should have trajectory latents"
        for idx in range(batch.trajectory_latents.shape[1]):
            # batch.trajectory_latents is [batch_size, timesteps, channels, frames, height, width]
            cur_latent = batch.trajectory_latents[:, idx, :, :, :, :]
            cur_timestep = batch.trajectory_timesteps[idx]
            logger.info("decoding trajectory latent for timestep: %s",
                        cur_timestep)
            decoded_frames = self.decode(cur_latent, fastvideo_args)
            batch.trajectory_decoded.append(decoded_frames.cpu().float())

    # Convert to CPU float32 for compatibility
    frames = frames.cpu().float()

    # Crop padding if this is a LongCat refinement
    if hasattr(batch, 'num_cond_frames_added') and hasattr(
            batch, 'new_frame_size_before_padding'):
        num_cond_frames_added = batch.num_cond_frames_added
        new_frame_size = batch.new_frame_size_before_padding
        if num_cond_frames_added > 0 or frames.shape[2] != new_frame_size:
            # frames is [B, C, T, H, W], crop temporal dimension
            frames = frames[:, :,
                            num_cond_frames_added:num_cond_frames_added +
                            new_frame_size, :, :]
            logger.info(
                "Cropped LongCat refinement padding: %s:%s, final shape: %s",
                num_cond_frames_added,
                num_cond_frames_added + new_frame_size, frames.shape)

    # Update batch with decoded image
    batch.output = frames

    # Offload models if needed
    if hasattr(self, 'maybe_free_model_hooks'):
        self.maybe_free_model_hooks()

    if fastvideo_args.vae_cpu_offload:
        self.vae.to("cpu")

    if torch.backends.mps.is_available():
        del self.vae
        if pipeline is not None and "vae" in pipeline.modules:
            del pipeline.modules["vae"]
        fastvideo_args.model_loaded["vae"] = False

    return batch
fastvideo.pipelines.stages.decoding.DecodingStage.streaming_decode
streaming_decode(latents: Tensor, fastvideo_args: FastVideoArgs, cache: list[Tensor | None] | None = None, is_first_chunk: bool = False) -> tuple[Tensor, list[Tensor | None]]

Decode latent representations into pixel space using VAE with streaming cache.

Parameters:

Name Type Description Default
latents Tensor

Input latent tensor with shape (batch, channels, frames, height_latents, width_latents)

required
fastvideo_args FastVideoArgs

Configuration object.

required
cache list[Tensor | None] | None

VAE cache from previous call, or None to initialize a new cache.

None
is_first_chunk bool

Whether this is the first chunk.

False

Returns:

Type Description
tuple[Tensor, list[Tensor | None]]

A tuple of (decoded_frames, updated_cache).

Source code in fastvideo/pipelines/stages/decoding.py
@torch.no_grad()
def streaming_decode(
    self,
    latents: torch.Tensor,
    fastvideo_args: FastVideoArgs,
    cache: list[torch.Tensor | None] | None = None,
    is_first_chunk: bool = False,
) -> tuple[torch.Tensor, list[torch.Tensor | None]]:
    """
    Decode latent representations into pixel space using VAE with streaming cache.

    Args:
        latents: Input latent tensor with shape (batch, channels, frames, height_latents, width_latents)
        fastvideo_args: Configuration object.
        cache: VAE cache from previous call, or None to initialize a new cache.
        is_first_chunk: Whether this is the first chunk.

    Returns:
        A tuple of (decoded_frames, updated_cache).
    """
    self.vae = self.vae.to(get_local_torch_device())
    latents = latents.to(get_local_torch_device())

    # Setup VAE precision
    vae_dtype = PRECISION_TO_TYPE[
        fastvideo_args.pipeline_config.vae_precision]
    vae_autocast_enabled = (
        vae_dtype != torch.float32) and not fastvideo_args.disable_autocast

    latents = self._denormalize_latents(latents)

    # Initialize cache if needed
    if cache is None:
        cache = self.vae.get_streaming_cache()

    # Decode latents with streaming
    with torch.autocast(device_type="cuda",
                        dtype=vae_dtype,
                        enabled=vae_autocast_enabled):
        if fastvideo_args.pipeline_config.vae_tiling:
            self.vae.enable_tiling()
        if not vae_autocast_enabled:
            latents = latents.to(vae_dtype)
        image, cache = self.vae.streaming_decode(latents, cache,
                                                 is_first_chunk)

    # Normalize image to [0, 1] range
    image = (image / 2 + 0.5).clamp(0, 1)
    assert cache is not None, "cache should not be None after streaming_decode"
    return image, cache
fastvideo.pipelines.stages.decoding.DecodingStage.verify_input
verify_input(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult

Verify decoding stage inputs.

Source code in fastvideo/pipelines/stages/decoding.py
def verify_input(self, batch: ForwardBatch,
                 fastvideo_args: FastVideoArgs) -> VerificationResult:
    """Verify decoding stage inputs."""
    result = VerificationResult()
    # Denoised latents for VAE decoding: [batch_size, channels, frames, height_latents, width_latents]
    result.add_check("latents", batch.latents,
                     [V.is_tensor, V.with_dims(5)])
    return result
fastvideo.pipelines.stages.decoding.DecodingStage.verify_output
verify_output(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult

Verify decoding stage outputs.

Source code in fastvideo/pipelines/stages/decoding.py
def verify_output(self, batch: ForwardBatch,
                  fastvideo_args: FastVideoArgs) -> VerificationResult:
    """Verify decoding stage outputs."""
    result = VerificationResult()
    # Decoded video/images: [batch_size, channels, frames, height, width]
    result.add_check("output", batch.output, [V.is_tensor, V.with_dims(5)])
    return result

Functions

fastvideo.pipelines.stages.denoising

Denoising stage for diffusion pipelines.

Classes

fastvideo.pipelines.stages.denoising.Cosmos25DenoisingStage
Cosmos25DenoisingStage(transformer, scheduler, pipeline=None)

Bases: CosmosDenoisingStage

Denoising stage for Cosmos 2.5 DiT (expects 1D/2D timestep, not 5D).

Source code in fastvideo/pipelines/stages/denoising.py
def __init__(self, transformer, scheduler, pipeline=None) -> None:
    super().__init__(transformer, scheduler, pipeline)
fastvideo.pipelines.stages.denoising.CosmosDenoisingStage
CosmosDenoisingStage(transformer, scheduler, pipeline=None)

Bases: DenoisingStage

Denoising stage for Cosmos models using FlowMatchEulerDiscreteScheduler.

Source code in fastvideo/pipelines/stages/denoising.py
def __init__(self, transformer, scheduler, pipeline=None) -> None:
    super().__init__(transformer, scheduler, pipeline)
Functions
fastvideo.pipelines.stages.denoising.CosmosDenoisingStage.verify_input
verify_input(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult

Verify Cosmos denoising stage inputs.

Source code in fastvideo/pipelines/stages/denoising.py
def verify_input(self, batch: ForwardBatch,
                 fastvideo_args: FastVideoArgs) -> VerificationResult:
    """Verify Cosmos denoising stage inputs."""
    result = VerificationResult()
    result.add_check("latents", batch.latents,
                     [V.is_tensor, V.with_dims(5)])
    result.add_check("prompt_embeds", batch.prompt_embeds, V.list_not_empty)
    result.add_check("num_inference_steps", batch.num_inference_steps,
                     V.positive_int)
    result.add_check("guidance_scale", batch.guidance_scale,
                     V.positive_float)
    result.add_check("do_classifier_free_guidance",
                     batch.do_classifier_free_guidance, V.bool_value)
    result.add_check(
        "negative_prompt_embeds", batch.negative_prompt_embeds, lambda x:
        not batch.do_classifier_free_guidance or V.list_not_empty(x))
    return result
fastvideo.pipelines.stages.denoising.CosmosDenoisingStage.verify_output
verify_output(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult

Verify Cosmos denoising stage outputs.

Source code in fastvideo/pipelines/stages/denoising.py
def verify_output(self, batch: ForwardBatch,
                  fastvideo_args: FastVideoArgs) -> VerificationResult:
    """Verify Cosmos denoising stage outputs."""
    result = VerificationResult()
    result.add_check("latents", batch.latents,
                     [V.is_tensor, V.with_dims(5)])
    return result
fastvideo.pipelines.stages.denoising.DenoisingStage
DenoisingStage(transformer, scheduler, pipeline=None, transformer_2=None, vae=None)

Bases: PipelineStage

Stage for running the denoising loop in diffusion pipelines.

This stage handles the iterative denoising process that transforms the initial noise into the final output.

Source code in fastvideo/pipelines/stages/denoising.py
def __init__(self,
             transformer,
             scheduler,
             pipeline=None,
             transformer_2=None,
             vae=None) -> None:
    super().__init__()
    self.transformer = transformer
    self.transformer_2 = transformer_2
    self.scheduler = scheduler
    self.vae = vae
    self.pipeline = weakref.ref(pipeline) if pipeline else None
    attn_head_size = self.transformer.hidden_size // self.transformer.num_attention_heads
    self.attn_backend = get_attn_backend(
        head_size=attn_head_size,
        dtype=torch.float16,  # TODO(will): hack
        supported_attention_backends=(
            AttentionBackendEnum.SLIDING_TILE_ATTN,
            AttentionBackendEnum.VIDEO_SPARSE_ATTN,
            AttentionBackendEnum.VMOBA_ATTN,
            AttentionBackendEnum.FLASH_ATTN,
            AttentionBackendEnum.TORCH_SDPA,
            AttentionBackendEnum.SAGE_ATTN_THREE)  # hack
    )
Functions
fastvideo.pipelines.stages.denoising.DenoisingStage.forward
forward(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> ForwardBatch

Run the denoising loop.

Parameters:

Name Type Description Default
batch ForwardBatch

The current batch information.

required
fastvideo_args FastVideoArgs

The inference arguments.

required

Returns:

Type Description
ForwardBatch

The batch with denoised latents.

Source code in fastvideo/pipelines/stages/denoising.py
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
def forward(
    self,
    batch: ForwardBatch,
    fastvideo_args: FastVideoArgs,
) -> ForwardBatch:
    """
    Run the denoising loop.

    Args:
        batch: The current batch information.
        fastvideo_args: The inference arguments.

    Returns:
        The batch with denoised latents.
    """
    pipeline = self.pipeline() if self.pipeline else None
    if not fastvideo_args.model_loaded["transformer"]:
        loader = TransformerLoader()
        self.transformer = loader.load(
            fastvideo_args.model_paths["transformer"], fastvideo_args)
        if pipeline:
            pipeline.add_module("transformer", self.transformer)
        fastvideo_args.model_loaded["transformer"] = True

    # Prepare extra step kwargs for scheduler
    extra_step_kwargs = self.prepare_extra_func_kwargs(
        self.scheduler.step,
        {
            "generator": batch.generator,
            "eta": batch.eta
        },
    )

    # Setup precision and autocast settings
    # TODO(will): make the precision configurable for inference
    # target_dtype = PRECISION_TO_TYPE[fastvideo_args.precision]
    target_dtype = torch.bfloat16
    autocast_enabled = (target_dtype != torch.float32
                        ) and not fastvideo_args.disable_autocast

    # Get timesteps and calculate warmup steps
    timesteps = batch.timesteps
    # TODO(will): remove this once we add input/output validation for stages
    if timesteps is None:
        raise ValueError("Timesteps must be provided")
    num_inference_steps = batch.num_inference_steps
    num_warmup_steps = len(
        timesteps) - num_inference_steps * self.scheduler.order

    # Prepare image latents and embeddings for I2V generation
    image_embeds = batch.image_embeds
    if len(image_embeds) > 0:
        assert not torch.isnan(
            image_embeds[0]).any(), "image_embeds contains nan"
        image_embeds = [
            image_embed.to(target_dtype) for image_embed in image_embeds
        ]

    image_kwargs = self.prepare_extra_func_kwargs(
        self.transformer.forward,
        {
            "encoder_hidden_states_image": image_embeds,
            "mask_strategy": dict_to_3d_list(
                None, t_max=50, l_max=60, h_max=24)
        },
    )

    pos_cond_kwargs = self.prepare_extra_func_kwargs(
        self.transformer.forward,
        {
            "encoder_hidden_states_2": batch.clip_embedding_pos,
            "encoder_attention_mask": batch.prompt_attention_mask,
        },
    )

    neg_cond_kwargs = self.prepare_extra_func_kwargs(
        self.transformer.forward,
        {
            "encoder_hidden_states_2": batch.clip_embedding_neg,
            "encoder_attention_mask": batch.negative_attention_mask,
        },
    )

    action_kwargs = self.prepare_extra_func_kwargs(
        self.transformer.forward,
        {
            "mouse_cond": batch.mouse_cond,
            "keyboard_cond": batch.keyboard_cond,
        },
    )

    # Prepare STA parameters
    if st_attn_available and self.attn_backend == SlidingTileAttentionBackend:
        self.prepare_sta_param(batch, fastvideo_args)

    # Get latents and embeddings
    latents = batch.latents
    prompt_embeds = batch.prompt_embeds
    assert not torch.isnan(
        prompt_embeds[0]).any(), "prompt_embeds contains nan"
    if batch.do_classifier_free_guidance:
        neg_prompt_embeds = batch.negative_prompt_embeds
        assert neg_prompt_embeds is not None
        assert not torch.isnan(
            neg_prompt_embeds[0]).any(), "neg_prompt_embeds contains nan"

    # (Wan2.2) Calculate timestep to switch from high noise expert to low noise expert
    boundary_ratio = fastvideo_args.pipeline_config.dit_config.boundary_ratio
    if batch.boundary_ratio is not None:
        logger.info("Overriding boundary ratio from %s to %s",
                    boundary_ratio, batch.boundary_ratio)
        boundary_ratio = batch.boundary_ratio

    if boundary_ratio is not None:
        boundary_timestep = boundary_ratio * self.scheduler.num_train_timesteps
    else:
        boundary_timestep = None
    latent_model_input = latents.to(target_dtype)
    assert latent_model_input.shape[0] == 1, "only support batch size 1"

    if fastvideo_args.pipeline_config.ti2v_task and batch.pil_image is not None:
        # TI2V directly replaces the first frame of the latent with
        # the image latent instead of appending along the channel dim
        assert batch.image_latent is None, "TI2V task should not have image latents"
        assert self.vae is not None, "VAE is not provided for TI2V task"
        z = self.vae.encode(batch.pil_image).mean.float()
        if (hasattr(self.vae, "shift_factor")
                and self.vae.shift_factor is not None):
            if isinstance(self.vae.shift_factor, torch.Tensor):
                z -= self.vae.shift_factor.to(z.device, z.dtype)
            else:
                z -= self.vae.shift_factor

        if isinstance(self.vae.scaling_factor, torch.Tensor):
            z = z * self.vae.scaling_factor.to(z.device, z.dtype)
        else:
            z = z * self.vae.scaling_factor

        latent_model_input = latent_model_input.squeeze(0)
        _, mask2 = masks_like([latent_model_input], zero=True)

        latent_model_input = (1. -
                              mask2[0]) * z + mask2[0] * latent_model_input
        # latent_model_input = latent_model_input.unsqueeze(0)
        latent_model_input = latent_model_input.to(get_local_torch_device())
        latents = latent_model_input
        F = batch.num_frames
        temporal_scale = fastvideo_args.pipeline_config.vae_config.arch_config.scale_factor_temporal
        spatial_scale = fastvideo_args.pipeline_config.vae_config.arch_config.scale_factor_spatial
        patch_size = fastvideo_args.pipeline_config.dit_config.arch_config.patch_size
        seq_len = ((F - 1) // temporal_scale +
                   1) * (batch.height // spatial_scale) * (
                       batch.width // spatial_scale) // (patch_size[1] *
                                                         patch_size[2])

    # Initialize lists for ODE trajectory
    trajectory_timesteps: list[torch.Tensor] = []
    trajectory_latents: list[torch.Tensor] = []

    # Run denoising loop
    with self.progress_bar(total=num_inference_steps) as progress_bar:
        for i, t in enumerate(timesteps):
            # Skip if interrupted
            if hasattr(self, 'interrupt') and self.interrupt:
                continue

            if boundary_timestep is None or t >= boundary_timestep:
                if (fastvideo_args.dit_cpu_offload
                        and not fastvideo_args.dit_layerwise_offload
                        and self.transformer_2 is not None and next(
                            self.transformer_2.parameters()).device.type
                        == 'cuda'):
                    self.transformer_2.to('cpu')
                current_model = self.transformer
                if (fastvideo_args.dit_cpu_offload
                        and not fastvideo_args.dit_layerwise_offload
                        and not fastvideo_args.use_fsdp_inference
                        and current_model is not None):
                    transformer_device = next(
                        current_model.parameters()).device.type
                    if transformer_device == 'cpu':
                        current_model.to(get_local_torch_device())
                current_guidance_scale = batch.guidance_scale
            else:
                # low-noise stage in wan2.2
                if (fastvideo_args.dit_cpu_offload
                        and not fastvideo_args.dit_layerwise_offload
                        and next(self.transformer.parameters()).device.type
                        == 'cuda'):
                    self.transformer.to('cpu')
                current_model = self.transformer_2
                if (fastvideo_args.dit_cpu_offload
                        and not fastvideo_args.dit_layerwise_offload
                        and not fastvideo_args.use_fsdp_inference
                        and current_model is not None):
                    transformer_2_device = next(
                        current_model.parameters()).device.type
                    if transformer_2_device == 'cpu':
                        current_model.to(get_local_torch_device())
                current_guidance_scale = batch.guidance_scale_2
            assert current_model is not None, "current_model is None"

            # Expand latents for V2V/I2V
            latent_model_input = latents.to(target_dtype)
            if batch.video_latent is not None:
                latent_model_input = torch.cat([
                    latent_model_input, batch.video_latent,
                    torch.zeros_like(latents)
                ],
                                               dim=1).to(target_dtype)
            elif batch.image_latent is not None:
                assert not fastvideo_args.pipeline_config.ti2v_task, "image latents should not be provided for TI2V task"
                latent_model_input = torch.cat(
                    [latent_model_input, batch.image_latent],
                    dim=1).to(target_dtype)

            assert not torch.isnan(
                latent_model_input).any(), "latent_model_input contains nan"
            if fastvideo_args.pipeline_config.ti2v_task and batch.pil_image is not None:
                timestep = torch.stack([t]).to(get_local_torch_device())
                temp_ts = (mask2[0][0][:, ::2, ::2] * timestep).flatten()
                temp_ts = torch.cat([
                    temp_ts,
                    temp_ts.new_ones(seq_len - temp_ts.size(0)) * timestep
                ])
                timestep = temp_ts.unsqueeze(0)
                t_expand = timestep.repeat(latent_model_input.shape[0], 1)
            else:
                t_expand = t.repeat(latent_model_input.shape[0])

            latent_model_input = self.scheduler.scale_model_input(
                latent_model_input, t)

            # Prepare inputs for transformer
            guidance_expand = (
                torch.tensor(
                    [fastvideo_args.pipeline_config.embedded_cfg_scale] *
                    latent_model_input.shape[0],
                    dtype=torch.float32,
                    device=get_local_torch_device(),
                ).to(target_dtype) *
                1000.0 if fastvideo_args.pipeline_config.embedded_cfg_scale
                is not None else None)

            # Predict noise residual
            with torch.autocast(device_type="cuda",
                                dtype=target_dtype,
                                enabled=autocast_enabled):
                if (st_attn_available
                        and self.attn_backend == SlidingTileAttentionBackend
                    ) or (vsa_available and self.attn_backend
                          == VideoSparseAttentionBackend):
                    self.attn_metadata_builder_cls = self.attn_backend.get_builder_cls(
                    )

                    if self.attn_metadata_builder_cls is not None:
                        self.attn_metadata_builder = self.attn_metadata_builder_cls(
                        )
                        # TODO(will): clean this up
                        attn_metadata = self.attn_metadata_builder.build(  # type: ignore
                            current_timestep=i,  # type: ignore
                            raw_latent_shape=batch.
                            raw_latent_shape[2:5],  # type: ignore
                            patch_size=fastvideo_args.
                            pipeline_config.  # type: ignore
                            dit_config.patch_size,  # type: ignore
                            STA_param=batch.STA_param,  # type: ignore
                            VSA_sparsity=fastvideo_args.
                            VSA_sparsity,  # type: ignore
                            device=get_local_torch_device(),
                        )
                        assert attn_metadata is not None, "attn_metadata cannot be None"
                    else:
                        attn_metadata = None
                elif (vmoba_attn_available
                      and self.attn_backend == VMOBAAttentionBackend):
                    self.attn_metadata_builder_cls = self.attn_backend.get_builder_cls(
                    )
                    if self.attn_metadata_builder_cls is not None:
                        self.attn_metadata_builder = self.attn_metadata_builder_cls(
                        )
                        # Prepare V-MoBA parameters from config
                        moba_params = fastvideo_args.moba_config.copy()
                        moba_params.update({
                            "current_timestep":
                            i,
                            "raw_latent_shape":
                            batch.raw_latent_shape[2:5],
                            "patch_size":
                            fastvideo_args.pipeline_config.dit_config.
                            patch_size,
                            "device":
                            get_local_torch_device(),
                        })
                        attn_metadata = self.attn_metadata_builder.build(
                            **moba_params)
                        assert attn_metadata is not None, "attn_metadata cannot be None"
                    else:
                        attn_metadata = None
                else:
                    attn_metadata = None
                # TODO(will): finalize the interface. vLLM uses this to
                # support torch dynamo compilation. They pass in
                # attn_metadata, vllm_config, and num_tokens. We can pass in
                # fastvideo_args or training_args, and attn_metadata.
                batch.is_cfg_negative = False
                with set_forward_context(
                        current_timestep=i,
                        attn_metadata=attn_metadata,
                        forward_batch=batch,
                        # fastvideo_args=fastvideo_args
                ):
                    # Run transformer
                    noise_pred = current_model(
                        latent_model_input,
                        prompt_embeds,
                        t_expand,
                        guidance=guidance_expand,
                        **image_kwargs,
                        **pos_cond_kwargs,
                        **action_kwargs,
                    )

                if batch.do_classifier_free_guidance:
                    batch.is_cfg_negative = True
                    with set_forward_context(
                            current_timestep=i,
                            attn_metadata=attn_metadata,
                            forward_batch=batch,
                    ):
                        noise_pred_uncond = current_model(
                            latent_model_input,
                            neg_prompt_embeds,
                            t_expand,
                            guidance=guidance_expand,
                            **image_kwargs,
                            **neg_cond_kwargs,
                            **action_kwargs,
                        )

                    noise_pred_text = noise_pred
                    noise_pred = noise_pred_uncond + current_guidance_scale * (
                        noise_pred_text - noise_pred_uncond)

                    # Apply guidance rescale if needed
                    if batch.guidance_rescale > 0.0:
                        # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
                        noise_pred = self.rescale_noise_cfg(
                            noise_pred,
                            noise_pred_text,
                            guidance_rescale=batch.guidance_rescale,
                        )
                # Compute the previous noisy sample
                latents = self.scheduler.step(noise_pred,
                                              t,
                                              latents,
                                              **extra_step_kwargs,
                                              return_dict=False)[0]
                if fastvideo_args.pipeline_config.ti2v_task and batch.pil_image is not None:
                    latents = latents.squeeze(0)
                    latents = (1. - mask2[0]) * z + mask2[0] * latents
                    # latents = latents.unsqueeze(0)

            # save trajectory latents if needed
            if batch.return_trajectory_latents:
                trajectory_timesteps.append(t)
                trajectory_latents.append(latents)

            # Update progress bar
            if i == len(timesteps) - 1 or (
                (i + 1) > num_warmup_steps and
                (i + 1) % self.scheduler.order == 0
                    and progress_bar is not None):
                progress_bar.update()

    trajectory_tensor: torch.Tensor | None = None
    if trajectory_latents:
        trajectory_tensor = torch.stack(trajectory_latents, dim=1)
        trajectory_timesteps_tensor = torch.stack(trajectory_timesteps,
                                                  dim=0)
    else:
        trajectory_tensor = None
        trajectory_timesteps_tensor = None

    if trajectory_tensor is not None and trajectory_timesteps_tensor is not None:
        batch.trajectory_timesteps = trajectory_timesteps_tensor.cpu()
        batch.trajectory_latents = trajectory_tensor.cpu()

    # Update batch with final latents
    batch.latents = latents

    if fastvideo_args.dit_layerwise_offload:
        mgr = getattr(self.transformer, "_layerwise_offload_manager", None)
        if mgr is not None and getattr(mgr, "enabled", False):
            mgr.release_all()
        if self.transformer_2 is not None:
            mgr2 = getattr(self.transformer_2, "_layerwise_offload_manager",
                           None)
            if mgr2 is not None and getattr(mgr2, "enabled", False):
                mgr2.release_all()

    # Save STA mask search results if needed
    if st_attn_available and self.attn_backend == SlidingTileAttentionBackend and fastvideo_args.STA_mode == STA_Mode.STA_SEARCHING:
        self.save_sta_search_results(batch)

    # deallocate transformer if on mps
    if torch.backends.mps.is_available():
        logger.info("Memory before deallocating transformer: %s",
                    torch.mps.current_allocated_memory())
        del self.transformer
        if pipeline is not None and "transformer" in pipeline.modules:
            del pipeline.modules["transformer"]
        fastvideo_args.model_loaded["transformer"] = False
        logger.info("Memory after deallocating transformer: %s",
                    torch.mps.current_allocated_memory())

    return batch
fastvideo.pipelines.stages.denoising.DenoisingStage.prepare_extra_func_kwargs
prepare_extra_func_kwargs(func, kwargs) -> dict[str, Any]

Prepare extra kwargs for the scheduler step / denoise step.

Parameters:

Name Type Description Default
func

The function to prepare kwargs for.

required
kwargs

The kwargs to prepare.

required

Returns:

Type Description
dict[str, Any]

The prepared kwargs.

Source code in fastvideo/pipelines/stages/denoising.py
def prepare_extra_func_kwargs(self, func, kwargs) -> dict[str, Any]:
    """
    Prepare extra kwargs for the scheduler step / denoise step.

    Args:
        func: The function to prepare kwargs for.
        kwargs: The kwargs to prepare.

    Returns:
        The prepared kwargs.
    """
    extra_step_kwargs = {}
    for k, v in kwargs.items():
        accepts = k in set(inspect.signature(func).parameters.keys())
        if accepts:
            extra_step_kwargs[k] = v
    return extra_step_kwargs
fastvideo.pipelines.stages.denoising.DenoisingStage.prepare_sta_param
prepare_sta_param(batch: ForwardBatch, fastvideo_args: FastVideoArgs)

Prepare Sliding Tile Attention (STA) parameters and settings.

Parameters:

Name Type Description Default
batch ForwardBatch

The current batch information.

required
fastvideo_args FastVideoArgs

The inference arguments.

required
Source code in fastvideo/pipelines/stages/denoising.py
def prepare_sta_param(self, batch: ForwardBatch,
                      fastvideo_args: FastVideoArgs):
    """
    Prepare Sliding Tile Attention (STA) parameters and settings.

    Args:
        batch: The current batch information.
        fastvideo_args: The inference arguments.
    """
    # TODO(kevin): STA mask search, currently only support Wan2.1 with 69x768x1280
    from fastvideo.attention.backends.STA_configuration import configure_sta
    STA_mode = fastvideo_args.STA_mode
    skip_time_steps = fastvideo_args.skip_time_steps
    if batch.timesteps is None:
        raise ValueError("Timesteps must be provided")
    timesteps_num = batch.timesteps.shape[0]

    logger.info("STA_mode: %s", STA_mode)
    if (batch.num_frames, batch.height,
            batch.width) != (69, 768, 1280) and STA_mode != "STA_inference":
        raise NotImplementedError(
            "STA mask search/tuning is not supported for this resolution")

    if STA_mode == STA_Mode.STA_SEARCHING or STA_mode == STA_Mode.STA_TUNING or STA_mode == STA_Mode.STA_TUNING_CFG:
        size = (batch.width, batch.height)
        if size == (1280, 768):
            # TODO: make it configurable
            sparse_mask_candidates_searching = [
                "3, 1, 10", "1, 5, 7", "3, 3, 3", "1, 6, 5", "1, 3, 10",
                "3, 6, 1"
            ]
            sparse_mask_candidates_tuning = [
                "3, 1, 10", "1, 5, 7", "3, 3, 3", "1, 6, 5", "1, 3, 10",
                "3, 6, 1"
            ]
            full_mask = ["3,6,10"]
        else:
            raise NotImplementedError(
                "STA mask search is not supported for this resolution")
    layer_num = self.transformer.config.num_layers
    # specific for HunyuanVideo
    if hasattr(self.transformer.config, "num_single_layers"):
        layer_num += self.transformer.config.num_single_layers
    head_num = self.transformer.config.num_attention_heads

    if STA_mode == STA_Mode.STA_SEARCHING:
        STA_param = configure_sta(
            mode=STA_Mode.STA_SEARCHING,
            layer_num=layer_num,
            head_num=head_num,
            time_step_num=timesteps_num,
            mask_candidates=sparse_mask_candidates_searching +
            full_mask,  # last is full mask; Can add more sparse masks while keep last one as full mask
        )
    elif STA_mode == STA_Mode.STA_TUNING:
        STA_param = configure_sta(
            mode=STA_Mode.STA_TUNING,
            layer_num=layer_num,
            head_num=head_num,
            time_step_num=timesteps_num,
            mask_search_files_path=
            f'output/mask_search_result_pos_{size[0]}x{size[1]}/',
            mask_candidates=sparse_mask_candidates_tuning,
            full_attention_mask=[int(x) for x in full_mask[0].split(',')],
            skip_time_steps=
            skip_time_steps,  # Use full attention for first 12 steps
            save_dir=
            f'output/mask_search_strategy_{size[0]}x{size[1]}/',  # Custom save directory
            timesteps=timesteps_num)
    elif STA_mode == STA_Mode.STA_TUNING_CFG:
        STA_param = configure_sta(
            mode=STA_Mode.STA_TUNING_CFG,
            layer_num=layer_num,
            head_num=head_num,
            time_step_num=timesteps_num,
            mask_search_files_path_pos=
            f'output/mask_search_result_pos_{size[0]}x{size[1]}/',
            mask_search_files_path_neg=
            f'output/mask_search_result_neg_{size[0]}x{size[1]}/',
            mask_candidates=sparse_mask_candidates_tuning,
            full_attention_mask=[int(x) for x in full_mask[0].split(',')],
            skip_time_steps=skip_time_steps,
            save_dir=f'output/mask_search_strategy_{size[0]}x{size[1]}/',
            timesteps=timesteps_num)
    elif STA_mode == STA_Mode.STA_INFERENCE:
        import fastvideo.envs as envs
        config_file = envs.FASTVIDEO_ATTENTION_CONFIG
        if config_file is None:
            raise ValueError("FASTVIDEO_ATTENTION_CONFIG is not set")
        STA_param = configure_sta(mode=STA_Mode.STA_INFERENCE,
                                  layer_num=layer_num,
                                  head_num=head_num,
                                  time_step_num=timesteps_num,
                                  load_path=config_file)

    batch.STA_param = STA_param
    batch.mask_search_final_result_pos = [[] for _ in range(timesteps_num)]
    batch.mask_search_final_result_neg = [[] for _ in range(timesteps_num)]
fastvideo.pipelines.stages.denoising.DenoisingStage.progress_bar
progress_bar(iterable: Iterable | None = None, total: int | None = None) -> tqdm

Create a progress bar for the denoising process.

Parameters:

Name Type Description Default
iterable Iterable | None

The iterable to iterate over.

None
total int | None

The total number of items.

None

Returns:

Type Description
tqdm

A tqdm progress bar.

Source code in fastvideo/pipelines/stages/denoising.py
def progress_bar(self,
                 iterable: Iterable | None = None,
                 total: int | None = None) -> tqdm:
    """
    Create a progress bar for the denoising process.

    Args:
        iterable: The iterable to iterate over.
        total: The total number of items.

    Returns:
        A tqdm progress bar.
    """
    local_rank = get_world_group().local_rank
    if local_rank == 0:
        return tqdm(iterable=iterable, total=total)
    else:
        return tqdm(iterable=iterable, total=total, disable=True)
fastvideo.pipelines.stages.denoising.DenoisingStage.rescale_noise_cfg
rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0) -> Tensor

Rescale noise prediction according to guidance_rescale.

Based on findings of "Common Diffusion Noise Schedules and Sample Steps are Flawed" (https://arxiv.org/pdf/2305.08891.pdf), Section 3.4.

Parameters:

Name Type Description Default
noise_cfg

The noise prediction with guidance.

required
noise_pred_text

The text-conditioned noise prediction.

required
guidance_rescale

The guidance rescale factor.

0.0

Returns:

Type Description
Tensor

The rescaled noise prediction.

Source code in fastvideo/pipelines/stages/denoising.py
def rescale_noise_cfg(self,
                      noise_cfg,
                      noise_pred_text,
                      guidance_rescale=0.0) -> torch.Tensor:
    """
    Rescale noise prediction according to guidance_rescale.

    Based on findings of "Common Diffusion Noise Schedules and Sample Steps are Flawed"
    (https://arxiv.org/pdf/2305.08891.pdf), Section 3.4.

    Args:
        noise_cfg: The noise prediction with guidance.
        noise_pred_text: The text-conditioned noise prediction.
        guidance_rescale: The guidance rescale factor.

    Returns:
        The rescaled noise prediction.
    """
    std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)),
                                   keepdim=True)
    std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)),
                            keepdim=True)
    # Rescale the results from guidance (fixes overexposure)
    noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
    # Mix with the original results from guidance by factor guidance_rescale
    noise_cfg = (guidance_rescale * noise_pred_rescaled +
                 (1 - guidance_rescale) * noise_cfg)
    return noise_cfg
fastvideo.pipelines.stages.denoising.DenoisingStage.save_sta_search_results
save_sta_search_results(batch: ForwardBatch)

Save the STA mask search results.

Parameters:

Name Type Description Default
batch ForwardBatch

The current batch information.

required
Source code in fastvideo/pipelines/stages/denoising.py
def save_sta_search_results(self, batch: ForwardBatch):
    """
    Save the STA mask search results.

    Args:
        batch: The current batch information.
    """
    size = (batch.width, batch.height)
    if size == (1280, 768):
        # TODO: make it configurable
        sparse_mask_candidates_searching = [
            "3, 1, 10", "1, 5, 7", "3, 3, 3", "1, 6, 5", "1, 3, 10",
            "3, 6, 1"
        ]
    else:
        raise NotImplementedError(
            "STA mask search is not supported for this resolution")

    from fastvideo.attention.backends.STA_configuration import save_mask_search_results
    if batch.mask_search_final_result_pos is not None and batch.prompt is not None:
        save_mask_search_results(
            [
                dict(layer_data)
                for layer_data in batch.mask_search_final_result_pos
            ],
            prompt=str(batch.prompt),
            mask_strategies=sparse_mask_candidates_searching,
            output_dir=f'output/mask_search_result_pos_{size[0]}x{size[1]}/'
        )
    if batch.mask_search_final_result_neg is not None and batch.prompt is not None:
        save_mask_search_results(
            [
                dict(layer_data)
                for layer_data in batch.mask_search_final_result_neg
            ],
            prompt=str(batch.prompt),
            mask_strategies=sparse_mask_candidates_searching,
            output_dir=f'output/mask_search_result_neg_{size[0]}x{size[1]}/'
        )
fastvideo.pipelines.stages.denoising.DenoisingStage.verify_input
verify_input(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult

Verify denoising stage inputs.

Source code in fastvideo/pipelines/stages/denoising.py
def verify_input(self, batch: ForwardBatch,
                 fastvideo_args: FastVideoArgs) -> VerificationResult:
    """Verify denoising stage inputs."""
    result = VerificationResult()
    result.add_check("timesteps", batch.timesteps,
                     [V.is_tensor, V.min_dims(1)])
    result.add_check("latents", batch.latents,
                     [V.is_tensor, V.with_dims(5)])
    result.add_check("prompt_embeds", batch.prompt_embeds, V.list_not_empty)
    result.add_check("image_embeds", batch.image_embeds, V.is_list)
    result.add_check("image_latent", batch.image_latent,
                     V.none_or_tensor_with_dims(5))
    result.add_check("num_inference_steps", batch.num_inference_steps,
                     V.positive_int)
    result.add_check("guidance_scale", batch.guidance_scale,
                     V.positive_float)
    result.add_check("eta", batch.eta, V.non_negative_float)
    result.add_check("generator", batch.generator,
                     V.generator_or_list_generators)
    result.add_check("do_classifier_free_guidance",
                     batch.do_classifier_free_guidance, V.bool_value)
    result.add_check(
        "negative_prompt_embeds", batch.negative_prompt_embeds, lambda x:
        not batch.do_classifier_free_guidance or V.list_not_empty(x))
    return result
fastvideo.pipelines.stages.denoising.DenoisingStage.verify_output
verify_output(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult

Verify denoising stage outputs.

Source code in fastvideo/pipelines/stages/denoising.py
def verify_output(self, batch: ForwardBatch,
                  fastvideo_args: FastVideoArgs) -> VerificationResult:
    """Verify denoising stage outputs."""
    result = VerificationResult()
    result.add_check("latents", batch.latents,
                     [V.is_tensor, V.with_dims(5)])
    return result
fastvideo.pipelines.stages.denoising.DmdDenoisingStage
DmdDenoisingStage(transformer, scheduler)

Bases: DenoisingStage

Denoising stage for DMD.

Source code in fastvideo/pipelines/stages/denoising.py
def __init__(self, transformer, scheduler) -> None:
    super().__init__(transformer, scheduler)
    self.scheduler = FlowMatchEulerDiscreteScheduler(shift=8.0)
Functions
fastvideo.pipelines.stages.denoising.DmdDenoisingStage.forward
forward(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> ForwardBatch

Run the denoising loop.

Parameters:

Name Type Description Default
batch ForwardBatch

The current batch information.

required
fastvideo_args FastVideoArgs

The inference arguments.

required

Returns:

Type Description
ForwardBatch

The batch with denoised latents.

Source code in fastvideo/pipelines/stages/denoising.py
def forward(
    self,
    batch: ForwardBatch,
    fastvideo_args: FastVideoArgs,
) -> ForwardBatch:
    """
    Run the denoising loop.

    Args:
        batch: The current batch information.
        fastvideo_args: The inference arguments.

    Returns:
        The batch with denoised latents.
    """
    # Setup precision and autocast settings
    # TODO(will): make the precision configurable for inference
    # target_dtype = PRECISION_TO_TYPE[fastvideo_args.precision]
    target_dtype = torch.bfloat16
    autocast_enabled = (target_dtype != torch.float32
                        ) and not fastvideo_args.disable_autocast

    # Get timesteps and calculate warmup steps
    timesteps = batch.timesteps

    # TODO(will): remove this once we add input/output validation for stages
    if timesteps is None:
        raise ValueError("Timesteps must be provided")
    num_inference_steps = batch.num_inference_steps
    num_warmup_steps = len(
        timesteps) - num_inference_steps * self.scheduler.order

    # Prepare image latents and embeddings for I2V generation
    image_embeds = batch.image_embeds
    if len(image_embeds) > 0:
        assert torch.isnan(image_embeds[0]).sum() == 0
        image_embeds = [
            image_embed.to(target_dtype) for image_embed in image_embeds
        ]

    image_kwargs = self.prepare_extra_func_kwargs(
        self.transformer.forward,
        {
            "encoder_hidden_states_image": image_embeds,
            "mask_strategy": dict_to_3d_list(
                None, t_max=50, l_max=60, h_max=24)
        },
    )

    pos_cond_kwargs = self.prepare_extra_func_kwargs(
        self.transformer.forward,
        {
            "encoder_hidden_states_2": batch.clip_embedding_pos,
            "encoder_attention_mask": batch.prompt_attention_mask,
        },
    )

    # Prepare STA parameters
    if st_attn_available and self.attn_backend == SlidingTileAttentionBackend:
        self.prepare_sta_param(batch, fastvideo_args)

    # Get latents and embeddings
    assert batch.latents is not None, "latents must be provided"
    latents = batch.latents

    video_raw_latent_shape = latents.shape
    prompt_embeds = batch.prompt_embeds
    assert not torch.isnan(
        prompt_embeds[0]).any(), "prompt_embeds contains nan"
    timesteps = torch.tensor(
        fastvideo_args.pipeline_config.dmd_denoising_steps,
        dtype=torch.long,
        device=get_local_torch_device())

    # Run denoising loop
    with self.progress_bar(total=len(timesteps)) as progress_bar:
        for i, t in enumerate(timesteps):
            # Skip if interrupted
            if hasattr(self, 'interrupt') and self.interrupt:
                continue
            # Expand latents for I2V
            noise_latents = latents.clone()
            latent_model_input = latents.to(target_dtype)

            if batch.image_latent is not None:
                latent_model_input = torch.cat([
                    latent_model_input,
                    batch.image_latent.permute(0, 2, 1, 3, 4)
                ],
                                               dim=2).to(target_dtype)
            assert not torch.isnan(
                latent_model_input).any(), "latent_model_input contains nan"

            # Prepare inputs for transformer
            t_expand = t.repeat(latent_model_input.shape[0])
            guidance_expand = (
                torch.tensor(
                    [fastvideo_args.pipeline_config.embedded_cfg_scale] *
                    latent_model_input.shape[0],
                    dtype=torch.float32,
                    device=get_local_torch_device(),
                ).to(target_dtype) *
                1000.0 if fastvideo_args.pipeline_config.embedded_cfg_scale
                is not None else None)

            # Predict noise residual
            with torch.autocast(device_type="cuda",
                                dtype=target_dtype,
                                enabled=autocast_enabled):
                if (vsa_available and self.attn_backend
                        == VideoSparseAttentionBackend):
                    self.attn_metadata_builder_cls = self.attn_backend.get_builder_cls(
                    )

                    if self.attn_metadata_builder_cls is not None:
                        self.attn_metadata_builder = self.attn_metadata_builder_cls(
                        )
                        # TODO(will): clean this up
                        attn_metadata = self.attn_metadata_builder.build(  # type: ignore
                            current_timestep=i,  # type: ignore
                            raw_latent_shape=batch.
                            raw_latent_shape[2:5],  # type: ignore
                            patch_size=fastvideo_args.
                            pipeline_config.  # type: ignore
                            dit_config.patch_size,  # type: ignore
                            STA_param=batch.STA_param,  # type: ignore
                            VSA_sparsity=fastvideo_args.
                            VSA_sparsity,  # type: ignore
                            device=get_local_torch_device(),  # type: ignore
                        )  # type: ignore
                        assert attn_metadata is not None, "attn_metadata cannot be None"
                    else:
                        attn_metadata = None
                else:
                    attn_metadata = None

                batch.is_cfg_negative = False
                with set_forward_context(
                        current_timestep=i,
                        attn_metadata=attn_metadata,
                        forward_batch=batch,
                        # fastvideo_args=fastvideo_args
                ):
                    # Run transformer
                    pred_noise = self.transformer(
                        latent_model_input.permute(0, 2, 1, 3, 4),
                        prompt_embeds,
                        t_expand,
                        guidance=guidance_expand,
                        **image_kwargs,
                        **pos_cond_kwargs,
                    ).permute(0, 2, 1, 3, 4)

                pred_video = pred_noise_to_pred_video(
                    pred_noise=pred_noise.flatten(0, 1),
                    noise_input_latent=noise_latents.flatten(0, 1),
                    timestep=t_expand,
                    scheduler=self.scheduler).unflatten(
                        0, pred_noise.shape[:2])

                if i < len(timesteps) - 1:
                    next_timestep = timesteps[i + 1] * torch.ones(
                        [1], dtype=torch.long, device=pred_video.device)
                    noise = torch.randn(video_raw_latent_shape,
                                        dtype=pred_video.dtype,
                                        generator=batch.generator[0]).to(
                                            self.device)
                    latents = self.scheduler.add_noise(
                        pred_video.flatten(0, 1), noise.flatten(0, 1),
                        next_timestep).unflatten(0, pred_video.shape[:2])
                else:
                    latents = pred_video

                # Update progress bar
                if i == len(timesteps) - 1 or (
                    (i + 1) > num_warmup_steps and
                    (i + 1) % self.scheduler.order == 0
                        and progress_bar is not None):
                    progress_bar.update()

    # Gather results if using sequence parallelism
    latents = latents.permute(0, 2, 1, 3, 4)
    # Update batch with final latents
    batch.latents = latents

    return batch

Functions

fastvideo.pipelines.stages.encoding

Encoding stage for diffusion pipelines.

Classes

fastvideo.pipelines.stages.encoding.EncodingStage
EncodingStage(vae: ParallelTiledVAE)

Bases: PipelineStage

Stage for encoding pixel space representations into latent space.

This stage handles the encoding of pixel-space video/images into latent representations for further processing in the diffusion pipeline.

Source code in fastvideo/pipelines/stages/encoding.py
def __init__(self, vae: ParallelTiledVAE) -> None:
    self.vae: ParallelTiledVAE = vae
Functions
fastvideo.pipelines.stages.encoding.EncodingStage.forward
forward(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> ForwardBatch

Encode pixel space representations into latent space.

Parameters:

Name Type Description Default
batch ForwardBatch

The current batch information.

required
fastvideo_args FastVideoArgs

The inference arguments.

required

Returns:

Type Description
ForwardBatch

The batch with encoded latents.

Source code in fastvideo/pipelines/stages/encoding.py
def forward(
    self,
    batch: ForwardBatch,
    fastvideo_args: FastVideoArgs,
) -> ForwardBatch:
    """
    Encode pixel space representations into latent space.

    Args:
        batch: The current batch information.
        fastvideo_args: The inference arguments.

    Returns:
        The batch with encoded latents.
    """
    assert batch.latents is not None and isinstance(batch.latents,
                                                    torch.Tensor)

    self.vae = self.vae.to(get_local_torch_device())

    # Setup VAE precision
    vae_dtype = PRECISION_TO_TYPE[
        fastvideo_args.pipeline_config.vae_precision]
    vae_autocast_enabled = (
        vae_dtype != torch.float32) and not fastvideo_args.disable_autocast

    # Normalize input to [-1, 1] range (reverse of decoding normalization)
    latents = (batch.latents * 2.0 - 1.0).clamp(-1, 1)

    # Move to appropriate device and dtype
    latents = latents.to(get_local_torch_device())

    # Encode image to latents
    with torch.autocast(device_type="cuda",
                        dtype=vae_dtype,
                        enabled=vae_autocast_enabled):
        if fastvideo_args.pipeline_config.vae_tiling:
            self.vae.enable_tiling()
        # if fastvideo_args.vae_sp:
        #     self.vae.enable_parallel()
        if not vae_autocast_enabled:
            latents = latents.to(vae_dtype)
        latents = self.vae.encode(latents).mean

    # Update batch with encoded latents
    batch.latents = latents

    # Offload models if needed
    if hasattr(self, 'maybe_free_model_hooks'):
        self.maybe_free_model_hooks()

    if fastvideo_args.vae_cpu_offload:
        self.vae.to("cpu")

    return batch
fastvideo.pipelines.stages.encoding.EncodingStage.verify_input
verify_input(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult

Verify encoding stage inputs.

Source code in fastvideo/pipelines/stages/encoding.py
@torch.no_grad()
def verify_input(self, batch: ForwardBatch,
                 fastvideo_args: FastVideoArgs) -> VerificationResult:
    """Verify encoding stage inputs."""
    result = VerificationResult()
    # Input video/images for VAE encoding: [batch_size, channels, frames, height, width]
    result.add_check("latents", batch.latents,
                     [V.is_tensor, V.with_dims(5)])
    return result
fastvideo.pipelines.stages.encoding.EncodingStage.verify_output
verify_output(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult

Verify encoding stage outputs.

Source code in fastvideo/pipelines/stages/encoding.py
def verify_output(self, batch: ForwardBatch,
                  fastvideo_args: FastVideoArgs) -> VerificationResult:
    """Verify encoding stage outputs."""
    result = VerificationResult()
    # Encoded latents: [batch_size, channels, frames, height_latents, width_latents]
    result.add_check("latents", batch.latents,
                     [V.is_tensor, V.with_dims(5)])
    return result

Functions

fastvideo.pipelines.stages.image_encoding

Image and video encoding stages for diffusion pipelines.

This module contains implementations of encoding stages for diffusion pipelines: - ImageEncodingStage: Encodes images using image encoders (e.g., CLIP) - RefImageEncodingStage: Encodes reference image for Wan2.1 control pipeline - ImageVAEEncodingStage: Encodes images to latent space using VAE for I2V generation - VideoVAEEncodingStage: Encodes videos to latent space using VAE for V2V and control tasks

Classes

fastvideo.pipelines.stages.image_encoding.Hy15ImageEncodingStage
Hy15ImageEncodingStage(image_encoder, image_processor)

Bases: ImageEncodingStage

Stage for encoding image prompts into embeddings for HunyuanVideo1.5 models.

Source code in fastvideo/pipelines/stages/image_encoding.py
def __init__(self, image_encoder, image_processor) -> None:
    """
    Initialize the prompt encoding stage.

    Args:
        enable_logging: Whether to enable logging for this stage.
        is_secondary: Whether this is a secondary image encoder.
    """
    super().__init__()
    self.image_processor = image_processor
    self.image_encoder = image_encoder
Functions
fastvideo.pipelines.stages.image_encoding.Hy15ImageEncodingStage.forward
forward(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> ForwardBatch

Encode the prompt into image encoder hidden states.

Source code in fastvideo/pipelines/stages/image_encoding.py
def forward(self, batch: ForwardBatch,
            fastvideo_args: FastVideoArgs) -> ForwardBatch:
    """
    Encode the prompt into image encoder hidden states.
    """
    if batch.pil_image is None:
        batch.image_embeds = [
            torch.zeros(1, 729, 1152, device=get_local_torch_device())
        ]

    raw_latent_shape = list(batch.raw_latent_shape)
    raw_latent_shape[1] = 1
    batch.video_latent = torch.zeros(tuple(raw_latent_shape),
                                     device=get_local_torch_device())
    return batch
fastvideo.pipelines.stages.image_encoding.Hy15ImageEncodingStage.verify_input
verify_input(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult

Verify image encoding stage inputs.

Source code in fastvideo/pipelines/stages/image_encoding.py
def verify_input(self, batch: ForwardBatch,
                 fastvideo_args: FastVideoArgs) -> VerificationResult:
    """Verify image encoding stage inputs."""
    return VerificationResult()
fastvideo.pipelines.stages.image_encoding.ImageEncodingStage
ImageEncodingStage(image_encoder, image_processor)

Bases: PipelineStage

Stage for encoding image prompts into embeddings for diffusion models.

This stage handles the encoding of image prompts into the embedding space expected by the diffusion model.

Initialize the prompt encoding stage.

Parameters:

Name Type Description Default
enable_logging

Whether to enable logging for this stage.

required
is_secondary

Whether this is a secondary image encoder.

required
Source code in fastvideo/pipelines/stages/image_encoding.py
def __init__(self, image_encoder, image_processor) -> None:
    """
    Initialize the prompt encoding stage.

    Args:
        enable_logging: Whether to enable logging for this stage.
        is_secondary: Whether this is a secondary image encoder.
    """
    super().__init__()
    self.image_processor = image_processor
    self.image_encoder = image_encoder
Functions
fastvideo.pipelines.stages.image_encoding.ImageEncodingStage.forward
forward(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> ForwardBatch

Encode the prompt into image encoder hidden states.

Parameters:

Name Type Description Default
batch ForwardBatch

The current batch information.

required
fastvideo_args FastVideoArgs

The inference arguments.

required

Returns:

Type Description
ForwardBatch

The batch with encoded prompt embeddings.

Source code in fastvideo/pipelines/stages/image_encoding.py
@torch.no_grad()
def forward(
    self,
    batch: ForwardBatch,
    fastvideo_args: FastVideoArgs,
) -> ForwardBatch:
    """
    Encode the prompt into image encoder hidden states.

    Args:
        batch: The current batch information.
        fastvideo_args: The inference arguments.

    Returns:
        The batch with encoded prompt embeddings.
    """
    self.image_encoder = self.image_encoder.to(get_local_torch_device())

    image = batch.pil_image

    image_inputs = self.image_processor(
        images=image, return_tensors="pt").to(get_local_torch_device())
    with set_forward_context(current_timestep=0, attn_metadata=None):
        outputs = self.image_encoder(**image_inputs)
        image_embeds = outputs.last_hidden_state

    batch.image_embeds.append(image_embeds)

    if fastvideo_args.image_encoder_cpu_offload:
        self.image_encoder.to('cpu')

    return batch
fastvideo.pipelines.stages.image_encoding.ImageEncodingStage.verify_input
verify_input(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult

Verify image encoding stage inputs.

Source code in fastvideo/pipelines/stages/image_encoding.py
def verify_input(self, batch: ForwardBatch,
                 fastvideo_args: FastVideoArgs) -> VerificationResult:
    """Verify image encoding stage inputs."""
    result = VerificationResult()
    result.add_check("pil_image", batch.pil_image, V.not_none)
    result.add_check("image_embeds", batch.image_embeds, V.is_list)
    return result
fastvideo.pipelines.stages.image_encoding.ImageEncodingStage.verify_output
verify_output(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult

Verify image encoding stage outputs.

Source code in fastvideo/pipelines/stages/image_encoding.py
def verify_output(self, batch: ForwardBatch,
                  fastvideo_args: FastVideoArgs) -> VerificationResult:
    """Verify image encoding stage outputs."""
    result = VerificationResult()
    result.add_check("image_embeds", batch.image_embeds,
                     V.list_of_tensors_dims(3))
    return result
fastvideo.pipelines.stages.image_encoding.ImageVAEEncodingStage
ImageVAEEncodingStage(vae: ParallelTiledVAE)

Bases: PipelineStage

Stage for encoding image pixel representations into latent space.

This stage handles the encoding of image pixel representations into the final input format (e.g., latents) for image-to-video generation.

Source code in fastvideo/pipelines/stages/image_encoding.py
def __init__(self, vae: ParallelTiledVAE) -> None:
    self.vae: ParallelTiledVAE = vae
Functions
fastvideo.pipelines.stages.image_encoding.ImageVAEEncodingStage.forward
forward(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> ForwardBatch

Encode pixel representations into latent space.

Parameters:

Name Type Description Default
batch ForwardBatch

The current batch information.

required
fastvideo_args FastVideoArgs

The inference arguments.

required

Returns:

Type Description
ForwardBatch

The batch with encoded outputs.

Source code in fastvideo/pipelines/stages/image_encoding.py
def forward(
    self,
    batch: ForwardBatch,
    fastvideo_args: FastVideoArgs,
) -> ForwardBatch:
    """
    Encode pixel representations into latent space.

    Args:
        batch: The current batch information.
        fastvideo_args: The inference arguments.

    Returns:
        The batch with encoded outputs.
    """
    assert batch.pil_image is not None
    if fastvideo_args.mode == ExecutionMode.INFERENCE:
        assert batch.pil_image is not None and isinstance(
            batch.pil_image, PIL.Image.Image)
        assert batch.height is not None and isinstance(batch.height, int)
        assert batch.width is not None and isinstance(batch.width, int)
        assert batch.num_frames is not None and isinstance(
            batch.num_frames, int)
        height = batch.height
        width = batch.width
        num_frames = batch.num_frames
    elif fastvideo_args.mode == ExecutionMode.PREPROCESS:
        assert batch.pil_image is not None and isinstance(
            batch.pil_image, torch.Tensor)
        assert batch.height is not None and isinstance(batch.height, list)
        assert batch.width is not None and isinstance(batch.width, list)
        assert batch.num_frames is not None and isinstance(
            batch.num_frames, list)
        num_frames = batch.num_frames[0]
        height = batch.height[0]
        width = batch.width[0]

    self.vae = self.vae.to(get_local_torch_device())

    # Process single image for I2V
    latent_height = height // self.vae.spatial_compression_ratio
    latent_width = width // self.vae.spatial_compression_ratio
    image = batch.pil_image
    image = self.preprocess(
        image,
        vae_scale_factor=self.vae.spatial_compression_ratio,
        height=height,
        width=width).to(get_local_torch_device(), dtype=torch.float32)

    # (B, C, H, W) -> (B, C, 1, H, W)
    image = image.unsqueeze(2)

    video_condition = torch.cat([
        image,
        image.new_zeros(image.shape[0], image.shape[1], num_frames - 1,
                        image.shape[3], image.shape[4])
    ],
                                dim=2)
    video_condition = video_condition.to(device=get_local_torch_device(),
                                         dtype=torch.float32)

    # Setup VAE precision
    vae_dtype = PRECISION_TO_TYPE[
        fastvideo_args.pipeline_config.vae_precision]
    vae_autocast_enabled = (
        vae_dtype != torch.float32) and not fastvideo_args.disable_autocast

    # Encode Image
    with torch.autocast(device_type="cuda",
                        dtype=vae_dtype,
                        enabled=vae_autocast_enabled):
        if fastvideo_args.pipeline_config.vae_tiling:
            self.vae.enable_tiling()
        # if fastvideo_args.vae_sp:
        #     self.vae.enable_parallel()
        if not vae_autocast_enabled:
            video_condition = video_condition.to(vae_dtype)
        encoder_output = self.vae.encode(video_condition)

    if fastvideo_args.mode == ExecutionMode.PREPROCESS:
        latent_condition = encoder_output.mean
    else:
        generator = batch.generator
        if generator is None:
            raise ValueError("Generator must be provided")
        latent_condition = self.retrieve_latents(encoder_output, generator)

    # Apply shifting if needed
    if (hasattr(self.vae, "shift_factor")
            and self.vae.shift_factor is not None):
        if isinstance(self.vae.shift_factor, torch.Tensor):
            latent_condition -= self.vae.shift_factor.to(
                latent_condition.device, latent_condition.dtype)
        else:
            latent_condition -= self.vae.shift_factor

    if isinstance(self.vae.scaling_factor, torch.Tensor):
        latent_condition = latent_condition * self.vae.scaling_factor.to(
            latent_condition.device, latent_condition.dtype)
    else:
        latent_condition = latent_condition * self.vae.scaling_factor

    if fastvideo_args.mode == ExecutionMode.PREPROCESS:
        batch.image_latent = latent_condition
    else:
        mask_lat_size = torch.ones(1, 1, num_frames, latent_height,
                                   latent_width)
        mask_lat_size[:, :, list(range(1, num_frames))] = 0
        first_frame_mask = mask_lat_size[:, :, 0:1]
        first_frame_mask = torch.repeat_interleave(
            first_frame_mask,
            dim=2,
            repeats=self.vae.temporal_compression_ratio)
        mask_lat_size = torch.concat(
            [first_frame_mask, mask_lat_size[:, :, 1:, :]], dim=2)
        mask_lat_size = mask_lat_size.view(
            1, -1, self.vae.temporal_compression_ratio, latent_height,
            latent_width)
        mask_lat_size = mask_lat_size.transpose(1, 2)
        mask_lat_size = mask_lat_size.to(latent_condition.device)

        batch.image_latent = torch.concat([mask_lat_size, latent_condition],
                                          dim=1)

    # Offload models if needed
    if hasattr(self, 'maybe_free_model_hooks'):
        self.maybe_free_model_hooks()

    self.vae.to("cpu")

    return batch
fastvideo.pipelines.stages.image_encoding.ImageVAEEncodingStage.verify_input
verify_input(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult

Verify encoding stage inputs.

Source code in fastvideo/pipelines/stages/image_encoding.py
def verify_input(self, batch: ForwardBatch,
                 fastvideo_args: FastVideoArgs) -> VerificationResult:
    """Verify encoding stage inputs."""
    result = VerificationResult()
    result.add_check("generator", batch.generator,
                     V.generator_or_list_generators)
    if fastvideo_args.mode == ExecutionMode.PREPROCESS:
        result.add_check("height", batch.height, V.list_not_empty)
        result.add_check("width", batch.width, V.list_not_empty)
        result.add_check("num_frames", batch.num_frames, V.list_not_empty)
    else:
        result.add_check("height", batch.height, V.positive_int)
        result.add_check("width", batch.width, V.positive_int)
        result.add_check("num_frames", batch.num_frames, V.positive_int)
    return result
fastvideo.pipelines.stages.image_encoding.ImageVAEEncodingStage.verify_output
verify_output(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult

Verify encoding stage outputs.

Source code in fastvideo/pipelines/stages/image_encoding.py
def verify_output(self, batch: ForwardBatch,
                  fastvideo_args: FastVideoArgs) -> VerificationResult:
    """Verify encoding stage outputs."""
    result = VerificationResult()
    result.add_check("image_latent", batch.image_latent,
                     [V.is_tensor, V.with_dims(5)])
    return result
fastvideo.pipelines.stages.image_encoding.RefImageEncodingStage
RefImageEncodingStage(image_encoder, image_processor)

Bases: ImageEncodingStage

Stage for encoding reference image prompts into embeddings for Wan2.1 Control models.

This stage extends ImageEncodingStage with specialized preprocessing for reference images.

Source code in fastvideo/pipelines/stages/image_encoding.py
def __init__(self, image_encoder, image_processor) -> None:
    """
    Initialize the prompt encoding stage.

    Args:
        enable_logging: Whether to enable logging for this stage.
        is_secondary: Whether this is a secondary image encoder.
    """
    super().__init__()
    self.image_processor = image_processor
    self.image_encoder = image_encoder
Functions
fastvideo.pipelines.stages.image_encoding.RefImageEncodingStage.forward
forward(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> ForwardBatch

Encode the prompt into image encoder hidden states.

Parameters:

Name Type Description Default
batch ForwardBatch

The current batch information.

required
fastvideo_args FastVideoArgs

The inference arguments.

required

Returns:

Type Description
ForwardBatch

The batch with encoded prompt embeddings.

Source code in fastvideo/pipelines/stages/image_encoding.py
@torch.no_grad()
def forward(
    self,
    batch: ForwardBatch,
    fastvideo_args: FastVideoArgs,
) -> ForwardBatch:
    """
    Encode the prompt into image encoder hidden states.

    Args:
        batch: The current batch information.
        fastvideo_args: The inference arguments.

    Returns:
        The batch with encoded prompt embeddings.
    """
    self.image_encoder = self.image_encoder.to(get_local_torch_device())

    image = batch.pil_image
    if image is None:
        image = create_default_image()
    # Preprocess reference image for CLIP encoder
    image_tensor = preprocess_reference_image_for_clip(
        image, get_local_torch_device())

    image_inputs = self.image_processor(images=image_tensor,
                                        return_tensors="pt").to(
                                            get_local_torch_device())
    with set_forward_context(current_timestep=0, attn_metadata=None):
        outputs = self.image_encoder(**image_inputs)
        image_embeds = outputs.last_hidden_state
    batch.image_embeds.append(image_embeds)

    if batch.pil_image is None:
        batch.image_embeds = [
            torch.zeros_like(x) for x in batch.image_embeds
        ]

    return batch
fastvideo.pipelines.stages.image_encoding.VideoVAEEncodingStage
VideoVAEEncodingStage(vae: ParallelTiledVAE)

Bases: ImageVAEEncodingStage

Stage for encoding video pixel representations into latent space.

This stage handles the encoding of video pixel representations for video-to-video generation and control. Inherits from ImageVAEEncodingStage to reuse common functionality.

Source code in fastvideo/pipelines/stages/image_encoding.py
def __init__(self, vae: ParallelTiledVAE) -> None:
    self.vae: ParallelTiledVAE = vae
Functions
fastvideo.pipelines.stages.image_encoding.VideoVAEEncodingStage.forward
forward(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> ForwardBatch

Encode video pixel representations into latent space.

Parameters:

Name Type Description Default
batch ForwardBatch

The current batch information.

required
fastvideo_args FastVideoArgs

The inference arguments.

required

Returns:

Type Description
ForwardBatch

The batch with encoded outputs.

Source code in fastvideo/pipelines/stages/image_encoding.py
def forward(
    self,
    batch: ForwardBatch,
    fastvideo_args: FastVideoArgs,
) -> ForwardBatch:
    """
    Encode video pixel representations into latent space.

    Args:
        batch: The current batch information.
        fastvideo_args: The inference arguments.

    Returns:
        The batch with encoded outputs.
    """
    assert batch.video_latent is not None, "Video latent input is required for VideoVAEEncodingStage"

    if fastvideo_args.mode == ExecutionMode.INFERENCE:
        assert batch.height is not None and isinstance(batch.height, int)
        assert batch.width is not None and isinstance(batch.width, int)
        assert batch.num_frames is not None and isinstance(
            batch.num_frames, int)
        height = batch.height
        width = batch.width
        num_frames = batch.num_frames
    elif fastvideo_args.mode == ExecutionMode.PREPROCESS:
        assert batch.height is not None and isinstance(batch.height, list)
        assert batch.width is not None and isinstance(batch.width, list)
        assert batch.num_frames is not None and isinstance(
            batch.num_frames, list)
        num_frames = batch.num_frames[0]
        height = batch.height[0]
        width = batch.width[0]

    self.vae = self.vae.to(get_local_torch_device())

    # Prepare video tensor from control video
    video_condition = self._prepare_control_video_tensor(
        batch.video_latent, num_frames, height,
        width).to(get_local_torch_device(), dtype=torch.float32)

    # Setup VAE precision
    vae_dtype = PRECISION_TO_TYPE[
        fastvideo_args.pipeline_config.vae_precision]
    vae_autocast_enabled = (
        vae_dtype != torch.float32) and not fastvideo_args.disable_autocast

    # Encode control video
    with torch.autocast(device_type="cuda",
                        dtype=vae_dtype,
                        enabled=vae_autocast_enabled):
        if fastvideo_args.pipeline_config.vae_tiling:
            self.vae.enable_tiling()
        if not vae_autocast_enabled:
            video_condition = video_condition.to(vae_dtype)
        encoder_output = self.vae.encode(video_condition)

    generator = batch.generator
    if generator is None:
        raise ValueError("Generator must be provided")
    latent_condition = self.retrieve_latents(encoder_output, generator)

    if (hasattr(self.vae, "shift_factor")
            and self.vae.shift_factor is not None):
        if isinstance(self.vae.shift_factor, torch.Tensor):
            latent_condition -= self.vae.shift_factor.to(
                latent_condition.device, latent_condition.dtype)
        else:
            latent_condition -= self.vae.shift_factor

    if isinstance(self.vae.scaling_factor, torch.Tensor):
        latent_condition = latent_condition * self.vae.scaling_factor.to(
            latent_condition.device, latent_condition.dtype)
    else:
        latent_condition = latent_condition * self.vae.scaling_factor

    batch.video_latent = latent_condition

    # Offload models if needed
    if hasattr(self, 'maybe_free_model_hooks'):
        self.maybe_free_model_hooks()

    self.vae.to("cpu")

    return batch
fastvideo.pipelines.stages.image_encoding.VideoVAEEncodingStage.verify_input
verify_input(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult

Verify video encoding stage inputs.

Source code in fastvideo/pipelines/stages/image_encoding.py
def verify_input(self, batch: ForwardBatch,
                 fastvideo_args: FastVideoArgs) -> VerificationResult:
    """Verify video encoding stage inputs."""
    result = VerificationResult()
    result.add_check("video_latent", batch.video_latent, V.not_none)
    result.add_check("generator", batch.generator,
                     V.generator_or_list_generators)
    if fastvideo_args.mode == ExecutionMode.PREPROCESS:
        result.add_check("height", batch.height, V.list_not_empty)
        result.add_check("width", batch.width, V.list_not_empty)
        result.add_check("num_frames", batch.num_frames, V.list_not_empty)
    else:
        result.add_check("height", batch.height, V.positive_int)
        result.add_check("width", batch.width, V.positive_int)
        result.add_check("num_frames", batch.num_frames, V.positive_int)
    return result
fastvideo.pipelines.stages.image_encoding.VideoVAEEncodingStage.verify_output
verify_output(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult

Verify video encoding stage outputs.

Source code in fastvideo/pipelines/stages/image_encoding.py
def verify_output(self, batch: ForwardBatch,
                  fastvideo_args: FastVideoArgs) -> VerificationResult:
    """Verify video encoding stage outputs."""
    result = VerificationResult()
    result.add_check("video_latent", batch.video_latent,
                     [V.is_tensor, V.with_dims(5)])
    return result

Functions

fastvideo.pipelines.stages.input_validation

Input validation stage for diffusion pipelines.

Classes

fastvideo.pipelines.stages.input_validation.InputValidationStage

Bases: PipelineStage

Stage for validating and preparing inputs for diffusion pipelines.

This stage validates that all required inputs are present and properly formatted before proceeding with the diffusion process.

Functions
fastvideo.pipelines.stages.input_validation.InputValidationStage.forward
forward(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> ForwardBatch

Validate and prepare inputs.

Parameters:

Name Type Description Default
batch ForwardBatch

The current batch information.

required
fastvideo_args FastVideoArgs

The inference arguments.

required

Returns:

Type Description
ForwardBatch

The validated batch information.

Source code in fastvideo/pipelines/stages/input_validation.py
def forward(
    self,
    batch: ForwardBatch,
    fastvideo_args: FastVideoArgs,
) -> ForwardBatch:
    """
    Validate and prepare inputs.

    Args:
        batch: The current batch information.
        fastvideo_args: The inference arguments.

    Returns:
        The validated batch information.
    """

    self._generate_seeds(batch, fastvideo_args)

    # Ensure prompt is properly formatted
    if batch.prompt is None and batch.prompt_embeds is None:
        raise ValueError(
            "Either `prompt` or `prompt_embeds` must be provided")

    # Ensure negative prompt is properly formatted if using classifier-free guidance
    if (batch.do_classifier_free_guidance and batch.negative_prompt is None
            and batch.negative_prompt_embeds is None):
        raise ValueError(
            "For classifier-free guidance, either `negative_prompt` or "
            "`negative_prompt_embeds` must be provided")

    # Validate height and width
    if batch.height is None or batch.width is None:
        raise ValueError(
            "Height and width must be provided. Please set `height` and `width`."
        )
    if batch.height % 8 != 0 or batch.width % 8 != 0:
        raise ValueError(
            f"Height and width must be divisible by 8 but are {batch.height} and {batch.width}."
        )

    # Validate number of inference steps
    if batch.num_inference_steps <= 0:
        raise ValueError(
            f"Number of inference steps must be positive, but got {batch.num_inference_steps}"
        )

    # Validate guidance scale if using classifier-free guidance
    if batch.do_classifier_free_guidance and batch.guidance_scale <= 0:
        raise ValueError(
            f"Guidance scale must be positive, but got {batch.guidance_scale}"
        )

    # for i2v, get image from image_path
    # @TODO(Wei) hard-coded for wan2.2 5b ti2v for now. Should put this in image_encoding stage
    if batch.image_path is not None:
        if batch.image_path.endswith(".mp4"):
            image = load_video(batch.image_path)[0]
        else:
            image = load_image(batch.image_path)
        batch.pil_image = image

    # further processing for ti2v task
    if (fastvideo_args.pipeline_config.ti2v_task
            or fastvideo_args.pipeline_config.is_causal
        ) and batch.pil_image is not None:
        img = batch.pil_image
        ih, iw = img.height, img.width

        pipeline_class_name = type(fastvideo_args.pipeline_config).__name__
        if 'MatrixGame' in pipeline_class_name or 'MatrixCausal' in pipeline_class_name:
            oh, ow = batch.height, batch.width
            img = img.resize((ow, oh), Image.LANCZOS)
        else:
            # Standard Wan logic
            patch_size = fastvideo_args.pipeline_config.dit_config.arch_config.patch_size
            vae_stride = fastvideo_args.pipeline_config.vae_config.arch_config.scale_factor_spatial
            dh, dw = patch_size[1] * vae_stride, patch_size[2] * vae_stride
            max_area = 480 * 832
            ow, oh = best_output_size(iw, ih, dw, dh, max_area)

            scale = max(ow / iw, oh / ih)
            img = img.resize((round(iw * scale), round(ih * scale)),
                             Image.LANCZOS)

            # center-crop
            x1 = (img.width - ow) // 2
            y1 = (img.height - oh) // 2
            img = img.crop((x1, y1, x1 + ow, y1 + oh))

        assert img.width == ow and img.height == oh
        logger.info("final processed img height: %s, img width: %s",
                    img.height, img.width)

        # to tensor
        img = TF.to_tensor(img).sub_(0.5).div_(0.5).to(
            self.device).unsqueeze(1)
        img = img.unsqueeze(0)
        batch.height = oh
        batch.width = ow
        batch.pil_image = img

    # for v2v, get control video from video path
    if batch.video_path is not None:
        pil_images, original_fps = load_video(batch.video_path,
                                              return_fps=True)
        logger.info("Loaded video with %s frames, original FPS: %s",
                    len(pil_images), original_fps)

        # Get target parameters from batch
        target_fps = batch.fps
        target_num_frames = batch.num_frames
        target_height = batch.height
        target_width = batch.width

        if target_fps is not None and original_fps is not None:
            frame_skip = max(1, int(original_fps // target_fps))
            if frame_skip > 1:
                pil_images = pil_images[::frame_skip]
                effective_fps = original_fps / frame_skip
                logger.info(
                    "Resampled video from %.1f fps to %.1f fps (skip=%s)",
                    original_fps, effective_fps, frame_skip)

        # Limit to target number of frames
        if target_num_frames is not None and len(
                pil_images) > target_num_frames:
            pil_images = pil_images[:target_num_frames]
            logger.info("Limited video to %s frames (from %s total)",
                        target_num_frames, len(pil_images))

        # Resize each PIL image to target dimensions
        resized_images = []
        for pil_img in pil_images:
            resized_img = resize(pil_img,
                                 target_height,
                                 target_width,
                                 resize_mode="default",
                                 resample="lanczos")
            resized_images.append(resized_img)

        # Convert PIL images to numpy array
        video_numpy = pil_to_numpy(resized_images)
        video_numpy = normalize(video_numpy)
        video_tensor = numpy_to_pt(video_numpy)

        # Rearrange to [C, T, H, W] and add batch dimension -> [B, C, T, H, W]
        input_video = video_tensor.permute(1, 0, 2, 3).unsqueeze(0)

        batch.video_latent = input_video

    # Validate action control inputs (Matrix-Game)
    if batch.mouse_cond is not None:
        if batch.mouse_cond.dim() != 3 or batch.mouse_cond.shape[-1] != 2:
            raise ValueError(
                f"mouse_cond must have shape (B, T, 2), but got {batch.mouse_cond.shape}"
            )
        logger.info("Action control: mouse_cond validated - shape %s",
                    batch.mouse_cond.shape)

    if batch.keyboard_cond is not None:
        if batch.keyboard_cond.dim() != 3:
            raise ValueError(
                f"keyboard_cond must have 3 dimensions (B, T, K), but got {batch.keyboard_cond.dim()}"
            )
        keyboard_dim = batch.keyboard_cond.shape[-1]
        if keyboard_dim not in {2, 4, 6, 7}:
            raise ValueError(
                f"keyboard_cond last dimension must be 2, 4, 6, or 7, but got {keyboard_dim}"
            )
        logger.info(
            "Action control: keyboard_cond validated - shape %s (dim=%d)",
            batch.keyboard_cond.shape, keyboard_dim)

    if batch.grid_sizes is not None:
        if not isinstance(batch.grid_sizes, list | tuple | torch.Tensor):
            raise ValueError("grid_sizes must be a list, tuple, or tensor")
        if isinstance(batch.grid_sizes, torch.Tensor):
            if batch.grid_sizes.numel() != 3:
                raise ValueError(
                    "grid_sizes must have 3 elements [F, H, W]")
        else:
            if len(batch.grid_sizes) != 3:
                raise ValueError(
                    "grid_sizes must have 3 elements [F, H, W]")
        logger.info("Action control: grid_sizes validated - %s",
                    batch.grid_sizes)

    return batch
fastvideo.pipelines.stages.input_validation.InputValidationStage.verify_input
verify_input(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult

Verify input validation stage inputs.

Source code in fastvideo/pipelines/stages/input_validation.py
def verify_input(self, batch: ForwardBatch,
                 fastvideo_args: FastVideoArgs) -> VerificationResult:
    """Verify input validation stage inputs."""
    result = VerificationResult()
    result.add_check("seed", batch.seed, [V.not_none, V.positive_int])
    result.add_check("num_videos_per_prompt", batch.num_videos_per_prompt,
                     V.positive_int)
    result.add_check(
        "prompt_or_embeds", None, lambda _: V.string_or_list_strings(
            batch.prompt) or V.list_not_empty(batch.prompt_embeds))
    result.add_check("height", batch.height, V.positive_int)
    result.add_check("width", batch.width, V.positive_int)
    result.add_check("num_inference_steps", batch.num_inference_steps,
                     V.positive_int)
    result.add_check(
        "guidance_scale", batch.guidance_scale, lambda x: not batch.
        do_classifier_free_guidance or V.positive_float(x))
    return result
fastvideo.pipelines.stages.input_validation.InputValidationStage.verify_output
verify_output(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult

Verify input validation stage outputs.

Source code in fastvideo/pipelines/stages/input_validation.py
def verify_output(self, batch: ForwardBatch,
                  fastvideo_args: FastVideoArgs) -> VerificationResult:
    """Verify input validation stage outputs."""
    result = VerificationResult()
    result.add_check("seeds", batch.seeds, V.list_not_empty)
    result.add_check("generator", batch.generator,
                     V.generator_or_list_generators)
    return result

Functions

fastvideo.pipelines.stages.latent_preparation

Latent preparation stage for diffusion pipelines.

Classes

fastvideo.pipelines.stages.latent_preparation.Cosmos25LatentPreparationStage
Cosmos25LatentPreparationStage(scheduler, transformer, vae=None)

Bases: CosmosLatentPreparationStage

Latent preparation for Cosmos 2.5 DiT input conventions.

Source code in fastvideo/pipelines/stages/latent_preparation.py
def __init__(self, scheduler, transformer, vae=None) -> None:
    super().__init__()
    self.scheduler = scheduler
    self.transformer = transformer
    self.vae = vae
Functions
fastvideo.pipelines.stages.latent_preparation.Cosmos25LatentPreparationStage.adjust_video_length
adjust_video_length(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> int

Adjust video length based on VAE version.

Parameters:

Name Type Description Default
batch ForwardBatch

The current batch information.

required
fastvideo_args FastVideoArgs

The inference arguments.

required

Returns:

Type Description
int

The batch with adjusted video length.

Source code in fastvideo/pipelines/stages/latent_preparation.py
def adjust_video_length(self, batch: ForwardBatch,
                        fastvideo_args: FastVideoArgs) -> int:
    """
    Adjust video length based on VAE version.

    Args:
        batch: The current batch information.
        fastvideo_args: The inference arguments.

    Returns:
        The batch with adjusted video length.
    """

    video_length = batch.num_frames
    use_temporal_scaling_frames = fastvideo_args.pipeline_config.vae_config.use_temporal_scaling_frames
    if use_temporal_scaling_frames:
        temporal_scale_factor = fastvideo_args.pipeline_config.vae_config.arch_config.temporal_compression_ratio
        latent_num_frames = (video_length - 1) // temporal_scale_factor + 1
    else:  # stepvideo only
        latent_num_frames = video_length // 17 * 3
    return int(latent_num_frames)
fastvideo.pipelines.stages.latent_preparation.Cosmos25LatentPreparationStage.verify_input
verify_input(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult

Verify Cosmos latent preparation stage inputs.

Source code in fastvideo/pipelines/stages/latent_preparation.py
def verify_input(self, batch: ForwardBatch,
                 fastvideo_args: FastVideoArgs) -> VerificationResult:
    """Verify Cosmos latent preparation stage inputs."""
    result = VerificationResult()
    result.add_check(
        "prompt_or_embeds", None, lambda _: V.string_or_list_strings(
            batch.prompt) or V.list_not_empty(batch.prompt_embeds))
    result.add_check("prompt_embeds", batch.prompt_embeds,
                     V.list_of_tensors)
    result.add_check("num_videos_per_prompt", batch.num_videos_per_prompt,
                     V.positive_int)
    result.add_check("generator", batch.generator,
                     V.generator_or_list_generators)
    result.add_check("num_frames", batch.num_frames, V.positive_int)
    result.add_check("height", batch.height, V.positive_int)
    result.add_check("width", batch.width, V.positive_int)
    result.add_check("latents", batch.latents, V.none_or_tensor)
    return result
fastvideo.pipelines.stages.latent_preparation.Cosmos25LatentPreparationStage.verify_output
verify_output(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult

Verify latent preparation stage outputs.

Source code in fastvideo/pipelines/stages/latent_preparation.py
def verify_output(self, batch: ForwardBatch,
                  fastvideo_args: FastVideoArgs) -> VerificationResult:
    """Verify latent preparation stage outputs."""
    result = VerificationResult()
    result.add_check("latents", batch.latents,
                     [V.is_tensor, V.with_dims(5)])
    result.add_check("raw_latent_shape", batch.raw_latent_shape, V.is_tuple)
    return result
fastvideo.pipelines.stages.latent_preparation.CosmosLatentPreparationStage
CosmosLatentPreparationStage(scheduler, transformer, vae=None)

Bases: PipelineStage

Cosmos-specific latent preparation stage that properly handles the tensor shapes and conditioning masks required by the Cosmos transformer.

This stage replicates the logic from diffusers' Cosmos2VideoToWorldPipeline.prepare_latents()

Source code in fastvideo/pipelines/stages/latent_preparation.py
def __init__(self, scheduler, transformer, vae=None) -> None:
    super().__init__()
    self.scheduler = scheduler
    self.transformer = transformer
    self.vae = vae
fastvideo.pipelines.stages.latent_preparation.LatentPreparationStage
LatentPreparationStage(scheduler, transformer, use_btchw_layout: bool = False)

Bases: PipelineStage

Stage for preparing initial latent variables for the diffusion process.

This stage handles the preparation of the initial latent variables that will be denoised during the diffusion process.

Source code in fastvideo/pipelines/stages/latent_preparation.py
def __init__(self,
             scheduler,
             transformer,
             use_btchw_layout: bool = False) -> None:
    super().__init__()
    self.scheduler = scheduler
    self.transformer = transformer
    self.use_btchw_layout = use_btchw_layout
Functions
fastvideo.pipelines.stages.latent_preparation.LatentPreparationStage.adjust_video_length
adjust_video_length(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> int

Adjust video length based on VAE version.

Parameters:

Name Type Description Default
batch ForwardBatch

The current batch information.

required
fastvideo_args FastVideoArgs

The inference arguments.

required

Returns:

Type Description
int

The batch with adjusted video length.

Source code in fastvideo/pipelines/stages/latent_preparation.py
def adjust_video_length(self, batch: ForwardBatch,
                        fastvideo_args: FastVideoArgs) -> int:
    """
    Adjust video length based on VAE version.

    Args:
        batch: The current batch information.
        fastvideo_args: The inference arguments.

    Returns:
        The batch with adjusted video length.
    """

    video_length = batch.num_frames
    use_temporal_scaling_frames = fastvideo_args.pipeline_config.vae_config.use_temporal_scaling_frames
    if use_temporal_scaling_frames:
        temporal_scale_factor = fastvideo_args.pipeline_config.vae_config.arch_config.temporal_compression_ratio
        latent_num_frames = (video_length - 1) // temporal_scale_factor + 1
    else:  # stepvideo only
        latent_num_frames = video_length // 17 * 3
    return int(latent_num_frames)
fastvideo.pipelines.stages.latent_preparation.LatentPreparationStage.forward
forward(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> ForwardBatch

Prepare initial latent variables for the diffusion process.

Parameters:

Name Type Description Default
batch ForwardBatch

The current batch information.

required
fastvideo_args FastVideoArgs

The inference arguments.

required

Returns:

Type Description
ForwardBatch

The batch with prepared latent variables.

Source code in fastvideo/pipelines/stages/latent_preparation.py
def forward(
    self,
    batch: ForwardBatch,
    fastvideo_args: FastVideoArgs,
) -> ForwardBatch:
    """
    Prepare initial latent variables for the diffusion process.

    Args:
        batch: The current batch information.
        fastvideo_args: The inference arguments.

    Returns:
        The batch with prepared latent variables.
    """

    latent_num_frames = None
    # Adjust video length based on VAE version if needed
    if hasattr(self, 'adjust_video_length'):
        latent_num_frames = self.adjust_video_length(batch, fastvideo_args)
    # Determine batch size; fall back to action/image inputs when no text encoder is present
    if not batch.prompt_embeds:
        if batch.keyboard_cond is not None:
            batch_size = batch.keyboard_cond.shape[0]
        elif batch.mouse_cond is not None:
            batch_size = batch.mouse_cond.shape[0]
        elif batch.image_embeds:
            batch_size = batch.image_embeds[0].shape[0]
        else:
            batch_size = 1
    elif isinstance(batch.prompt, list):
        batch_size = len(batch.prompt)
    elif batch.prompt is not None:
        batch_size = 1
    else:
        batch_size = batch.prompt_embeds[0].shape[0]

    # Adjust batch size for number of videos per prompt
    batch_size *= batch.num_videos_per_prompt

    # Get required parameters
    if not batch.prompt_embeds:
        # Create a dummy zero-length text embedding to satisfy downstream checks.
        # Matrix-Game models have text_dim=0 and ignore encoder_hidden_states.
        transformer_dtype = next(self.transformer.parameters()).dtype
        device = get_local_torch_device()
        dummy_prompt = torch.zeros(batch_size,
                                   0,
                                   self.transformer.hidden_size,
                                   device=device,
                                   dtype=transformer_dtype)
        batch.prompt_embeds = [dummy_prompt]
        batch.negative_prompt_embeds = []
        batch.do_classifier_free_guidance = False
    dtype = batch.prompt_embeds[0].dtype
    device = get_local_torch_device()
    generator = batch.generator
    latents = batch.latents
    num_frames = latent_num_frames if latent_num_frames is not None else batch.num_frames
    height = batch.height
    width = batch.width

    # TODO(will): remove this once we add input/output validation for stages
    if height is None or width is None:
        raise ValueError("Height and width must be provided")

    # Calculate latent shape
    bcthw_shape: tuple[int, ...] | None = None
    if self.use_btchw_layout:
        shape = (
            batch_size,
            num_frames,
            self.transformer.num_channels_latents,
            height // fastvideo_args.pipeline_config.vae_config.arch_config.
            spatial_compression_ratio,
            width // fastvideo_args.pipeline_config.vae_config.arch_config.
            spatial_compression_ratio,
        )
        bcthw_shape = tuple(shape[i] for i in [0, 2, 1, 3, 4])
    else:
        shape = (
            batch_size,
            self.transformer.num_channels_latents,
            num_frames,
            height // fastvideo_args.pipeline_config.vae_config.arch_config.
            spatial_compression_ratio,
            width // fastvideo_args.pipeline_config.vae_config.arch_config.
            spatial_compression_ratio,
        )
        bcthw_shape = shape

    # Validate generator if it's a list
    if isinstance(generator, list) and len(generator) != batch_size:
        raise ValueError(
            f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
            f" size of {batch_size}. Make sure the batch size matches the length of the generators."
        )
    # Generate or use provided latents
    if latents is None:
        latents = randn_tensor(
            shape,
            generator=generator,
            device=device,
            dtype=dtype,
        )
        if hasattr(self.scheduler, "init_noise_sigma"):
            latents = latents * self.scheduler.init_noise_sigma
    else:
        # Pre-initialized latents:
        # - For LongCat refine (refine_from or stage1_video present), we should not re-scale by init_noise_sigma.
        # - For other models, keep the original behavior.
        latents = latents.to(device)
        is_longcat_refine = (batch.refine_from
                             is not None) or (batch.stage1_video
                                              is not None)
        if (not is_longcat_refine) and hasattr(self.scheduler,
                                               "init_noise_sigma"):
            latents = latents * self.scheduler.init_noise_sigma

    # Update batch with prepared latents
    batch.latents = latents
    batch.raw_latent_shape = bcthw_shape

    return batch
fastvideo.pipelines.stages.latent_preparation.LatentPreparationStage.verify_input
verify_input(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult

Verify latent preparation stage inputs.

Source code in fastvideo/pipelines/stages/latent_preparation.py
def verify_input(self, batch: ForwardBatch,
                 fastvideo_args: FastVideoArgs) -> VerificationResult:
    """Verify latent preparation stage inputs."""
    result = VerificationResult()
    result.add_check(
        "prompt_or_embeds", None,
        lambda _: V.string_or_list_strings(batch.prompt) or not batch.
        prompt_embeds or V.list_not_empty(batch.prompt_embeds))
    if batch.prompt_embeds:
        result.add_check("prompt_embeds", batch.prompt_embeds,
                         V.list_of_tensors)
    result.add_check("num_videos_per_prompt", batch.num_videos_per_prompt,
                     V.positive_int)
    result.add_check("generator", batch.generator,
                     V.generator_or_list_generators)
    result.add_check("num_frames", batch.num_frames, V.positive_int)
    result.add_check("height", batch.height, V.positive_int)
    result.add_check("width", batch.width, V.positive_int)
    result.add_check("latents", batch.latents, V.none_or_tensor)
    return result
fastvideo.pipelines.stages.latent_preparation.LatentPreparationStage.verify_output
verify_output(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult

Verify latent preparation stage outputs.

Source code in fastvideo/pipelines/stages/latent_preparation.py
def verify_output(self, batch: ForwardBatch,
                  fastvideo_args: FastVideoArgs) -> VerificationResult:
    """Verify latent preparation stage outputs."""
    result = VerificationResult()
    result.add_check("latents", batch.latents,
                     [V.is_tensor, V.with_dims(5)])
    result.add_check("raw_latent_shape", batch.raw_latent_shape, V.is_tuple)
    return result

Functions

fastvideo.pipelines.stages.longcat_denoising

LongCat-specific denoising stage implementing CFG-zero optimized guidance.

Classes

fastvideo.pipelines.stages.longcat_denoising.LongCatDenoisingStage
LongCatDenoisingStage(transformer, scheduler, pipeline=None, transformer_2=None, vae=None)

Bases: DenoisingStage

LongCat denoising stage with CFG-zero optimized guidance scale.

Implements: 1. Optimized CFG scale from CFG-zero paper 2. Negation of noise prediction before scheduler step (flow matching convention) 3. Batched CFG computation (unlike standard FastVideo separate passes)

Source code in fastvideo/pipelines/stages/denoising.py
def __init__(self,
             transformer,
             scheduler,
             pipeline=None,
             transformer_2=None,
             vae=None) -> None:
    super().__init__()
    self.transformer = transformer
    self.transformer_2 = transformer_2
    self.scheduler = scheduler
    self.vae = vae
    self.pipeline = weakref.ref(pipeline) if pipeline else None
    attn_head_size = self.transformer.hidden_size // self.transformer.num_attention_heads
    self.attn_backend = get_attn_backend(
        head_size=attn_head_size,
        dtype=torch.float16,  # TODO(will): hack
        supported_attention_backends=(
            AttentionBackendEnum.SLIDING_TILE_ATTN,
            AttentionBackendEnum.VIDEO_SPARSE_ATTN,
            AttentionBackendEnum.VMOBA_ATTN,
            AttentionBackendEnum.FLASH_ATTN,
            AttentionBackendEnum.TORCH_SDPA,
            AttentionBackendEnum.SAGE_ATTN_THREE)  # hack
    )
Functions
fastvideo.pipelines.stages.longcat_denoising.LongCatDenoisingStage.forward
forward(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> ForwardBatch

Run LongCat denoising loop with optimized CFG.

Parameters:

Name Type Description Default
batch ForwardBatch

The current batch information.

required
fastvideo_args FastVideoArgs

The inference arguments.

required

Returns:

Type Description
ForwardBatch

The batch with denoised latents.

Source code in fastvideo/pipelines/stages/longcat_denoising.py
def forward(
    self,
    batch: ForwardBatch,
    fastvideo_args: FastVideoArgs,
) -> ForwardBatch:
    """
    Run LongCat denoising loop with optimized CFG.

    Args:
        batch: The current batch information.
        fastvideo_args: The inference arguments.

    Returns:
        The batch with denoised latents.
    """
    if not fastvideo_args.model_loaded["transformer"]:
        from fastvideo.models.loader.component_loader import TransformerLoader
        loader = TransformerLoader()
        self.transformer = loader.load(
            fastvideo_args.model_paths["transformer"], fastvideo_args)
        pipeline = self.pipeline() if self.pipeline else None
        if pipeline:
            pipeline.add_module("transformer", self.transformer)
        fastvideo_args.model_loaded["transformer"] = True

    # Get transformer dtype
    if hasattr(self.transformer, 'module'):
        transformer_dtype = next(self.transformer.module.parameters()).dtype
    else:
        transformer_dtype = next(self.transformer.parameters()).dtype

    target_dtype = transformer_dtype
    autocast_enabled = (target_dtype != torch.float32
                        ) and not fastvideo_args.disable_autocast

    # Extract batch parameters
    latents = batch.latents
    timesteps = batch.timesteps
    prompt_embeds = batch.prompt_embeds[0]  # LongCat uses single encoder
    prompt_attention_mask = batch.prompt_attention_mask[
        0] if batch.prompt_attention_mask else None
    guidance_scale = batch.guidance_scale
    do_classifier_free_guidance = batch.do_classifier_free_guidance

    # Get negative prompts if doing CFG
    if do_classifier_free_guidance:
        negative_prompt_embeds = batch.negative_prompt_embeds[0]
        negative_prompt_attention_mask = (batch.negative_attention_mask[0]
                                          if batch.negative_attention_mask
                                          else None)
        # Concatenate for batched processing
        prompt_embeds_combined = torch.cat(
            [negative_prompt_embeds, prompt_embeds], dim=0)
        if prompt_attention_mask is not None:
            prompt_attention_mask_combined = torch.cat(
                [negative_prompt_attention_mask, prompt_attention_mask],
                dim=0)
        else:
            prompt_attention_mask_combined = None
    else:
        prompt_embeds_combined = prompt_embeds
        prompt_attention_mask_combined = prompt_attention_mask

    # Denoising loop
    num_inference_steps = len(timesteps)
    with tqdm(total=num_inference_steps,
              desc="LongCat Denoising") as progress_bar:
        for i, t in enumerate(timesteps):
            # Expand latents for CFG
            if do_classifier_free_guidance:
                latent_model_input = torch.cat([latents] * 2)
            else:
                latent_model_input = latents

            latent_model_input = latent_model_input.to(target_dtype)

            # Expand timestep to match batch size
            timestep = t.expand(
                latent_model_input.shape[0]).to(target_dtype)

            # Run transformer with context
            batch.is_cfg_negative = False
            with set_forward_context(
                    current_timestep=i,
                    attn_metadata=None,
                    forward_batch=batch,
            ), torch.autocast(device_type='cuda',
                              dtype=target_dtype,
                              enabled=autocast_enabled):
                noise_pred = self.transformer(
                    hidden_states=latent_model_input,
                    encoder_hidden_states=prompt_embeds_combined,
                    timestep=timestep,
                    encoder_attention_mask=prompt_attention_mask_combined,
                )

            # Apply CFG with optimized scale
            if do_classifier_free_guidance:
                noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)

                B = noise_pred_cond.shape[0]
                positive = noise_pred_cond.reshape(B, -1)
                negative = noise_pred_uncond.reshape(B, -1)

                # Calculate optimized scale (CFG-zero)
                st_star = self.optimized_scale(positive, negative)

                # Reshape for broadcasting
                st_star = st_star.view(B, 1, 1, 1, 1)

                # Apply optimized CFG formula
                noise_pred = (
                    noise_pred_uncond * st_star + guidance_scale *
                    (noise_pred_cond - noise_pred_uncond * st_star))

            # CRITICAL: Negate noise prediction for flow matching scheduler
            noise_pred = -noise_pred

            # Compute previous noisy sample x_t -> x_t-1
            latents = self.scheduler.step(noise_pred,
                                          t,
                                          latents,
                                          return_dict=False)[0]

            progress_bar.update()

    # Update batch with denoised latents
    batch.latents = latents
    return batch
fastvideo.pipelines.stages.longcat_denoising.LongCatDenoisingStage.optimized_scale
optimized_scale(positive_flat, negative_flat) -> Tensor

Calculate optimized scale from CFG-zero paper.

st_star = (v_cond^T * v_uncond) / ||v_uncond||^2

Parameters:

Name Type Description Default
positive_flat

Conditional prediction, flattened [B, -1]

required
negative_flat

Unconditional prediction, flattened [B, -1]

required

Returns:

Name Type Description
st_star Tensor

Optimized scale [B, 1]

Source code in fastvideo/pipelines/stages/longcat_denoising.py
def optimized_scale(self, positive_flat, negative_flat) -> torch.Tensor:
    """
    Calculate optimized scale from CFG-zero paper.

    st_star = (v_cond^T * v_uncond) / ||v_uncond||^2

    Args:
        positive_flat: Conditional prediction, flattened [B, -1]
        negative_flat: Unconditional prediction, flattened [B, -1]

    Returns:
        st_star: Optimized scale [B, 1]
    """
    # Calculate dot product
    dot_product = torch.sum(positive_flat * negative_flat,
                            dim=1,
                            keepdim=True)
    # Squared norm of uncondition
    squared_norm = torch.sum(negative_flat**2, dim=1, keepdim=True) + 1e-8
    # st_star = v_cond^T * v_uncond / ||v_uncond||^2
    st_star = dot_product / squared_norm
    return st_star

Functions

fastvideo.pipelines.stages.longcat_i2v_denoising

LongCat I2V Denoising Stage with conditioning support.

This stage implements Tier 3 I2V denoising: 1. Per-frame timestep masking (timestep[:, :num_cond_latents] = 0) 2. Passes num_cond_latents to transformer (for RoPE skipping) 3. Selective denoising (only updates non-conditioned frames) 4. CFG-zero optimized guidance

Classes

fastvideo.pipelines.stages.longcat_i2v_denoising.LongCatI2VDenoisingStage
LongCatI2VDenoisingStage(transformer, scheduler, pipeline=None, transformer_2=None, vae=None)

Bases: LongCatDenoisingStage

LongCat denoising with I2V conditioning support.

Key modifications from base LongCat denoising: 1. Sets timestep=0 for conditioning frames 2. Passes num_cond_latents to transformer 3. Only applies scheduler step to non-conditioned frames

Source code in fastvideo/pipelines/stages/denoising.py
def __init__(self,
             transformer,
             scheduler,
             pipeline=None,
             transformer_2=None,
             vae=None) -> None:
    super().__init__()
    self.transformer = transformer
    self.transformer_2 = transformer_2
    self.scheduler = scheduler
    self.vae = vae
    self.pipeline = weakref.ref(pipeline) if pipeline else None
    attn_head_size = self.transformer.hidden_size // self.transformer.num_attention_heads
    self.attn_backend = get_attn_backend(
        head_size=attn_head_size,
        dtype=torch.float16,  # TODO(will): hack
        supported_attention_backends=(
            AttentionBackendEnum.SLIDING_TILE_ATTN,
            AttentionBackendEnum.VIDEO_SPARSE_ATTN,
            AttentionBackendEnum.VMOBA_ATTN,
            AttentionBackendEnum.FLASH_ATTN,
            AttentionBackendEnum.TORCH_SDPA,
            AttentionBackendEnum.SAGE_ATTN_THREE)  # hack
    )
Functions
fastvideo.pipelines.stages.longcat_i2v_denoising.LongCatI2VDenoisingStage.forward
forward(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> ForwardBatch

Run denoising loop with I2V conditioning.

Source code in fastvideo/pipelines/stages/longcat_i2v_denoising.py
def forward(
    self,
    batch: ForwardBatch,
    fastvideo_args: FastVideoArgs,
) -> ForwardBatch:
    """Run denoising loop with I2V conditioning."""

    # Load transformer if needed
    if not fastvideo_args.model_loaded["transformer"]:
        loader = TransformerLoader()
        self.transformer = loader.load(
            fastvideo_args.model_paths["transformer"], fastvideo_args)
        fastvideo_args.model_loaded["transformer"] = True

    # Setup
    target_dtype = torch.bfloat16
    autocast_enabled = (target_dtype != torch.float32
                        ) and not fastvideo_args.disable_autocast

    latents = batch.latents
    timesteps = batch.timesteps
    prompt_embeds = batch.prompt_embeds[0]
    prompt_attention_mask = (batch.prompt_attention_mask[0]
                             if batch.prompt_attention_mask else None)
    guidance_scale = batch.guidance_scale
    do_classifier_free_guidance = batch.do_classifier_free_guidance

    # Get num_cond_latents from batch
    num_cond_latents = getattr(batch, 'num_cond_latents', 0)

    if num_cond_latents > 0:
        logger.info("I2V Denoising: num_cond_latents=%s, latent_shape=%s",
                    num_cond_latents, latents.shape)

    # Prepare negative prompts for CFG
    if do_classifier_free_guidance:
        negative_prompt_embeds = batch.negative_prompt_embeds[0]
        negative_prompt_attention_mask = (batch.negative_attention_mask[0]
                                          if batch.negative_attention_mask
                                          else None)

        prompt_embeds_combined = torch.cat(
            [negative_prompt_embeds, prompt_embeds], dim=0)
        if prompt_attention_mask is not None:
            prompt_attention_mask_combined = torch.cat(
                [negative_prompt_attention_mask, prompt_attention_mask],
                dim=0)
        else:
            prompt_attention_mask_combined = None
    else:
        prompt_embeds_combined = prompt_embeds
        prompt_attention_mask_combined = prompt_attention_mask

    # Denoising loop
    num_inference_steps = len(timesteps)

    with tqdm(total=num_inference_steps,
              desc="I2V Denoising") as progress_bar:
        for i, t in enumerate(timesteps):

            # 1. Expand latents for CFG
            if do_classifier_free_guidance:
                latent_model_input = torch.cat([latents] * 2)
            else:
                latent_model_input = latents

            latent_model_input = latent_model_input.to(target_dtype)

            # 2. Expand timestep to match batch size
            timestep = t.expand(
                latent_model_input.shape[0]).to(target_dtype)

            # 3. CRITICAL: Expand timestep to temporal dimension
            # and set conditioning frames to timestep=0
            timestep = timestep.unsqueeze(-1).repeat(
                1, latent_model_input.shape[2])

            # Mark conditioning frames as clean (timestep=0)
            if num_cond_latents > 0:
                timestep[:, :num_cond_latents] = 0

            # 4. Run transformer with num_cond_latents
            batch.is_cfg_negative = False
            with set_forward_context(
                    current_timestep=i,
                    attn_metadata=None,
                    forward_batch=batch,
            ), torch.autocast(device_type='cuda',
                              dtype=target_dtype,
                              enabled=autocast_enabled):
                noise_pred = self.transformer(
                    hidden_states=latent_model_input,
                    encoder_hidden_states=prompt_embeds_combined,
                    timestep=timestep,
                    encoder_attention_mask=prompt_attention_mask_combined,
                    num_cond_latents=num_cond_latents,
                )

            # 5. Apply CFG with optimized scale
            if do_classifier_free_guidance:
                noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)

                B = noise_pred_cond.shape[0]
                positive = noise_pred_cond.reshape(B, -1)
                negative = noise_pred_uncond.reshape(B, -1)

                # CFG-zero optimized scale
                st_star = self.optimized_scale(positive, negative)
                st_star = st_star.view(B, 1, 1, 1, 1)

                noise_pred = (
                    noise_pred_uncond * st_star + guidance_scale *
                    (noise_pred_cond - noise_pred_uncond * st_star))

            # 6. CRITICAL: Negate for flow matching scheduler
            noise_pred = -noise_pred

            # 7. CRITICAL: Only update non-conditioned frames
            # The conditioning frames stay FIXED throughout denoising
            if num_cond_latents > 0:
                latents[:, :, num_cond_latents:] = self.scheduler.step(
                    noise_pred[:, :, num_cond_latents:],
                    t,
                    latents[:, :, num_cond_latents:],
                    return_dict=False)[0]
            else:
                # No conditioning, update all frames
                latents = self.scheduler.step(noise_pred,
                                              t,
                                              latents,
                                              return_dict=False)[0]

            progress_bar.update()

    # Update batch with denoised latents
    batch.latents = latents
    return batch

Functions

fastvideo.pipelines.stages.longcat_i2v_latent_preparation

LongCat I2V Latent Preparation Stage.

This stage prepares latents with image conditioning for the first frame.

Classes

fastvideo.pipelines.stages.longcat_i2v_latent_preparation.LongCatI2VLatentPreparationStage
LongCatI2VLatentPreparationStage(scheduler, transformer, use_btchw_layout: bool = False)

Bases: LatentPreparationStage

Prepare latents with image conditioning for first frame.

This stage: 1. Generates random noise for all frames 2. Replaces first latent frame with encoded image latent 3. Marks conditioning information in batch

Source code in fastvideo/pipelines/stages/latent_preparation.py
def __init__(self,
             scheduler,
             transformer,
             use_btchw_layout: bool = False) -> None:
    super().__init__()
    self.scheduler = scheduler
    self.transformer = transformer
    self.use_btchw_layout = use_btchw_layout
Functions
fastvideo.pipelines.stages.longcat_i2v_latent_preparation.LongCatI2VLatentPreparationStage.forward
forward(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> ForwardBatch

Prepare latents with I2V conditioning.

Source code in fastvideo/pipelines/stages/longcat_i2v_latent_preparation.py
def forward(
    self,
    batch: ForwardBatch,
    fastvideo_args: FastVideoArgs,
) -> ForwardBatch:
    """Prepare latents with I2V conditioning."""

    # IMPORTANT: Skip if latents already prepared (e.g., by refinement init stage)
    # The refine_init stage encodes stage1 video and mixes with noise - don't overwrite!
    if batch.latents is not None:
        logger.info(
            "I2V Latent Prep: Skipping - latents already prepared "
            "(shape=%s), likely from refinement stage", batch.latents.shape)
        return batch

    # 1. Calculate dimensions
    num_frames = batch.num_frames
    height = batch.height
    width = batch.width

    # Get VAE compression factors
    # IMPORTANT: Use VAE's temporal compression (4), NOT transformer's patch_size[0] (1)
    vae_temporal_scale = fastvideo_args.pipeline_config.vae_config.arch_config.scale_factor_temporal
    vae_spatial_scale = fastvideo_args.pipeline_config.vae_config.arch_config.scale_factor_spatial

    num_latent_frames = (num_frames - 1) // vae_temporal_scale + 1
    latent_height = height // vae_spatial_scale
    latent_width = width // vae_spatial_scale

    num_channels = self.transformer.config.in_channels

    logger.info(
        "I2V Latent Prep: num_frames=%s, num_latent_frames=%s "
        "(vae_temporal_scale=%s), latent_shape=(%s, %s)", num_frames,
        num_latent_frames, vae_temporal_scale, latent_height, latent_width)

    # 2. Generate random noise for all frames
    # batch_size might not be set, default to 1
    batch_size = batch.batch_size if batch.batch_size is not None else 1
    shape = (batch_size, num_channels, num_latent_frames, latent_height,
             latent_width)

    # Handle generator - may be a list for batch handling
    generator = batch.generator
    if isinstance(generator, list):
        generator = generator[0] if generator else None

    # torch.randn requires specific argument order: size, generator, dtype
    latents = torch.randn(*shape,
                          generator=generator).to(get_local_torch_device(),
                                                  dtype=torch.float32)

    # 3. Replace first frame with conditioned image latent
    if batch.image_latent is not None:
        num_cond_latents = batch.num_cond_latents
        latents[:, :, :
                num_cond_latents] = batch.image_latent[:, :, :
                                                       num_cond_latents]

        logger.info(
            "I2V: Replaced first %s latent frame(s) with image conditioning",
            num_cond_latents)
    else:
        logger.warning(
            "No image_latent found in batch, proceeding without conditioning"
        )

    # 4. Store in batch
    batch.latents = latents

    # Required by base class output validator
    batch.raw_latent_shape = (num_latent_frames, latent_height,
                              latent_width)

    return batch

Functions

fastvideo.pipelines.stages.longcat_image_vae_encoding

LongCat Image VAE Encoding Stage for I2V generation.

This stage handles encoding a single input image to latent space with LongCat-specific normalization for I2V conditioning.

Classes

fastvideo.pipelines.stages.longcat_image_vae_encoding.LongCatImageVAEEncodingStage
LongCatImageVAEEncodingStage(vae)

Bases: PipelineStage

Encode input image to latent space for I2V conditioning.

This stage: 1. Preprocesses image to match target dimensions 2. Encodes via VAE to latent space 3. Applies LongCat-specific normalization 4. Stores latent and calculates num_cond_latents

Source code in fastvideo/pipelines/stages/longcat_image_vae_encoding.py
def __init__(self, vae):
    super().__init__()
    self.vae = vae
Functions
fastvideo.pipelines.stages.longcat_image_vae_encoding.LongCatImageVAEEncodingStage.forward
forward(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> ForwardBatch

Encode image to latent for I2V conditioning.

Source code in fastvideo/pipelines/stages/longcat_image_vae_encoding.py
def forward(
    self,
    batch: ForwardBatch,
    fastvideo_args: FastVideoArgs,
) -> ForwardBatch:
    """Encode image to latent for I2V conditioning."""

    # Skip image encoding for refinement tasks - we're refining an existing video
    if getattr(batch, 'stage1_video', None) is not None or getattr(
            batch, 'refine_from', None) is not None:
        logger.info(
            "Skipping image encoding - refinement mode (using stage1_video)"
        )
        return batch

    # 1. Get image from batch
    image = batch.pil_image  # PIL.Image
    if image is None:
        raise ValueError("pil_image must be provided for I2V")

    if not isinstance(image, PIL.Image.Image):
        raise TypeError(f"pil_image must be PIL.Image, got {type(image)}")

    # 2. Get target dimensions
    height = batch.height
    width = batch.width

    if height is None or width is None:
        raise ValueError("height and width must be set for I2V")

    # 3. Preprocess image
    image = resize(image, height, width, resize_mode="default")
    image = pil_to_numpy(image)
    image = numpy_to_pt(image)
    image = normalize(image)  # to [-1, 1]

    # 4. Add temporal dimension
    # After numpy_to_pt: [1, C, H, W] (batch already added by pil_to_numpy)
    # Add T dimension: [1, C, H, W] -> [1, C, 1, H, W] = [B, C, T, H, W]
    image = image.unsqueeze(2)
    image = image.to(get_local_torch_device(), dtype=torch.float32)

    # 5. Encode via VAE
    self.vae = self.vae.to(get_local_torch_device())

    # Setup VAE precision
    vae_dtype = PRECISION_TO_TYPE[
        fastvideo_args.pipeline_config.vae_precision]
    vae_autocast_enabled = (
        vae_dtype != torch.float32) and not fastvideo_args.disable_autocast

    with torch.autocast(device_type="cuda",
                        dtype=vae_dtype,
                        enabled=vae_autocast_enabled):
        if fastvideo_args.pipeline_config.vae_tiling:
            self.vae.enable_tiling()

        if not vae_autocast_enabled:
            image = image.to(vae_dtype)

        with torch.no_grad():
            encoder_output = self.vae.encode(image)
            latent = self.retrieve_latents(encoder_output, batch.generator)

    # 6. Apply LongCat-specific normalization
    # Formula: (latents - mean) / std
    latent = self.normalize_latents(latent)

    # 7. Calculate num_cond_latents
    # Formula: 1 + (num_cond_frames - 1) // vae_temporal_scale
    # For single image (num_cond_frames=1): always 1 latent frame
    num_cond_frames = 1  # Single image
    vae_temporal_scale = self.vae.config.scale_factor_temporal
    batch.num_cond_latents = 1 + (num_cond_frames - 1) // vae_temporal_scale

    # 8. Store in batch
    batch.image_latent = latent
    batch.num_cond_frames = 1

    logger.info(
        "I2V: Encoded image to latent shape %s, num_cond_latents=%s",
        latent.shape, batch.num_cond_latents)

    # Offload VAE if needed
    if fastvideo_args.vae_cpu_offload:
        self.vae.to("cpu")

    return batch
fastvideo.pipelines.stages.longcat_image_vae_encoding.LongCatImageVAEEncodingStage.normalize_latents
normalize_latents(latents: Tensor) -> Tensor

Apply LongCat-specific latent normalization.

Formula: (latents - mean) / std

This matches the original LongCat implementation and is DIFFERENT from standard VAE scaling (which uses scaling_factor).

Source code in fastvideo/pipelines/stages/longcat_image_vae_encoding.py
def normalize_latents(self, latents: torch.Tensor) -> torch.Tensor:
    """
    Apply LongCat-specific latent normalization.

    Formula: (latents - mean) / std

    This matches the original LongCat implementation and is DIFFERENT
    from standard VAE scaling (which uses scaling_factor).
    """
    if not hasattr(self.vae.config, 'latents_mean') or not hasattr(
            self.vae.config, 'latents_std'):
        raise ValueError(
            "VAE config must have 'latents_mean' and 'latents_std' "
            "for LongCat normalization")

    latents_mean = torch.tensor(self.vae.config.latents_mean).view(
        1, self.vae.config.z_dim, 1, 1, 1).to(latents.device, latents.dtype)

    latents_std = torch.tensor(self.vae.config.latents_std).view(
        1, self.vae.config.z_dim, 1, 1, 1).to(latents.device, latents.dtype)

    return (latents - latents_mean) / latents_std
fastvideo.pipelines.stages.longcat_image_vae_encoding.LongCatImageVAEEncodingStage.retrieve_latents
retrieve_latents(encoder_output: object, generator: Generator | None) -> Tensor

Sample from VAE posterior.

Source code in fastvideo/pipelines/stages/longcat_image_vae_encoding.py
def retrieve_latents(self, encoder_output: object,
                     generator: torch.Generator | None) -> torch.Tensor:
    """Sample from VAE posterior."""
    # WAN VAE returns an object with .sample() method
    if hasattr(encoder_output, 'sample'):
        return encoder_output.sample(generator)
    elif hasattr(encoder_output, 'latent_dist'):
        return encoder_output.latent_dist.sample(generator)
    elif hasattr(encoder_output, 'latents'):
        return encoder_output.latents
    else:
        raise AttributeError("Could not access latents from encoder output")

Functions

fastvideo.pipelines.stages.longcat_kv_cache_init

LongCat KV Cache Initialization Stage for Video Continuation (VC).

This stage pre-computes K/V cache for conditioning frames.

Classes

fastvideo.pipelines.stages.longcat_kv_cache_init.LongCatKVCacheInitStage
LongCatKVCacheInitStage(transformer)

Bases: PipelineStage

Pre-compute KV cache for conditioning frames.

After this stage: - batch.kv_cache_dict contains {block_idx: (k, v)} - batch.cond_latents contains the conditioning latents - batch.latents contains ONLY noise latents

Source code in fastvideo/pipelines/stages/longcat_kv_cache_init.py
def __init__(self, transformer):
    super().__init__()
    self.transformer = transformer
Functions
fastvideo.pipelines.stages.longcat_kv_cache_init.LongCatKVCacheInitStage.forward
forward(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> ForwardBatch

Initialize KV cache from conditioning latents.

Source code in fastvideo/pipelines/stages/longcat_kv_cache_init.py
def forward(
    self,
    batch: ForwardBatch,
    fastvideo_args: FastVideoArgs,
) -> ForwardBatch:
    """Initialize KV cache from conditioning latents."""

    # Check if KV cache is enabled
    use_kv_cache = getattr(fastvideo_args.pipeline_config, 'use_kv_cache',
                           True)
    if not use_kv_cache:
        batch.kv_cache_dict = {}
        batch.use_kv_cache = False
        logger.info("KV cache disabled, skipping initialization")
        return batch

    batch.use_kv_cache = True
    offload_kv_cache = getattr(fastvideo_args.pipeline_config,
                               'offload_kv_cache', False)

    # Get conditioning latents
    num_cond_latents = batch.num_cond_latents
    if num_cond_latents <= 0:
        batch.kv_cache_dict = {}
        logger.warning("num_cond_latents <= 0, skipping KV cache init")
        return batch

    # Extract conditioning latents
    cond_latents = batch.latents[:, :, :num_cond_latents].clone()

    logger.info(
        "Initializing KV cache for %d conditioning latents, shape: %s",
        num_cond_latents, cond_latents.shape)

    # Timestep = 0 for conditioning (they are "clean")
    B = cond_latents.shape[0]
    T_cond = cond_latents.shape[2]
    timestep = torch.zeros(B,
                           T_cond,
                           device=cond_latents.device,
                           dtype=cond_latents.dtype)

    # Empty prompt embeddings (cross-attn will be skipped)
    max_seq_len = 512
    # Get caption dimension from transformer config
    caption_dim = self.transformer.config.caption_channels
    empty_embeds = torch.zeros(B,
                               max_seq_len,
                               caption_dim,
                               device=cond_latents.device,
                               dtype=cond_latents.dtype)

    # Get transformer dtype
    if hasattr(self.transformer, 'module'):
        transformer_dtype = next(self.transformer.module.parameters()).dtype
    else:
        transformer_dtype = next(self.transformer.parameters()).dtype

    # Run transformer with return_kv=True, skip_crs_attn=True
    with (
            torch.no_grad(),
            set_forward_context(
                current_timestep=0,
                attn_metadata=None,
                forward_batch=batch,
            ),
            torch.autocast(device_type='cuda', dtype=transformer_dtype),
    ):
        _, kv_cache_dict = self.transformer(
            hidden_states=cond_latents.to(transformer_dtype),
            encoder_hidden_states=empty_embeds.to(transformer_dtype),
            timestep=timestep.to(transformer_dtype),
            return_kv=True,
            skip_crs_attn=True,
            offload_kv_cache=offload_kv_cache,
        )

    # Store cache and save cond_latents for later concatenation
    batch.kv_cache_dict = kv_cache_dict
    batch.cond_latents = cond_latents

    # Remove conditioning latents from main latents
    # After this, batch.latents contains ONLY noise frames
    batch.latents = batch.latents[:, :, num_cond_latents:]

    logger.info(
        "KV cache initialized: %d blocks, offload=%s, remaining latents shape: %s",
        len(kv_cache_dict), offload_kv_cache, batch.latents.shape)

    return batch

Functions

fastvideo.pipelines.stages.longcat_refine_init

LongCat refinement initialization stage.

This stage prepares the latent variables for LongCat's 480p->720p refinement by: 1. Loading the stage1 (480p) video 2. Upsampling it to 720p resolution 3. Encoding it with VAE 4. Mixing with noise according to t_thresh

Classes

fastvideo.pipelines.stages.longcat_refine_init.LongCatRefineInitStage
LongCatRefineInitStage(vae)

Bases: PipelineStage

Stage for initializing LongCat refinement from a stage1 (480p) video.

This replicates the logic from LongCatVideoPipeline.generate_refine(): - Load stage1_video frames - Upsample spatially and temporally - VAE encode and normalize - Mix with noise according to t_thresh

Source code in fastvideo/pipelines/stages/longcat_refine_init.py
def __init__(self, vae) -> None:
    super().__init__()
    self.vae = vae
Functions
fastvideo.pipelines.stages.longcat_refine_init.LongCatRefineInitStage.forward
forward(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> ForwardBatch

Initialize latents for refinement.

Parameters:

Name Type Description Default
batch ForwardBatch

The current batch information.

required
fastvideo_args FastVideoArgs

The inference arguments.

required

Returns:

Type Description
ForwardBatch

The batch with initialized latents for refinement.

Source code in fastvideo/pipelines/stages/longcat_refine_init.py
def forward(
    self,
    batch: ForwardBatch,
    fastvideo_args: FastVideoArgs,
) -> ForwardBatch:
    """
    Initialize latents for refinement.

    Args:
        batch: The current batch information.
        fastvideo_args: The inference arguments.

    Returns:
        The batch with initialized latents for refinement.
    """
    refine_from = batch.refine_from
    in_memory_stage1 = batch.stage1_video

    # Only run for refinement tasks: either a path (refine_from) or in-memory video is provided
    if refine_from is None and in_memory_stage1 is None:
        # Not a refinement task, skip
        return batch

    # ------------------------------------------------------------------
    # 1. Obtain stage1 frames (either from disk or from in-memory input)
    # ------------------------------------------------------------------
    if in_memory_stage1 is not None:
        # User provided stage1 frames directly (e.g., from distilled stage output)
        if len(in_memory_stage1) == 0:
            raise ValueError(
                "stage1_video is empty; expected a non-empty list of frames"
            )

        if isinstance(in_memory_stage1[0], Image.Image):
            pil_images = in_memory_stage1
        else:
            # Assume numpy arrays or torch tensors with shape [H, W, C]
            pil_images = [
                Image.fromarray(np.array(frame))
                for frame in in_memory_stage1
            ]

        logger.info(
            "Initializing LongCat refinement from in-memory stage1_video (%s frames)",
            len(pil_images))
    else:
        # Path-based refine: load video from disk (original design)
        logger.info("Initializing LongCat refinement from file: %s",
                    refine_from)
        stage1_video_path = Path(refine_from)
        if not stage1_video_path.exists():
            raise FileNotFoundError(
                f"Stage1 video not found: {refine_from}")

        # Load video frames as PIL Images
        pil_images, original_fps = load_video(str(stage1_video_path),
                                              return_fps=True)
        logger.info("Loaded stage1 video: %s frames @ %s fps",
                    len(pil_images), original_fps)

    # Store in batch for reference (use PIL images, same as official demo)
    batch.stage1_video = pil_images

    # Get parameters from batch
    num_frames = len(pil_images)
    spatial_refine_only = batch.spatial_refine_only
    t_thresh = batch.t_thresh
    num_cond_frames = batch.num_cond_frames if hasattr(
        batch, 'num_cond_frames') else 0

    # Calculate new frame count (temporal upsampling if not spatial_refine_only)
    new_num_frames = num_frames if spatial_refine_only else 2 * num_frames
    logger.info(
        "Refine mode: %s",
        'spatial only' if spatial_refine_only else 'spatial + temporal')

    # Update batch.num_frames to reflect the upsampled count
    batch.num_frames = new_num_frames

    # Use bucket system to select resolution (exactly like LongCat)
    # Calculate scale_factor_spatial considering SP split
    sp_size = fastvideo_args.sp_size if fastvideo_args.sp_size > 0 else 1
    vae_scale_factor_spatial = 8  # VAE spatial downsampling
    patch_size_spatial = 2  # LongCat patch size
    bsa_latent_granularity = 4
    scale_factor_spatial = vae_scale_factor_spatial * patch_size_spatial * bsa_latent_granularity  # 64

    # Calculate optimal split like LongCat (cp_split_hw logic)
    # For sp_size=1: [1,1], max=1
    # For sp_size=2: [1,2], max=2
    # For sp_size=4: [2,2], max=2
    # For sp_size=8: [2,4], max=4
    if sp_size > 1:
        # Get optimal 2D split factors (mimic context_parallel_util.get_optimal_split)
        factors = []
        for i in range(1, int(sp_size**0.5) + 1):
            if sp_size % i == 0:
                factors.append([i, sp_size // i])
        cp_split_hw = min(factors, key=lambda x: abs(x[0] - x[1]))
        scale_factor_spatial *= max(cp_split_hw)
        logger.info("SP split: sp_size=%s, cp_split_hw=%s, max_split=%s",
                    sp_size, cp_split_hw, max(cp_split_hw))
    else:
        cp_split_hw = [1, 1]

    # Get bucket config and find closest bucket for the input aspect ratio
    bucket_config = get_bucket_config('720p', scale_factor_spatial)

    # Get input aspect ratio from stage1 video
    input_height, input_width = pil_images[0].height, pil_images[0].width
    input_ratio = input_height / input_width

    # Find closest bucket
    closest_ratio = min(bucket_config.keys(),
                        key=lambda x: abs(float(x) - input_ratio))
    height, width = bucket_config[closest_ratio][0]

    logger.info("Input aspect ratio: %.2f (%sx%s)", input_ratio,
                input_width, input_height)
    logger.info("Matched bucket ratio: %s -> resolution: %sx%s",
                closest_ratio, width, height)
    logger.info("Target: %sx%s @ %s frames (sp_size=%s, scale_factor=%s)",
                width, height, new_num_frames, sp_size,
                scale_factor_spatial)

    # Override batch height/width with bucket-selected resolution
    batch.height = height
    batch.width = width

    # Convert PIL images to tensor [T, C, H, W]
    stage1_video_tensor = torch.stack([
        torch.from_numpy(np.array(img)).permute(2, 0, 1)  # HWC -> CHW
        for img in pil_images
    ]).float()  # [T, C, H, W]

    device = batch.prompt_embeds[0].device
    dtype = batch.prompt_embeds[0].dtype
    stage1_video_tensor = stage1_video_tensor.to(device=device, dtype=dtype)

    # Replicate LongCat's exact preprocessing (lines 1227-1235 in pipeline_longcat_video.py)
    # First: spatial interpolation to target (height, width) on [T, C, H, W]
    video_down = F.interpolate(stage1_video_tensor,
                               size=(height, width),
                               mode='bilinear',
                               align_corners=True)

    # Rearrange to [C, T, H, W] and add batch dimension -> [1, C, T, H, W]
    video_down = video_down.permute(1, 0, 2,
                                    3).unsqueeze(0)  # [1, C, T, H, W]
    video_down = video_down / 255.0  # Normalize to [0, 1]

    # Then: temporal+spatial interpolation to (new_num_frames, height, width)
    video_up = F.interpolate(video_down,
                             size=(new_num_frames, height, width),
                             mode='trilinear',
                             align_corners=True)

    # Rescale to [-1, 1] for VAE
    video_up = video_up * 2.0 - 1.0

    logger.info("Upsampled video shape: %s", video_up.shape)

    # Padding logic (exactly like LongCat lines 1237-1255)
    # Only pad temporal dimension to ensure BSA compatibility
    vae_scale_factor_temporal = 4
    num_noise_frames = video_up.shape[2] - num_cond_frames

    num_cond_latents = 0
    num_cond_frames_added = 0
    if num_cond_frames > 0:
        num_cond_latents = 1 + math.ceil(
            (num_cond_frames - 1) / vae_scale_factor_temporal)
        num_cond_latents = math.ceil(
            num_cond_latents /
            bsa_latent_granularity) * bsa_latent_granularity
        num_cond_frames_added = 1 + (
            num_cond_latents -
            1) * vae_scale_factor_temporal - num_cond_frames
        num_cond_frames = num_cond_frames + num_cond_frames_added

    num_noise_latents = math.ceil(num_noise_frames /
                                  vae_scale_factor_temporal)
    num_noise_latents = math.ceil(
        num_noise_latents / bsa_latent_granularity) * bsa_latent_granularity
    num_noise_frames_added = num_noise_latents * vae_scale_factor_temporal - num_noise_frames

    if num_cond_frames_added > 0 or num_noise_frames_added > 0:
        logger.info(
            "Padding temporal dimension for BSA: cond_frames+=%s, noise_frames+=%s",
            num_cond_frames_added, num_noise_frames_added)
        pad_front = video_up[:, :, 0:1].repeat(1, 1, num_cond_frames_added,
                                               1, 1)
        pad_back = video_up[:, :, -1:].repeat(1, 1, num_noise_frames_added,
                                              1, 1)
        video_up = torch.cat([pad_front, video_up, pad_back], dim=2)
        logger.info("Padded video shape: %s", video_up.shape)

    # Update batch with actual frame count after padding
    batch.num_frames = video_up.shape[2]

    # Store padding info for later cropping (CRITICAL for correct output!)
    batch.num_cond_frames_added = num_cond_frames_added
    batch.num_noise_frames_added = num_noise_frames_added
    batch.new_frame_size_before_padding = new_num_frames

    # Store num_cond_latents for denoising stage
    if num_cond_latents > 0:
        batch.num_cond_latents = num_cond_latents
        logger.info("Will use num_cond_latents=%s during denoising",
                    num_cond_latents)

    logger.info("Padding info: cond+=%s, noise+=%s, original=%s",
                num_cond_frames_added, num_noise_frames_added,
                new_num_frames)

    # VAE encode with tiling for memory efficiency
    logger.info("Encoding stage1 video with VAE (tiling enabled)...")
    vae_dtype = next(self.vae.parameters()).dtype
    vae_device = next(self.vae.parameters()).device
    video_up = video_up.to(dtype=vae_dtype, device=vae_device)

    # Enable tiling for large video encoding
    if hasattr(self.vae, 'enable_tiling'):
        self.vae.enable_tiling()

    with torch.no_grad():
        latent_dist = self.vae.encode(video_up)
        # Extract tensor from latent distribution
        if hasattr(latent_dist, 'latent_dist'):
            # Nested distribution wrapper
            latent_up = latent_dist.latent_dist.sample()
        elif hasattr(latent_dist, 'sample'):
            # DiagonalGaussianDistribution or similar
            latent_up = latent_dist.sample()
        elif hasattr(latent_dist, 'latents'):
            # Direct latents tensor
            latent_up = latent_dist.latents
        else:
            # Assume it's already a tensor
            latent_up = latent_dist

    # Normalize latents using VAE config (exactly like LongCat)
    if hasattr(self.vae.config, 'latents_mean') and hasattr(
            self.vae.config, 'latents_std'):
        latents_mean = torch.tensor(self.vae.config.latents_mean).view(
            1, self.vae.config.z_dim, 1, 1, 1).to(latent_up.device,
                                                  latent_up.dtype)
        # LongCat uses: 1.0 / latents_std (equivalent to dividing by latents_std)
        latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(
            1, self.vae.config.z_dim, 1, 1, 1).to(latent_up.device,
                                                  latent_up.dtype)
        # LongCat: (latents - mean) * (1/std)
        latent_up = (latent_up - latents_mean) * latents_std

    logger.info("Encoded latent shape: %s", latent_up.shape)

    # Mix with noise according to t_thresh
    # latent_up = (1 - t_thresh) * latent_up + t_thresh * noise
    noise = torch.randn_like(latent_up).contiguous()
    latent_up = (1 - t_thresh) * latent_up + t_thresh * noise

    logger.info("Applied t_thresh=%s noise mixing", t_thresh)

    # Store in batch - ensure correct dtype and device
    # The latents need to be on the same device as the transformer (CUDA)
    target_device = batch.prompt_embeds[0].device
    batch.latents = latent_up.to(device=target_device, dtype=dtype)
    batch.raw_latent_shape = latent_up.shape

    logger.info("Latents device: %s, dtype: %s", batch.latents.device,
                batch.latents.dtype)
    logger.info("LongCat refinement initialization complete")

    return batch

Functions

fastvideo.pipelines.stages.longcat_refine_timestep

LongCat refinement timestep preparation stage.

This stage prepares special timesteps for LongCat refinement that start from t_thresh.

Classes

fastvideo.pipelines.stages.longcat_refine_timestep.LongCatRefineTimestepStage
LongCatRefineTimestepStage(scheduler)

Bases: PipelineStage

Stage for preparing timesteps specific to LongCat refinement.

For refinement, we need to start from t_thresh instead of t=1.0, so we: 1. Generate normal timesteps for num_inference_steps 2. Filter to only keep timesteps < t_thresh * 1000 3. Prepend t_thresh * 1000 as the first timestep

Source code in fastvideo/pipelines/stages/longcat_refine_timestep.py
def __init__(self, scheduler) -> None:
    super().__init__()
    self.scheduler = scheduler
Functions
fastvideo.pipelines.stages.longcat_refine_timestep.LongCatRefineTimestepStage.forward
forward(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> ForwardBatch

Prepare refinement-specific timesteps.

Parameters:

Name Type Description Default
batch ForwardBatch

The current batch information.

required
fastvideo_args FastVideoArgs

The inference arguments.

required

Returns:

Type Description
ForwardBatch

The batch with refinement timesteps.

Source code in fastvideo/pipelines/stages/longcat_refine_timestep.py
def forward(
    self,
    batch: ForwardBatch,
    fastvideo_args: FastVideoArgs,
) -> ForwardBatch:
    """
    Prepare refinement-specific timesteps.

    Args:
        batch: The current batch information.
        fastvideo_args: The inference arguments.

    Returns:
        The batch with refinement timesteps.
    """
    # Only apply if this is a refinement task
    # Trigger when either a refine_from path or in-memory stage1_video is provided
    if batch.refine_from is None and batch.stage1_video is None:
        return batch

    device = get_local_torch_device()
    num_inference_steps = batch.num_inference_steps
    t_thresh = batch.t_thresh

    logger.info("Preparing LongCat refinement timesteps (t_thresh=%s)",
                t_thresh)

    # ------------------------------------------------------------------
    # 1) Match LongCatVideoPipeline.get_timesteps_sigmas (non-distill):
    #    sigmas = linspace(1, 0.001, num_inference_steps) on CPU
    # ------------------------------------------------------------------
    base_sigmas = torch.linspace(
        1.0,
        0.001,
        num_inference_steps,
        dtype=torch.float32,
        device=
        "cpu",  # scheduler.set_timesteps expects CPU-convertible sigmas
    )
    # Let the scheduler build its internal timestep schedule from sigmas
    self.scheduler.set_timesteps(num_inference_steps,
                                 sigmas=base_sigmas,
                                 device=device)
    base_timesteps = self.scheduler.timesteps

    # ------------------------------------------------------------------
    # 2) Apply t_thresh cropping exactly like generate_refine:
    #    timesteps = [t_thresh*1000] + [t for t in base_timesteps if t < t_thresh*1000]
    #    sigmas = timesteps / 1000  (with trailing zero)
    # ------------------------------------------------------------------
    t_thresh_value = t_thresh * 1000.0
    t_thresh_tensor = torch.tensor(t_thresh_value,
                                   dtype=base_timesteps.dtype,
                                   device=device)
    filtered_timesteps = base_timesteps[base_timesteps < t_thresh_tensor]

    timesteps = torch.cat(
        [t_thresh_tensor.unsqueeze(0), filtered_timesteps])

    # Update scheduler with these custom timesteps and corresponding sigmas
    self.scheduler.timesteps = timesteps
    sigmas = torch.cat([timesteps / 1000.0, torch.zeros(1, device=device)])
    self.scheduler.sigmas = sigmas

    logger.info("Refinement timesteps: %s steps starting from t=%s",
                len(timesteps), t_thresh)
    logger.info("First few timesteps: %s", timesteps[:5].tolist())

    # Store in batch so downstream stages (denoising) use the same schedule
    batch.timesteps = timesteps

    return batch

Functions

fastvideo.pipelines.stages.longcat_vc_denoising

LongCat VC Denoising Stage with KV cache support.

This stage extends the I2V denoising stage to support: 1. KV cache for conditioning frames 2. Video continuation with multiple conditioning frames

Classes

fastvideo.pipelines.stages.longcat_vc_denoising.LongCatVCDenoisingStage
LongCatVCDenoisingStage(transformer, scheduler, pipeline=None, transformer_2=None, vae=None)

Bases: LongCatDenoisingStage

LongCat denoising with Video Continuation and KV cache support.

Key differences from I2V denoising: - Supports KV cache (reuses cached K/V from conditioning frames) - Handles larger num_cond_latents - Concatenates conditioning latents back after denoising

When use_kv_cache=True: - batch.latents contains ONLY noise frames (cond removed by KV cache init) - batch.kv_cache_dict contains cached K/V - batch.cond_latents contains conditioning latents for post-concat

When use_kv_cache=False: - batch.latents contains ALL frames (cond + noise) - Timestep masking: timestep[:, :num_cond_latents] = 0 - Selective denoising: only update noise frames

Source code in fastvideo/pipelines/stages/denoising.py
def __init__(self,
             transformer,
             scheduler,
             pipeline=None,
             transformer_2=None,
             vae=None) -> None:
    super().__init__()
    self.transformer = transformer
    self.transformer_2 = transformer_2
    self.scheduler = scheduler
    self.vae = vae
    self.pipeline = weakref.ref(pipeline) if pipeline else None
    attn_head_size = self.transformer.hidden_size // self.transformer.num_attention_heads
    self.attn_backend = get_attn_backend(
        head_size=attn_head_size,
        dtype=torch.float16,  # TODO(will): hack
        supported_attention_backends=(
            AttentionBackendEnum.SLIDING_TILE_ATTN,
            AttentionBackendEnum.VIDEO_SPARSE_ATTN,
            AttentionBackendEnum.VMOBA_ATTN,
            AttentionBackendEnum.FLASH_ATTN,
            AttentionBackendEnum.TORCH_SDPA,
            AttentionBackendEnum.SAGE_ATTN_THREE)  # hack
    )
Functions
fastvideo.pipelines.stages.longcat_vc_denoising.LongCatVCDenoisingStage.forward
forward(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> ForwardBatch

Run denoising loop with VC conditioning and optional KV cache.

Source code in fastvideo/pipelines/stages/longcat_vc_denoising.py
def forward(
    self,
    batch: ForwardBatch,
    fastvideo_args: FastVideoArgs,
) -> ForwardBatch:
    """Run denoising loop with VC conditioning and optional KV cache."""

    # Load transformer if needed
    if not fastvideo_args.model_loaded["transformer"]:
        loader = TransformerLoader()
        self.transformer = loader.load(
            fastvideo_args.model_paths["transformer"], fastvideo_args)
        fastvideo_args.model_loaded["transformer"] = True

    # Setup
    target_dtype = torch.bfloat16
    autocast_enabled = (target_dtype != torch.float32
                        ) and not fastvideo_args.disable_autocast

    latents = batch.latents
    timesteps = batch.timesteps
    prompt_embeds = batch.prompt_embeds[0]
    prompt_attention_mask = (batch.prompt_attention_mask[0]
                             if batch.prompt_attention_mask else None)
    guidance_scale = batch.guidance_scale
    do_classifier_free_guidance = batch.do_classifier_free_guidance

    # Get VC-specific parameters
    num_cond_latents = getattr(batch, 'num_cond_latents', 0)
    use_kv_cache = getattr(batch, 'use_kv_cache', False)
    kv_cache_dict = getattr(batch, 'kv_cache_dict', {})

    logger.info(
        "VC Denoising: num_cond_latents=%d, use_kv_cache=%s, latent_shape=%s",
        num_cond_latents, use_kv_cache, latents.shape)

    # Prepare negative prompts for CFG
    if do_classifier_free_guidance:
        negative_prompt_embeds = batch.negative_prompt_embeds[0]
        negative_prompt_attention_mask = (batch.negative_attention_mask[0]
                                          if batch.negative_attention_mask
                                          else None)

        prompt_embeds_combined = torch.cat(
            [negative_prompt_embeds, prompt_embeds], dim=0)
        if prompt_attention_mask is not None:
            prompt_attention_mask_combined = torch.cat(
                [negative_prompt_attention_mask, prompt_attention_mask],
                dim=0)
        else:
            prompt_attention_mask_combined = None
    else:
        prompt_embeds_combined = prompt_embeds
        prompt_attention_mask_combined = prompt_attention_mask

    # Denoising loop
    num_inference_steps = len(timesteps)
    step_times = []

    with tqdm(total=num_inference_steps,
              desc="VC Denoising") as progress_bar:
        for i, t in enumerate(timesteps):
            step_start = time.time()

            # 1. Expand latents for CFG
            if do_classifier_free_guidance:
                latent_model_input = torch.cat([latents] * 2)
            else:
                latent_model_input = latents

            latent_model_input = latent_model_input.to(target_dtype)

            # 2. Expand timestep to match batch size
            timestep = t.expand(
                latent_model_input.shape[0]).to(target_dtype)

            # 3. Expand timestep to temporal dimension
            timestep = timestep.unsqueeze(-1).repeat(
                1, latent_model_input.shape[2])

            # 4. Timestep masking (only when NOT using KV cache)
            if not use_kv_cache and num_cond_latents > 0:
                timestep[:, :num_cond_latents] = 0

            # 5. Prepare transformer kwargs
            # IMPORTANT: num_cond_latents is ALWAYS passed - needed for RoPE position offset
            transformer_kwargs = {
                'num_cond_latents': num_cond_latents,
            }
            if use_kv_cache:
                transformer_kwargs['kv_cache_dict'] = kv_cache_dict

            # 6. Run transformer
            batch.is_cfg_negative = False
            with set_forward_context(
                    current_timestep=i,
                    attn_metadata=None,
                    forward_batch=batch,
            ), torch.autocast(device_type='cuda',
                              dtype=target_dtype,
                              enabled=autocast_enabled):
                noise_pred = self.transformer(
                    hidden_states=latent_model_input,
                    encoder_hidden_states=prompt_embeds_combined,
                    timestep=timestep,
                    encoder_attention_mask=prompt_attention_mask_combined,
                    **transformer_kwargs,
                )

            # 7. Apply CFG with optimized scale (CFG-zero)
            if do_classifier_free_guidance:
                noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)

                B = noise_pred_cond.shape[0]
                positive = noise_pred_cond.reshape(B, -1)
                negative = noise_pred_uncond.reshape(B, -1)

                st_star = self.optimized_scale(positive, negative)
                st_star = st_star.view(B, 1, 1, 1, 1)

                noise_pred = (
                    noise_pred_uncond * st_star + guidance_scale *
                    (noise_pred_cond - noise_pred_uncond * st_star))

            # 8. Negate for flow matching scheduler
            noise_pred = -noise_pred

            # 9. Scheduler step
            if use_kv_cache:
                # All latents are noise frames (conditioning is in cache)
                latents = self.scheduler.step(noise_pred,
                                              t,
                                              latents,
                                              return_dict=False)[0]
            else:
                # Only update noise frames (skip conditioning)
                if num_cond_latents > 0:
                    latents[:, :, num_cond_latents:] = self.scheduler.step(
                        noise_pred[:, :, num_cond_latents:],
                        t,
                        latents[:, :, num_cond_latents:],
                        return_dict=False,
                    )[0]
                else:
                    latents = self.scheduler.step(noise_pred,
                                                  t,
                                                  latents,
                                                  return_dict=False)[0]

            step_time = time.time() - step_start
            step_times.append(step_time)

            # Log timing for first few steps
            if i < 3:
                logger.info("Step %d: %.2fs", i, step_time)

            progress_bar.update()

    # 10. If using KV cache, concatenate conditioning latents back
    if use_kv_cache and hasattr(
            batch, 'cond_latents') and batch.cond_latents is not None:
        latents = torch.cat([batch.cond_latents, latents], dim=2)
        logger.info(
            "Concatenated conditioning latents back, final shape: %s",
            latents.shape)

    # Log average timing
    avg_time = sum(step_times) / len(step_times)
    logger.info("Average step time: %.2fs (total: %.1fs)", avg_time,
                sum(step_times))

    # Update batch with denoised latents
    batch.latents = latents
    return batch

Functions

fastvideo.pipelines.stages.longcat_video_vae_encoding

LongCat Video VAE Encoding Stage for Video Continuation (VC) generation.

This stage handles encoding multiple video frames to latent space with LongCat-specific normalization for VC conditioning.

Classes

fastvideo.pipelines.stages.longcat_video_vae_encoding.LongCatVideoVAEEncodingStage
LongCatVideoVAEEncodingStage(vae)

Bases: PipelineStage

Encode video frames to latent space for VC conditioning.

This stage: 1. Loads video frames from path or uses provided frames 2. Takes the last num_cond_frames from the video 3. Preprocesses and stacks frames 4. Encodes via VAE to latent space 5. Applies LongCat-specific normalization 6. Calculates num_cond_latents

Source code in fastvideo/pipelines/stages/longcat_video_vae_encoding.py
def __init__(self, vae):
    super().__init__()
    self.vae = vae
Functions
fastvideo.pipelines.stages.longcat_video_vae_encoding.LongCatVideoVAEEncodingStage.forward
forward(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> ForwardBatch

Encode video frames to latent for VC conditioning.

Source code in fastvideo/pipelines/stages/longcat_video_vae_encoding.py
def forward(
    self,
    batch: ForwardBatch,
    fastvideo_args: FastVideoArgs,
) -> ForwardBatch:
    """Encode video frames to latent for VC conditioning."""

    # Get video from batch - can be path, list of PIL images, or already loaded
    video = getattr(batch, 'video_frames', None) or getattr(
        batch, 'video_path', None)
    num_cond_frames = getattr(batch, 'num_cond_frames',
                              13)  # Default 13 for VC

    if video is None:
        raise ValueError(
            "video_frames or video_path must be provided for VC")

    # Load video if path
    if isinstance(video, str):
        from diffusers.utils import load_video
        video = load_video(video)
        logger.info("Loaded video from path: %d frames", len(video))

    # Take last num_cond_frames
    if len(video) > num_cond_frames:
        video = video[-num_cond_frames:]
        logger.info("Using last %d frames for conditioning",
                    num_cond_frames)
    elif len(video) < num_cond_frames:
        logger.warning(
            "Video has only %d frames, less than num_cond_frames=%d",
            len(video), num_cond_frames)
        num_cond_frames = len(video)

    # Get target dimensions
    height = batch.height
    width = batch.width

    if height is None or width is None:
        raise ValueError("height and width must be set for VC")

    # Preprocess and stack frames
    processed_frames = []
    for frame in video:
        if not isinstance(frame, PIL.Image.Image):
            raise TypeError(f"Frame must be PIL.Image, got {type(frame)}")

        frame = resize(frame, height, width, resize_mode="default")
        frame = pil_to_numpy(frame)  # Returns [1, H, W, C] then converted
        frame = numpy_to_pt(frame)  # Returns [1, C, H, W]
        frame = normalize(frame)  # to [-1, 1]
        processed_frames.append(frame)

    # Stack frames: [num_frames, C, H, W] -> [1, C, T, H, W]
    video_tensor = torch.cat(processed_frames, dim=0)  # [T, C, H, W]
    video_tensor = video_tensor.permute(1, 0, 2,
                                        3).unsqueeze(0)  # [1, C, T, H, W]
    video_tensor = video_tensor.to(get_local_torch_device(),
                                   dtype=torch.float32)

    logger.info("VC: Preprocessed video tensor shape: %s",
                video_tensor.shape)

    # Encode via VAE
    self.vae = self.vae.to(get_local_torch_device())

    # Setup VAE precision
    vae_dtype = PRECISION_TO_TYPE[
        fastvideo_args.pipeline_config.vae_precision]
    vae_autocast_enabled = (
        vae_dtype != torch.float32) and not fastvideo_args.disable_autocast

    with torch.autocast(device_type="cuda",
                        dtype=vae_dtype,
                        enabled=vae_autocast_enabled):
        if fastvideo_args.pipeline_config.vae_tiling:
            self.vae.enable_tiling()

        if not vae_autocast_enabled:
            video_tensor = video_tensor.to(vae_dtype)

        with torch.no_grad():
            encoder_output = self.vae.encode(video_tensor)
            latent = self.retrieve_latents(encoder_output, batch.generator)

    # Apply LongCat-specific normalization
    latent = self.normalize_latents(latent)

    # Calculate num_cond_latents
    # Formula: 1 + (num_cond_frames - 1) // vae_temporal_scale
    vae_temporal_scale = self.vae.config.scale_factor_temporal
    num_cond_latents = 1 + (num_cond_frames - 1) // vae_temporal_scale

    # Store in batch
    batch.video_latent = latent
    batch.num_cond_frames = num_cond_frames
    batch.num_cond_latents = num_cond_latents

    logger.info(
        "VC: Encoded %d frames to latent shape %s, num_cond_latents=%d",
        num_cond_frames, latent.shape, num_cond_latents)

    # Offload VAE if needed
    if fastvideo_args.vae_cpu_offload:
        self.vae.to("cpu")

    return batch
fastvideo.pipelines.stages.longcat_video_vae_encoding.LongCatVideoVAEEncodingStage.normalize_latents
normalize_latents(latents: Tensor) -> Tensor

Apply LongCat-specific latent normalization.

Formula: (latents - mean) / std

Source code in fastvideo/pipelines/stages/longcat_video_vae_encoding.py
def normalize_latents(self, latents: torch.Tensor) -> torch.Tensor:
    """
    Apply LongCat-specific latent normalization.

    Formula: (latents - mean) / std
    """
    if not hasattr(self.vae.config, 'latents_mean') or not hasattr(
            self.vae.config, 'latents_std'):
        raise ValueError(
            "VAE config must have 'latents_mean' and 'latents_std' "
            "for LongCat normalization")

    latents_mean = torch.tensor(self.vae.config.latents_mean).view(
        1, self.vae.config.z_dim, 1, 1, 1).to(latents.device, latents.dtype)

    latents_std = torch.tensor(self.vae.config.latents_std).view(
        1, self.vae.config.z_dim, 1, 1, 1).to(latents.device, latents.dtype)

    return (latents - latents_mean) / latents_std
fastvideo.pipelines.stages.longcat_video_vae_encoding.LongCatVideoVAEEncodingStage.retrieve_latents
retrieve_latents(encoder_output: Any, generator: Generator | None) -> Tensor

Sample from VAE posterior.

Source code in fastvideo/pipelines/stages/longcat_video_vae_encoding.py
def retrieve_latents(self, encoder_output: Any,
                     generator: torch.Generator | None) -> torch.Tensor:
    """Sample from VAE posterior."""
    if hasattr(encoder_output, 'sample'):
        return encoder_output.sample(generator)
    elif hasattr(encoder_output, 'latent_dist'):
        return encoder_output.latent_dist.sample(generator)
    elif hasattr(encoder_output, 'latents'):
        return encoder_output.latents
    else:
        raise AttributeError("Could not access latents from encoder output")

Functions

fastvideo.pipelines.stages.matrixgame_denoising

Classes

fastvideo.pipelines.stages.matrixgame_denoising.BlockProcessingContext dataclass
BlockProcessingContext(batch: ForwardBatch, block_idx: int, start_index: int, kv_cache1: list[dict[Any, Any]], kv_cache2: list[dict[Any, Any]] | None, kv_cache_mouse: list[dict[Any, Any]] | None, kv_cache_keyboard: list[dict[Any, Any]] | None, crossattn_cache: list[dict[Any, Any]], timesteps: Tensor, block_sizes: list[int], noise_pool: list[Tensor] | None, fastvideo_args: FastVideoArgs, target_dtype: dtype, autocast_enabled: bool, boundary_timestep: float | None, high_noise_timesteps: Tensor | None, context_noise: float, image_kwargs: dict[str, Any], pos_cond_kwargs: dict[str, Any])

Dataclass contains for block processing.

Functions

fastvideo.pipelines.stages.stepvideo_encoding

Classes

fastvideo.pipelines.stages.stepvideo_encoding.StepvideoPromptEncodingStage
StepvideoPromptEncodingStage(stepllm, clip)

Bases: PipelineStage

Stage for encoding prompts using the remote caption API.

This stage applies the magic string transformations and calls the remote caption service asynchronously to get: - primary prompt embeddings, - an attention mask, - and a clip embedding.

Source code in fastvideo/pipelines/stages/stepvideo_encoding.py
def __init__(self, stepllm, clip) -> None:
    super().__init__()
    # self.caption_client = caption_client  # This should have a call_caption(prompts: List[str]) method.
    self.stepllm = stepllm
    self.clip = clip
Functions
fastvideo.pipelines.stages.stepvideo_encoding.StepvideoPromptEncodingStage.verify_input
verify_input(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult

Verify stepvideo encoding stage inputs.

Source code in fastvideo/pipelines/stages/stepvideo_encoding.py
def verify_input(self, batch: ForwardBatch,
                 fastvideo_args: FastVideoArgs) -> VerificationResult:
    """Verify stepvideo encoding stage inputs."""
    result = VerificationResult()
    result.add_check("prompt", batch.prompt, V.string_not_empty)
    return result
fastvideo.pipelines.stages.stepvideo_encoding.StepvideoPromptEncodingStage.verify_output
verify_output(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult

Verify stepvideo encoding stage outputs.

Source code in fastvideo/pipelines/stages/stepvideo_encoding.py
def verify_output(self, batch: ForwardBatch,
                  fastvideo_args: FastVideoArgs) -> VerificationResult:
    """Verify stepvideo encoding stage outputs."""
    result = VerificationResult()
    result.add_check("prompt_embeds", batch.prompt_embeds,
                     [V.is_tensor, V.with_dims(3)])
    result.add_check("negative_prompt_embeds", batch.negative_prompt_embeds,
                     [V.is_tensor, V.with_dims(3)])
    result.add_check("prompt_attention_mask", batch.prompt_attention_mask,
                     [V.is_tensor, V.with_dims(2)])
    result.add_check("negative_attention_mask",
                     batch.negative_attention_mask,
                     [V.is_tensor, V.with_dims(2)])
    result.add_check("clip_embedding_pos", batch.clip_embedding_pos,
                     [V.is_tensor, V.with_dims(2)])
    result.add_check("clip_embedding_neg", batch.clip_embedding_neg,
                     [V.is_tensor, V.with_dims(2)])
    return result

Functions

fastvideo.pipelines.stages.text_encoding

Prompt encoding stages for diffusion pipelines.

This module contains implementations of prompt encoding stages for diffusion pipelines.

Classes

fastvideo.pipelines.stages.text_encoding.Cosmos25TextEncodingStage
Cosmos25TextEncodingStage(text_encoder)

Bases: PipelineStage

Cosmos 2.5 text encoding stage.

Cosmos 2.5 uses Reason1 (Qwen2.5-VL) and relies on the encoder's compute_text_embeddings_online().

Source code in fastvideo/pipelines/stages/text_encoding.py
def __init__(self, text_encoder) -> None:
    super().__init__()
    self.text_encoder = text_encoder
fastvideo.pipelines.stages.text_encoding.TextEncodingStage
TextEncodingStage(text_encoders, tokenizers)

Bases: PipelineStage

Stage for encoding text prompts into embeddings for diffusion models.

This stage handles the encoding of text prompts into the embedding space expected by the diffusion model.

Initialize the prompt encoding stage.

Parameters:

Name Type Description Default
enable_logging

Whether to enable logging for this stage.

required
is_secondary

Whether this is a secondary text encoder.

required
Source code in fastvideo/pipelines/stages/text_encoding.py
def __init__(self, text_encoders, tokenizers) -> None:
    """
    Initialize the prompt encoding stage.

    Args:
        enable_logging: Whether to enable logging for this stage.
        is_secondary: Whether this is a secondary text encoder.
    """
    super().__init__()
    self.tokenizers = tokenizers
    self.text_encoders = text_encoders
Functions
fastvideo.pipelines.stages.text_encoding.TextEncodingStage.encode_text
encode_text(text: str | list[str], fastvideo_args: FastVideoArgs, encoder_index: int | list[int] | None = None, return_attention_mask: bool = False, return_type: str = 'list', device: device | str | None = None, dtype: dtype | None = None, max_length: int | None = None, truncation: bool | None = None, padding: bool | str | None = None)

Encode plain text using selected text encoder(s) and return embeddings.

Parameters:

Name Type Description Default
text str | list[str]

A single string or a list of strings to encode.

required
fastvideo_args FastVideoArgs

The inference arguments providing pipeline config, including tokenizer and encoder settings, preprocess and postprocess functions.

required
encoder_index int | list[int] | None

Encoder selector by index. Accepts an int or list of ints.

None
return_attention_mask bool

If True, also return attention masks for each selected encoder.

False
return_type str

"list" (default) returns a list aligned with selection; "dict" returns a dict keyed by encoder index as a string; "stack" stacks along a new first dimension (requires matching shapes).

'list'
device device | str | None

Optional device override for inputs; defaults to local torch device.

None
dtype dtype | None

Optional dtype to cast returned embeddings to.

None
max_length int | None

Optional per-call tokenizer override.

None
truncation bool | None

Optional per-call tokenizer override.

None
padding bool | str | None

Optional per-call tokenizer override.

None

Returns:

Type Description

Depending on return_type and return_attention_mask:

  • list: List[Tensor] or (List[Tensor], List[Tensor])
  • dict: Dict[str, Tensor] or (Dict[str, Tensor], Dict[str, Tensor])
  • stack: Tensor of shape [num_encoders, ...] or a tuple with stacked attention masks
Source code in fastvideo/pipelines/stages/text_encoding.py
@torch.no_grad()
def encode_text(
    self,
    text: str | list[str],
    fastvideo_args: FastVideoArgs,
    encoder_index: int | list[int] | None = None,
    return_attention_mask: bool = False,
    return_type: str = "list",  # one of: "list", "dict", "stack"
    device: torch.device | str | None = None,
    dtype: torch.dtype | None = None,
    max_length: int | None = None,
    truncation: bool | None = None,
    padding: bool | str | None = None,
):
    """
    Encode plain text using selected text encoder(s) and return embeddings.

    Args:
        text: A single string or a list of strings to encode.
        fastvideo_args: The inference arguments providing pipeline config,
            including tokenizer and encoder settings, preprocess and postprocess
            functions.
        encoder_index: Encoder selector by index. Accepts an int or list of ints.
        return_attention_mask: If True, also return attention masks for each
            selected encoder.
        return_type: "list" (default) returns a list aligned with selection;
            "dict" returns a dict keyed by encoder index as a string; "stack" stacks along a
            new first dimension (requires matching shapes).
        device: Optional device override for inputs; defaults to local torch device.
        dtype: Optional dtype to cast returned embeddings to.
        max_length: Optional per-call tokenizer override.
        truncation: Optional per-call tokenizer override.
        padding: Optional per-call tokenizer override.

    Returns:
        Depending on return_type and return_attention_mask:
        - list: List[Tensor] or (List[Tensor], List[Tensor])
        - dict: Dict[str, Tensor] or (Dict[str, Tensor], Dict[str, Tensor])
        - stack: Tensor of shape [num_encoders, ...] or a tuple with stacked
          attention masks
    """

    assert len(self.tokenizers) == len(self.text_encoders)
    assert len(self.text_encoders) == len(
        fastvideo_args.pipeline_config.text_encoder_configs)

    # Resolve selection into indices
    encoder_cfgs = fastvideo_args.pipeline_config.text_encoder_configs
    if encoder_index is None:
        indices: list[int] = [0]
    elif isinstance(encoder_index, int):
        indices = [encoder_index]
    else:
        indices = list(encoder_index)
    # validate range
    num_encoders = len(self.text_encoders)
    for idx in indices:
        if idx < 0 or idx >= num_encoders:
            raise IndexError(
                f"encoder index {idx} out of range [0, {num_encoders-1}]")

    # Validate indices are within range
    num_encoders = len(self.text_encoders)

    # Normalize input to list[str]
    assert isinstance(text, str | list)
    if isinstance(text, str):
        texts: list[str] = [text]
    else:
        texts = text

    embeds_list: list[torch.Tensor] = []
    attn_masks_list: list[torch.Tensor] = []

    preprocess_funcs = fastvideo_args.pipeline_config.preprocess_text_funcs
    postprocess_funcs = fastvideo_args.pipeline_config.postprocess_text_funcs
    encoder_cfgs = fastvideo_args.pipeline_config.text_encoder_configs

    if return_type not in ("list", "dict", "stack"):
        raise ValueError(
            f"Invalid return_type '{return_type}'. Expected one of: 'list', 'dict', 'stack'"
        )

    target_device = device if device is not None else get_local_torch_device(
    )

    for i in indices:
        tokenizer = self.tokenizers[i]
        text_encoder = self.text_encoders[i]
        encoder_config = encoder_cfgs[i]
        preprocess_func = preprocess_funcs[i]
        postprocess_func = postprocess_funcs[i]

        tok_kwargs = dict(encoder_config.tokenizer_kwargs)
        if max_length is not None:
            tok_kwargs["max_length"] = max_length
        elif hasattr(fastvideo_args.pipeline_config,
                     "text_encoder_max_lengths"):
            tok_kwargs[
                "max_length"] = fastvideo_args.pipeline_config.text_encoder_max_lengths[
                    i]

        if truncation is not None:
            tok_kwargs["truncation"] = truncation
        if padding is not None:
            tok_kwargs["padding"] = padding

        processed_texts: list[str] = []
        for prompt_str in texts:
            processed_text = preprocess_func(prompt_str)
            if processed_text is not None:
                processed_texts.append(processed_text)
            else:
                # Assuming batch_size = 1
                prompt_embeds = torch.zeros((1, tok_kwargs["max_length"],
                                             encoder_config.hidden_size),
                                            device=target_device)
                attention_mask = torch.zeros((1, tok_kwargs["max_length"]),
                                             device=target_device,
                                             dtype=torch.int64)
                embeds_list.append(prompt_embeds)
                attn_masks_list.append(attention_mask)
                return self.return_embeds(embeds_list, attn_masks_list,
                                          return_type,
                                          return_attention_mask, indices)

        if encoder_config.is_chat_model:
            text_inputs = tokenizer.apply_chat_template(
                processed_texts, **tok_kwargs).to(target_device)
        else:
            text_inputs = tokenizer(processed_texts,
                                    **tok_kwargs).to(target_device)

        input_ids = text_inputs["input_ids"]
        attention_mask = text_inputs["attention_mask"]

        with set_forward_context(current_timestep=0, attn_metadata=None):
            outputs = text_encoder(
                input_ids=input_ids,
                attention_mask=attention_mask,
                output_hidden_states=True,
            )

        try:
            prompt_embeds = postprocess_func(outputs)
        except Exception:
            prompt_embeds, attention_mask = postprocess_func(
                outputs, attention_mask)

        if dtype is not None:
            prompt_embeds = prompt_embeds.to(dtype=dtype)
        embeds_list.append(prompt_embeds)
        if return_attention_mask:
            attn_masks_list.append(attention_mask)

    return self.return_embeds(embeds_list, attn_masks_list, return_type,
                              return_attention_mask, indices)
fastvideo.pipelines.stages.text_encoding.TextEncodingStage.forward
forward(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> ForwardBatch

Encode the prompt into text encoder hidden states.

Parameters:

Name Type Description Default
batch ForwardBatch

The current batch information.

required
fastvideo_args FastVideoArgs

The inference arguments.

required

Returns:

Type Description
ForwardBatch

The batch with encoded prompt embeddings.

Source code in fastvideo/pipelines/stages/text_encoding.py
@torch.no_grad()
def forward(
    self,
    batch: ForwardBatch,
    fastvideo_args: FastVideoArgs,
) -> ForwardBatch:
    """
    Encode the prompt into text encoder hidden states.

    Args:
        batch: The current batch information.
        fastvideo_args: The inference arguments.

    Returns:
        The batch with encoded prompt embeddings.
    """
    assert len(self.tokenizers) == len(self.text_encoders)
    assert len(self.text_encoders) == len(
        fastvideo_args.pipeline_config.text_encoder_configs)

    # Encode positive prompt with all available encoders
    assert batch.prompt is not None
    prompt_text: str | list[str] = batch.prompt
    all_indices: list[int] = list(range(len(self.text_encoders)))
    prompt_embeds_list, prompt_masks_list = self.encode_text(
        prompt_text,
        fastvideo_args,
        encoder_index=all_indices,
        return_attention_mask=True,
    )

    for pe in prompt_embeds_list:
        batch.prompt_embeds.append(pe)
    if batch.prompt_attention_mask is not None:
        for am in prompt_masks_list:
            batch.prompt_attention_mask.append(am)

    # Encode negative prompt if CFG is enabled
    if batch.do_classifier_free_guidance:
        assert isinstance(batch.negative_prompt, str)
        neg_embeds_list, neg_masks_list = self.encode_text(
            batch.negative_prompt,
            fastvideo_args,
            encoder_index=all_indices,
            return_attention_mask=True,
        )

        assert batch.negative_prompt_embeds is not None
        for ne in neg_embeds_list:
            batch.negative_prompt_embeds.append(ne)
        if batch.negative_attention_mask is not None:
            for nm in neg_masks_list:
                batch.negative_attention_mask.append(nm)

    return batch
fastvideo.pipelines.stages.text_encoding.TextEncodingStage.verify_input
verify_input(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult

Verify text encoding stage inputs.

Source code in fastvideo/pipelines/stages/text_encoding.py
def verify_input(self, batch: ForwardBatch,
                 fastvideo_args: FastVideoArgs) -> VerificationResult:
    """Verify text encoding stage inputs."""
    result = VerificationResult()
    result.add_check("prompt", batch.prompt, V.string_or_list_strings)
    # result.add_check(
    #     "negative_prompt", batch.negative_prompt, lambda x: not batch.
    #     do_classifier_free_guidance or V.string_not_empty(x))
    result.add_check("do_classifier_free_guidance",
                     batch.do_classifier_free_guidance, V.bool_value)
    result.add_check("prompt_embeds", batch.prompt_embeds, V.is_list)
    result.add_check("negative_prompt_embeds", batch.negative_prompt_embeds,
                     V.none_or_list)
    return result
fastvideo.pipelines.stages.text_encoding.TextEncodingStage.verify_output
verify_output(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult

Verify text encoding stage outputs.

Source code in fastvideo/pipelines/stages/text_encoding.py
def verify_output(self, batch: ForwardBatch,
                  fastvideo_args: FastVideoArgs) -> VerificationResult:
    """Verify text encoding stage outputs."""
    result = VerificationResult()
    result.add_check("prompt_embeds", batch.prompt_embeds,
                     V.list_of_tensors_min_dims(2))
    result.add_check(
        "negative_prompt_embeds", batch.negative_prompt_embeds,
        lambda x: not batch.do_classifier_free_guidance or V.
        list_of_tensors_with_min_dims(x, 2))
    return result

Functions

fastvideo.pipelines.stages.timestep_preparation

Timestep preparation stages for diffusion pipelines.

This module contains implementations of timestep preparation stages for diffusion pipelines.

Classes

fastvideo.pipelines.stages.timestep_preparation.Cosmos25TimestepPreparationStage
Cosmos25TimestepPreparationStage(scheduler)

Bases: TimestepPreparationStage

Cosmos 2.5 timestep preparation with scheduler-specific kwargs.

Source code in fastvideo/pipelines/stages/timestep_preparation.py
def __init__(self, scheduler) -> None:
    self.scheduler = scheduler
fastvideo.pipelines.stages.timestep_preparation.TimestepPreparationStage
TimestepPreparationStage(scheduler)

Bases: PipelineStage

Stage for preparing timesteps for the diffusion process.

This stage handles the preparation of the timestep sequence that will be used during the diffusion process.

Source code in fastvideo/pipelines/stages/timestep_preparation.py
def __init__(self, scheduler) -> None:
    self.scheduler = scheduler
Functions
fastvideo.pipelines.stages.timestep_preparation.TimestepPreparationStage.forward
forward(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> ForwardBatch

Prepare timesteps for the diffusion process.

Parameters:

Name Type Description Default
batch ForwardBatch

The current batch information.

required
fastvideo_args FastVideoArgs

The inference arguments.

required

Returns:

Type Description
ForwardBatch

The batch with prepared timesteps.

Source code in fastvideo/pipelines/stages/timestep_preparation.py
def forward(
    self,
    batch: ForwardBatch,
    fastvideo_args: FastVideoArgs,
) -> ForwardBatch:
    """
    Prepare timesteps for the diffusion process.

    Args:
        batch: The current batch information.
        fastvideo_args: The inference arguments.

    Returns:
        The batch with prepared timesteps.
    """
    scheduler = self.scheduler
    device = get_local_torch_device()
    num_inference_steps = batch.num_inference_steps
    timesteps = batch.timesteps
    sigmas = batch.sigmas
    n_tokens = batch.n_tokens

    # Prepare extra kwargs for set_timesteps
    extra_set_timesteps_kwargs = {}
    if n_tokens is not None and "n_tokens" in inspect.signature(
            scheduler.set_timesteps).parameters:
        extra_set_timesteps_kwargs["n_tokens"] = n_tokens

    # Handle custom timesteps or sigmas
    if timesteps is not None and sigmas is not None:
        raise ValueError(
            "Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values"
        )

    if timesteps is not None:
        accepts_timesteps = "timesteps" in inspect.signature(
            scheduler.set_timesteps).parameters
        if not accepts_timesteps:
            raise ValueError(
                f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
                f" timestep schedules. Please check whether you are using the correct scheduler."
            )
        # Convert timesteps to CPU if it's a tensor (for numpy conversion in scheduler)
        if isinstance(timesteps, torch.Tensor):
            timesteps_for_scheduler = timesteps.cpu()
        else:
            timesteps_for_scheduler = timesteps
        scheduler.set_timesteps(timesteps=timesteps_for_scheduler,
                                device=device,
                                **extra_set_timesteps_kwargs)
        timesteps = scheduler.timesteps
    elif sigmas is not None:
        accept_sigmas = "sigmas" in inspect.signature(
            scheduler.set_timesteps).parameters
        if not accept_sigmas:
            raise ValueError(
                f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
                f" sigmas schedules. Please check whether you are using the correct scheduler."
            )
        scheduler.set_timesteps(sigmas=sigmas,
                                device=device,
                                **extra_set_timesteps_kwargs)
        timesteps = scheduler.timesteps
    else:
        scheduler.set_timesteps(num_inference_steps,
                                device=device,
                                **extra_set_timesteps_kwargs)
        timesteps = scheduler.timesteps

    # Update batch with prepared timesteps
    batch.timesteps = timesteps

    return batch
fastvideo.pipelines.stages.timestep_preparation.TimestepPreparationStage.verify_input
verify_input(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult

Verify timestep preparation stage inputs.

Source code in fastvideo/pipelines/stages/timestep_preparation.py
def verify_input(self, batch: ForwardBatch,
                 fastvideo_args: FastVideoArgs) -> VerificationResult:
    """Verify timestep preparation stage inputs."""
    result = VerificationResult()
    result.add_check("num_inference_steps", batch.num_inference_steps,
                     V.positive_int)
    result.add_check("timesteps", batch.timesteps, V.none_or_tensor)
    result.add_check("sigmas", batch.sigmas, V.none_or_list)
    result.add_check("n_tokens", batch.n_tokens, V.none_or_positive_int)
    return result
fastvideo.pipelines.stages.timestep_preparation.TimestepPreparationStage.verify_output
verify_output(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult

Verify timestep preparation stage outputs.

Source code in fastvideo/pipelines/stages/timestep_preparation.py
def verify_output(self, batch: ForwardBatch,
                  fastvideo_args: FastVideoArgs) -> VerificationResult:
    """Verify timestep preparation stage outputs."""
    result = VerificationResult()
    result.add_check("timesteps", batch.timesteps,
                     [V.is_tensor, V.with_dims(1)])
    return result

Functions

fastvideo.pipelines.stages.utils

Utility functions for pipeline stages.

Functions

fastvideo.pipelines.stages.utils.retrieve_timesteps
retrieve_timesteps(scheduler: Any, num_inference_steps: int | None = None, device: str | device | None = None, timesteps: list[int] | None = None, sigmas: list[float] | None = None, **kwargs: Any) -> tuple[Any, int]

Calls the scheduler's set_timesteps method and retrieves timesteps from the scheduler after the call. Handles custom timesteps. Any kwargs will be supplied to scheduler.set_timesteps.

Parameters:

Name Type Description Default
scheduler `SchedulerMixin`

The scheduler to get timesteps from.

required
num_inference_steps `int`

The number of diffusion steps used when generating samples with a pre-trained model. If used, timesteps must be None.

None
device `str` or `torch.device`, *optional*

The device to which the timesteps should be moved to. If None, the timesteps are not moved.

None
timesteps `List[int]`, *optional*

Custom timesteps used to override the timestep spacing strategy of the scheduler. If timesteps is passed, num_inference_steps and sigmas must be None.

None
sigmas `List[float]`, *optional*

Custom sigmas used to override the timestep spacing strategy of the scheduler. If sigmas is passed, num_inference_steps and timesteps must be None.

None

Returns:

Type Description
Any

Tuple[torch.Tensor, int]: A tuple where the first element is the timestep schedule and the

int

second element is the number of inference steps.

Source code in fastvideo/pipelines/stages/utils.py
def retrieve_timesteps(
    scheduler: Any,
    num_inference_steps: int | None = None,
    device: str | torch.device | None = None,
    timesteps: list[int] | None = None,
    sigmas: list[float] | None = None,
    **kwargs: Any,
) -> tuple[Any, int]:
    """
    Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
    custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.

    Args:
        scheduler (`SchedulerMixin`):
            The scheduler to get timesteps from.
        num_inference_steps (`int`):
            The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
            must be `None`.
        device (`str` or `torch.device`, *optional*):
            The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
        timesteps (`List[int]`, *optional*):
            Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
            `num_inference_steps` and `sigmas` must be `None`.
        sigmas (`List[float]`, *optional*):
            Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
            `num_inference_steps` and `timesteps` must be `None`.

    Returns:
        `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule and the
        second element is the number of inference steps.
    """
    if timesteps is not None and sigmas is not None:
        raise ValueError(
            "Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values"
        )
    if timesteps is not None:
        accepts_timesteps = "timesteps" in set(
            inspect.signature(scheduler.set_timesteps).parameters.keys())
        if not accepts_timesteps:
            raise ValueError(
                f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
                f" timestep schedules. Please check whether you are using the correct scheduler."
            )
        scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
        timesteps = scheduler.timesteps
        if timesteps is None:
            raise ValueError("scheduler.timesteps is None after set_timesteps")
        num_inference_steps = len(timesteps)
    elif sigmas is not None:
        accept_sigmas = "sigmas" in set(
            inspect.signature(scheduler.set_timesteps).parameters.keys())
        if not accept_sigmas:
            raise ValueError(
                f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
                f" sigmas schedules. Please check whether you are using the correct scheduler."
            )
        scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
        timesteps = scheduler.timesteps
        if timesteps is None:
            raise ValueError("scheduler.timesteps is None after set_timesteps")
        num_inference_steps = len(timesteps)
    else:
        scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
        timesteps = scheduler.timesteps
        if timesteps is None:
            raise ValueError("scheduler.timesteps is None after set_timesteps")
        num_inference_steps = len(timesteps)
    return timesteps, num_inference_steps

fastvideo.pipelines.stages.validators

Common validators for pipeline stage verification.

This module provides reusable validation functions that can be used across all pipeline stages for input/output verification.

Classes

fastvideo.pipelines.stages.validators.StageValidators

Common validators for pipeline stages.

Functions
fastvideo.pipelines.stages.validators.StageValidators.bool_value staticmethod
bool_value(value: Any) -> bool

Check if value is a boolean.

Source code in fastvideo/pipelines/stages/validators.py
@staticmethod
def bool_value(value: Any) -> bool:
    """Check if value is a boolean."""
    return isinstance(value, bool)
fastvideo.pipelines.stages.validators.StageValidators.divisible staticmethod
divisible(divisor: int) -> Callable[[Any], bool]

Return a validator that checks if value is divisible by divisor.

Source code in fastvideo/pipelines/stages/validators.py
@staticmethod
def divisible(divisor: int) -> Callable[[Any], bool]:
    """Return a validator that checks if value is divisible by divisor."""

    def validator(value: Any) -> bool:
        return StageValidators.divisible_by(value, divisor)

    return validator
fastvideo.pipelines.stages.validators.StageValidators.divisible_by staticmethod
divisible_by(value: Any, divisor: int) -> bool

Check if value is divisible by divisor.

Source code in fastvideo/pipelines/stages/validators.py
@staticmethod
def divisible_by(value: Any, divisor: int) -> bool:
    """Check if value is divisible by divisor."""
    return value is not None and isinstance(value,
                                            int) and value % divisor == 0
fastvideo.pipelines.stages.validators.StageValidators.generator_or_list_generators staticmethod
generator_or_list_generators(value: Any) -> bool

Check if value is a Generator or list of Generators.

Source code in fastvideo/pipelines/stages/validators.py
@staticmethod
def generator_or_list_generators(value: Any) -> bool:
    """Check if value is a Generator or list of Generators."""
    if isinstance(value, torch.Generator):
        return True
    if isinstance(value, list):
        return all(isinstance(item, torch.Generator) for item in value)
    return False
fastvideo.pipelines.stages.validators.StageValidators.is_list staticmethod
is_list(value: Any) -> bool

Check if value is a list (can be empty).

Source code in fastvideo/pipelines/stages/validators.py
@staticmethod
def is_list(value: Any) -> bool:
    """Check if value is a list (can be empty)."""
    return isinstance(value, list)
fastvideo.pipelines.stages.validators.StageValidators.is_tensor staticmethod
is_tensor(value: Any) -> bool

Check if value is a torch tensor and doesn't contain NaN values.

Source code in fastvideo/pipelines/stages/validators.py
@staticmethod
def is_tensor(value: Any) -> bool:
    """Check if value is a torch tensor and doesn't contain NaN values."""
    if not isinstance(value, torch.Tensor):
        return False
    return not torch.isnan(value).any().item()
fastvideo.pipelines.stages.validators.StageValidators.is_tuple staticmethod
is_tuple(value: Any) -> bool

Check if value is a tuple.

Source code in fastvideo/pipelines/stages/validators.py
@staticmethod
def is_tuple(value: Any) -> bool:
    """Check if value is a tuple."""
    return isinstance(value, tuple)
fastvideo.pipelines.stages.validators.StageValidators.list_length staticmethod
list_length(value: Any, length: int) -> bool

Check if list has specific length.

Source code in fastvideo/pipelines/stages/validators.py
@staticmethod
def list_length(value: Any, length: int) -> bool:
    """Check if list has specific length."""
    return isinstance(value, list) and len(value) == length
fastvideo.pipelines.stages.validators.StageValidators.list_min_length staticmethod
list_min_length(value: Any, min_length: int) -> bool

Check if list has at least min_length items.

Source code in fastvideo/pipelines/stages/validators.py
@staticmethod
def list_min_length(value: Any, min_length: int) -> bool:
    """Check if list has at least min_length items."""
    return isinstance(value, list) and len(value) >= min_length
fastvideo.pipelines.stages.validators.StageValidators.list_not_empty staticmethod
list_not_empty(value: Any) -> bool

Check if value is a non-empty list.

Source code in fastvideo/pipelines/stages/validators.py
@staticmethod
def list_not_empty(value: Any) -> bool:
    """Check if value is a non-empty list."""
    return isinstance(value, list) and len(value) > 0
fastvideo.pipelines.stages.validators.StageValidators.list_of_tensors staticmethod
list_of_tensors(value: Any) -> bool

Check if value is a non-empty list where all items are tensors without NaN values.

Source code in fastvideo/pipelines/stages/validators.py
@staticmethod
def list_of_tensors(value: Any) -> bool:
    """Check if value is a non-empty list where all items are tensors without NaN values."""
    if not isinstance(value, list) or len(value) == 0:
        return False
    for item in value:
        if not isinstance(item, torch.Tensor):
            return False
        if torch.isnan(item).any().item():
            return False
    return True
fastvideo.pipelines.stages.validators.StageValidators.list_of_tensors_dims staticmethod
list_of_tensors_dims(dims: int) -> Callable[[Any], bool]

Return a validator that checks if value is a list of tensors with specific dimensions and no NaN values.

Source code in fastvideo/pipelines/stages/validators.py
@staticmethod
def list_of_tensors_dims(dims: int) -> Callable[[Any], bool]:
    """Return a validator that checks if value is a list of tensors with specific dimensions and no NaN values."""

    def validator(value: Any) -> bool:
        return StageValidators.list_of_tensors_with_dims(value, dims)

    return validator
fastvideo.pipelines.stages.validators.StageValidators.list_of_tensors_min_dims staticmethod
list_of_tensors_min_dims(min_dims: int) -> Callable[[Any], bool]

Return a validator that checks if value is a list of tensors with at least min_dims dimensions and no NaN values.

Source code in fastvideo/pipelines/stages/validators.py
@staticmethod
def list_of_tensors_min_dims(min_dims: int) -> Callable[[Any], bool]:
    """Return a validator that checks if value is a list of tensors with at least min_dims dimensions and no NaN values."""

    def validator(value: Any) -> bool:
        return StageValidators.list_of_tensors_with_min_dims(
            value, min_dims)

    return validator
fastvideo.pipelines.stages.validators.StageValidators.list_of_tensors_with_dims staticmethod
list_of_tensors_with_dims(value: Any, dims: int) -> bool

Check if value is a non-empty list where all items are tensors with specific dimensions and no NaN values.

Source code in fastvideo/pipelines/stages/validators.py
@staticmethod
def list_of_tensors_with_dims(value: Any, dims: int) -> bool:
    """Check if value is a non-empty list where all items are tensors with specific dimensions and no NaN values."""
    if not isinstance(value, list) or len(value) == 0:
        return False
    for item in value:
        if not isinstance(item, torch.Tensor):
            return False
        if item.dim() != dims:
            return False
        if torch.isnan(item).any().item():
            return False
    return True
fastvideo.pipelines.stages.validators.StageValidators.list_of_tensors_with_min_dims staticmethod
list_of_tensors_with_min_dims(value: Any, min_dims: int) -> bool

Check if value is a non-empty list where all items are tensors with at least min_dims dimensions and no NaN values.

Source code in fastvideo/pipelines/stages/validators.py
@staticmethod
def list_of_tensors_with_min_dims(value: Any, min_dims: int) -> bool:
    """Check if value is a non-empty list where all items are tensors with at least min_dims dimensions and no NaN values."""
    if not isinstance(value, list) or len(value) == 0:
        return False
    for item in value:
        if not isinstance(item, torch.Tensor):
            return False
        if item.dim() < min_dims:
            return False
        if torch.isnan(item).any().item():
            return False
    return True
fastvideo.pipelines.stages.validators.StageValidators.min_dims staticmethod
min_dims(min_dims: int) -> Callable[[Any], bool]

Return a validator that checks if tensor has at least min_dims dimensions and no NaN values.

Source code in fastvideo/pipelines/stages/validators.py
@staticmethod
def min_dims(min_dims: int) -> Callable[[Any], bool]:
    """Return a validator that checks if tensor has at least min_dims dimensions and no NaN values."""

    def validator(value: Any) -> bool:
        return StageValidators.tensor_min_dims(value, min_dims)

    return validator
fastvideo.pipelines.stages.validators.StageValidators.non_negative_float staticmethod
non_negative_float(value: Any) -> bool

Check if value is a non-negative float.

Source code in fastvideo/pipelines/stages/validators.py
@staticmethod
def non_negative_float(value: Any) -> bool:
    """Check if value is a non-negative float."""
    return isinstance(value, int | float) and value >= 0
fastvideo.pipelines.stages.validators.StageValidators.none_or_list staticmethod
none_or_list(value: Any) -> bool

Check if value is None or a list.

Source code in fastvideo/pipelines/stages/validators.py
@staticmethod
def none_or_list(value: Any) -> bool:
    """Check if value is None or a list."""
    return value is None or isinstance(value, list)
fastvideo.pipelines.stages.validators.StageValidators.none_or_positive_int staticmethod
none_or_positive_int(value: Any) -> bool

Check if value is None or a positive integer.

Source code in fastvideo/pipelines/stages/validators.py
@staticmethod
def none_or_positive_int(value: Any) -> bool:
    """Check if value is None or a positive integer."""
    return value is None or (isinstance(value, int) and value > 0)
fastvideo.pipelines.stages.validators.StageValidators.none_or_tensor staticmethod
none_or_tensor(value: Any) -> bool

Check if value is None or a tensor without NaN values.

Source code in fastvideo/pipelines/stages/validators.py
@staticmethod
def none_or_tensor(value: Any) -> bool:
    """Check if value is None or a tensor without NaN values."""
    if value is None:
        return True
    if not isinstance(value, torch.Tensor):
        return False
    return not torch.isnan(value).any().item()
fastvideo.pipelines.stages.validators.StageValidators.none_or_tensor_with_dims staticmethod
none_or_tensor_with_dims(dims: int) -> Callable[[Any], bool]

Return a validator that checks if value is None or a tensor with specific dimensions and no NaN values.

Source code in fastvideo/pipelines/stages/validators.py
@staticmethod
def none_or_tensor_with_dims(dims: int) -> Callable[[Any], bool]:
    """Return a validator that checks if value is None or a tensor with specific dimensions and no NaN values."""

    def validator(value: Any) -> bool:
        if value is None:
            return True
        if not isinstance(value, torch.Tensor):
            return False
        if value.dim() != dims:
            return False
        return not torch.isnan(value).any().item()

    return validator
fastvideo.pipelines.stages.validators.StageValidators.not_none staticmethod
not_none(value: Any) -> bool

Check if value is not None.

Source code in fastvideo/pipelines/stages/validators.py
@staticmethod
def not_none(value: Any) -> bool:
    """Check if value is not None."""
    return value is not None
fastvideo.pipelines.stages.validators.StageValidators.positive_float staticmethod
positive_float(value: Any) -> bool

Check if value is a positive float.

Source code in fastvideo/pipelines/stages/validators.py
@staticmethod
def positive_float(value: Any) -> bool:
    """Check if value is a positive float."""
    return isinstance(value, int | float) and value > 0
fastvideo.pipelines.stages.validators.StageValidators.positive_int staticmethod
positive_int(value: Any) -> bool

Check if value is a positive integer.

Source code in fastvideo/pipelines/stages/validators.py
@staticmethod
def positive_int(value: Any) -> bool:
    """Check if value is a positive integer."""
    return isinstance(value, int) and value > 0
fastvideo.pipelines.stages.validators.StageValidators.positive_int_divisible staticmethod
positive_int_divisible(divisor: int) -> Callable[[Any], bool]

Return a validator that checks if value is a positive integer divisible by divisor.

Source code in fastvideo/pipelines/stages/validators.py
@staticmethod
def positive_int_divisible(divisor: int) -> Callable[[Any], bool]:
    """Return a validator that checks if value is a positive integer divisible by divisor."""

    def validator(value: Any) -> bool:
        return (isinstance(value, int) and value > 0
                and StageValidators.divisible_by(value, divisor))

    return validator
fastvideo.pipelines.stages.validators.StageValidators.string_not_empty staticmethod
string_not_empty(value: Any) -> bool

Check if value is a non-empty string.

Source code in fastvideo/pipelines/stages/validators.py
@staticmethod
def string_not_empty(value: Any) -> bool:
    """Check if value is a non-empty string."""
    return isinstance(value, str) and len(value.strip()) > 0
fastvideo.pipelines.stages.validators.StageValidators.string_or_list_strings staticmethod
string_or_list_strings(value: Any) -> bool

Check if value is a string or list of strings.

Source code in fastvideo/pipelines/stages/validators.py
@staticmethod
def string_or_list_strings(value: Any) -> bool:
    """Check if value is a string or list of strings."""
    if isinstance(value, str):
        return True
    if isinstance(value, list):
        return all(isinstance(item, str) for item in value)
    return False
fastvideo.pipelines.stages.validators.StageValidators.tensor_min_dims staticmethod
tensor_min_dims(value: Any, min_dims: int) -> bool

Check if value is a tensor with at least min_dims dimensions and no NaN values.

Source code in fastvideo/pipelines/stages/validators.py
@staticmethod
def tensor_min_dims(value: Any, min_dims: int) -> bool:
    """Check if value is a tensor with at least min_dims dimensions and no NaN values."""
    if not isinstance(value, torch.Tensor):
        return False
    if value.dim() < min_dims:
        return False
    return not torch.isnan(value).any().item()
fastvideo.pipelines.stages.validators.StageValidators.tensor_shape_matches staticmethod
tensor_shape_matches(value: Any, expected_shape: tuple) -> bool

Check if tensor shape matches expected shape (None for any size) and no NaN values.

Source code in fastvideo/pipelines/stages/validators.py
@staticmethod
def tensor_shape_matches(value: Any, expected_shape: tuple) -> bool:
    """Check if tensor shape matches expected shape (None for any size) and no NaN values."""
    if not isinstance(value, torch.Tensor):
        return False
    if len(value.shape) != len(expected_shape):
        return False
    for actual, expected in zip(value.shape, expected_shape, strict=True):
        if expected is not None and actual != expected:
            return False
    return not torch.isnan(value).any().item()
fastvideo.pipelines.stages.validators.StageValidators.tensor_with_dims staticmethod
tensor_with_dims(value: Any, dims: int) -> bool

Check if value is a tensor with specific dimensions and no NaN values.

Source code in fastvideo/pipelines/stages/validators.py
@staticmethod
def tensor_with_dims(value: Any, dims: int) -> bool:
    """Check if value is a tensor with specific dimensions and no NaN values."""
    if not isinstance(value, torch.Tensor):
        return False
    if value.dim() != dims:
        return False
    return not torch.isnan(value).any().item()
fastvideo.pipelines.stages.validators.StageValidators.with_dims staticmethod
with_dims(dims: int) -> Callable[[Any], bool]

Return a validator that checks if tensor has specific dimensions and no NaN values.

Source code in fastvideo/pipelines/stages/validators.py
@staticmethod
def with_dims(dims: int) -> Callable[[Any], bool]:
    """Return a validator that checks if tensor has specific dimensions and no NaN values."""

    def validator(value: Any) -> bool:
        return StageValidators.tensor_with_dims(value, dims)

    return validator
fastvideo.pipelines.stages.validators.ValidationFailure
ValidationFailure(validator_name: str, actual_value: Any, expected: str | None = None, error_msg: str | None = None)

Details about a specific validation failure.

Source code in fastvideo/pipelines/stages/validators.py
def __init__(self,
             validator_name: str,
             actual_value: Any,
             expected: str | None = None,
             error_msg: str | None = None):
    self.validator_name = validator_name
    self.actual_value = actual_value
    self.expected = expected
    self.error_msg = error_msg
fastvideo.pipelines.stages.validators.VerificationResult
VerificationResult()

Wrapper class for stage verification results.

Source code in fastvideo/pipelines/stages/validators.py
def __init__(self) -> None:
    self._checks: dict[str, bool] = {}
    self._failures: dict[str, list[ValidationFailure]] = {}
Functions
fastvideo.pipelines.stages.validators.VerificationResult.add_check
add_check(field_name: str, value: Any, validators: Callable[[Any], bool] | list[Callable[[Any], bool]]) -> VerificationResult

Add a validation check for a field.

Parameters:

Name Type Description Default
field_name str

Name of the field being checked

required
value Any

The actual value to validate

required
validators Callable[[Any], bool] | list[Callable[[Any], bool]]

Single validation function or list of validation functions. Each function will be called with the value as its first argument.

required

Returns:

Type Description
VerificationResult

Self for method chaining

Examples:

Single validator

result.add_check("tensor", my_tensor, V.is_tensor)

Multiple validators (all must pass)

result.add_check("latents", batch.latents, [V.is_tensor, V.with_dims(5)])

Using partial functions for parameters

result.add_check("height", batch.height, [V.not_none, V.divisible(8)])

Source code in fastvideo/pipelines/stages/validators.py
def add_check(
    self, field_name: str, value: Any,
    validators: Callable[[Any], bool] | list[Callable[[Any], bool]]
) -> 'VerificationResult':
    """
    Add a validation check for a field.

    Args:
        field_name: Name of the field being checked
        value: The actual value to validate
        validators: Single validation function or list of validation functions.
                   Each function will be called with the value as its first argument.

    Returns:
        Self for method chaining

    Examples:
        # Single validator
        result.add_check("tensor", my_tensor, V.is_tensor)

        # Multiple validators (all must pass)
        result.add_check("latents", batch.latents, [V.is_tensor, V.with_dims(5)])

        # Using partial functions for parameters
        result.add_check("height", batch.height, [V.not_none, V.divisible(8)])
    """
    if not isinstance(validators, list):
        validators = [validators]

    failures = []
    all_passed = True

    # Apply all validators and collect detailed failure info
    for validator in validators:
        try:
            passed = validator(value)
            if not passed:
                all_passed = False
                failure = self._create_validation_failure(validator, value)
                failures.append(failure)
        except Exception as e:
            # If any validator raises an exception, consider the check failed
            all_passed = False
            validator_name = getattr(validator, '__name__', str(validator))
            failure = ValidationFailure(
                validator_name=validator_name,
                actual_value=value,
                error_msg=f"Exception during validation: {str(e)}")
            failures.append(failure)

    self._checks[field_name] = all_passed
    if not all_passed:
        self._failures[field_name] = failures

    return self
fastvideo.pipelines.stages.validators.VerificationResult.get_detailed_failures
get_detailed_failures() -> dict[str, list[ValidationFailure]]

Get detailed failure information for each failed field.

Source code in fastvideo/pipelines/stages/validators.py
def get_detailed_failures(self) -> dict[str, list[ValidationFailure]]:
    """Get detailed failure information for each failed field."""
    return self._failures.copy()
fastvideo.pipelines.stages.validators.VerificationResult.get_failed_fields
get_failed_fields() -> list[str]

Get list of fields that failed validation.

Source code in fastvideo/pipelines/stages/validators.py
def get_failed_fields(self) -> list[str]:
    """Get list of fields that failed validation."""
    return [field for field, passed in self._checks.items() if not passed]
fastvideo.pipelines.stages.validators.VerificationResult.get_failure_summary
get_failure_summary() -> str

Get a comprehensive summary of all validation failures.

Source code in fastvideo/pipelines/stages/validators.py
def get_failure_summary(self) -> str:
    """Get a comprehensive summary of all validation failures."""
    if self.is_valid():
        return "All validations passed"

    summary_parts = []
    for field_name, failures in self._failures.items():
        field_summary = f"\n  Field '{field_name}':"
        for i, failure in enumerate(failures, 1):
            field_summary += f"\n    {i}. {failure}"
        summary_parts.append(field_summary)

    return "Validation failures:" + "".join(summary_parts)
fastvideo.pipelines.stages.validators.VerificationResult.is_valid
is_valid() -> bool

Check if all validations passed.

Source code in fastvideo/pipelines/stages/validators.py
def is_valid(self) -> bool:
    """Check if all validations passed."""
    return all(self._checks.values())
fastvideo.pipelines.stages.validators.VerificationResult.to_dict
to_dict() -> dict

Convert to dictionary for backward compatibility.

Source code in fastvideo/pipelines/stages/validators.py
def to_dict(self) -> dict:
    """Convert to dictionary for backward compatibility."""
    return self._checks.copy()