Skip to content

causal_denoising

Classes

fastvideo.pipelines.stages.causal_denoising.CausalDMDDenosingStage

CausalDMDDenosingStage(transformer, scheduler, transformer_2=None)

Bases: DenoisingStage

Denoising stage for causal diffusion.

Source code in fastvideo/pipelines/stages/causal_denoising.py
def __init__(self, transformer, scheduler, transformer_2=None) -> None:
    super().__init__(transformer, scheduler, transformer_2)
    # KV and cross-attention cache state (initialized on first forward)
    self.transformer = transformer
    self.transformer_2 = transformer_2
    self.kv_cache1: list | None = None
    self.crossattn_cache: list | None = None
    # Model-dependent constants (aligned with causal_inference.py assumptions)
    self.num_transformer_blocks = len(self.transformer.blocks)
    self.num_frames_per_block = self.transformer.config.arch_config.num_frames_per_block
    self.sliding_window_num_frames = self.transformer.config.arch_config.sliding_window_num_frames

    try:
        self.local_attn_size = getattr(self.transformer.model,
                                       "local_attn_size",
                                       -1)  # type: ignore
    except Exception:
        self.local_attn_size = -1

Functions

fastvideo.pipelines.stages.causal_denoising.CausalDMDDenosingStage.verify_input
verify_input(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult

Verify denoising stage inputs.

Source code in fastvideo/pipelines/stages/causal_denoising.py
def verify_input(self, batch: ForwardBatch,
                 fastvideo_args: FastVideoArgs) -> VerificationResult:
    """Verify denoising stage inputs."""
    result = VerificationResult()
    result.add_check("latents", batch.latents,
                     [V.is_tensor, V.with_dims(5)])
    result.add_check("prompt_embeds", batch.prompt_embeds, V.list_not_empty)
    result.add_check("image_embeds", batch.image_embeds, V.is_list)
    result.add_check("image_latent", batch.image_latent,
                     V.none_or_tensor_with_dims(5))
    result.add_check("num_inference_steps", batch.num_inference_steps,
                     V.positive_int)
    result.add_check("guidance_scale", batch.guidance_scale,
                     V.positive_float)
    result.add_check("eta", batch.eta, V.non_negative_float)
    result.add_check("generator", batch.generator,
                     V.generator_or_list_generators)
    result.add_check("do_classifier_free_guidance",
                     batch.do_classifier_free_guidance, V.bool_value)
    result.add_check(
        "negative_prompt_embeds", batch.negative_prompt_embeds, lambda x:
        not batch.do_classifier_free_guidance or V.list_not_empty(x))
    return result

Functions