Skip to content

stages

Pipeline stages for diffusion models.

This package contains the various stages that can be composed to create complete diffusion pipelines.

Classes

fastvideo.pipelines.stages.CausalDMDDenosingStage

CausalDMDDenosingStage(transformer, scheduler, transformer_2=None)

Bases: DenoisingStage

Denoising stage for causal diffusion.

Source code in fastvideo/pipelines/stages/causal_denoising.py
def __init__(self, transformer, scheduler, transformer_2=None) -> None:
    super().__init__(transformer, scheduler, transformer_2)
    # KV and cross-attention cache state (initialized on first forward)
    self.transformer = transformer
    self.transformer_2 = transformer_2
    self.kv_cache1: list | None = None
    self.crossattn_cache: list | None = None
    # Model-dependent constants (aligned with causal_inference.py assumptions)
    self.num_transformer_blocks = len(self.transformer.blocks)
    self.num_frames_per_block = self.transformer.config.arch_config.num_frames_per_block
    self.sliding_window_num_frames = self.transformer.config.arch_config.sliding_window_num_frames

    try:
        self.local_attn_size = getattr(self.transformer.model,
                                       "local_attn_size",
                                       -1)  # type: ignore
    except Exception:
        self.local_attn_size = -1

Functions

fastvideo.pipelines.stages.CausalDMDDenosingStage.verify_input
verify_input(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult

Verify denoising stage inputs.

Source code in fastvideo/pipelines/stages/causal_denoising.py
def verify_input(self, batch: ForwardBatch,
                 fastvideo_args: FastVideoArgs) -> VerificationResult:
    """Verify denoising stage inputs."""
    result = VerificationResult()
    result.add_check("latents", batch.latents,
                     [V.is_tensor, V.with_dims(5)])
    result.add_check("prompt_embeds", batch.prompt_embeds, V.list_not_empty)
    result.add_check("image_embeds", batch.image_embeds, V.is_list)
    result.add_check("image_latent", batch.image_latent,
                     V.none_or_tensor_with_dims(5))
    result.add_check("num_inference_steps", batch.num_inference_steps,
                     V.positive_int)
    result.add_check("guidance_scale", batch.guidance_scale,
                     V.positive_float)
    result.add_check("eta", batch.eta, V.non_negative_float)
    result.add_check("generator", batch.generator,
                     V.generator_or_list_generators)
    result.add_check("do_classifier_free_guidance",
                     batch.do_classifier_free_guidance, V.bool_value)
    result.add_check(
        "negative_prompt_embeds", batch.negative_prompt_embeds, lambda x:
        not batch.do_classifier_free_guidance or V.list_not_empty(x))
    return result

fastvideo.pipelines.stages.ConditioningStage

Bases: PipelineStage

Stage for applying conditioning to the diffusion process.

This stage handles the application of conditioning, such as classifier-free guidance, to the diffusion process.

Functions

fastvideo.pipelines.stages.ConditioningStage.forward
forward(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> ForwardBatch

Apply conditioning to the diffusion process.

Parameters:

Name Type Description Default
batch ForwardBatch

The current batch information.

required
fastvideo_args FastVideoArgs

The inference arguments.

required

Returns:

Type Description
ForwardBatch

The batch with applied conditioning.

Source code in fastvideo/pipelines/stages/conditioning.py
@torch.no_grad()
def forward(
    self,
    batch: ForwardBatch,
    fastvideo_args: FastVideoArgs,
) -> ForwardBatch:
    """
    Apply conditioning to the diffusion process.

    Args:
        batch: The current batch information.
        fastvideo_args: The inference arguments.

    Returns:
        The batch with applied conditioning.
    """
    # TODO!!
    if not batch.do_classifier_free_guidance:
        return batch
    else:
        return batch

    logger.info("batch.negative_prompt_embeds: %s",
                batch.negative_prompt_embeds)
    logger.info("do_classifier_free_guidance: %s",
                batch.do_classifier_free_guidance)
    logger.info("cfg_scale: %s", batch.guidance_scale)

    # Ensure negative prompt embeddings are available
    assert batch.negative_prompt_embeds is not None, (
        "Negative prompt embeddings are required for classifier-free guidance"
    )

    # Concatenate primary embeddings and masks
    batch.prompt_embeds = torch.cat(
        [batch.negative_prompt_embeds, batch.prompt_embeds])
    if batch.attention_mask is not None:
        batch.attention_mask = torch.cat(
            [batch.negative_attention_mask, batch.attention_mask])

    # Concatenate secondary embeddings and masks if present
    if batch.prompt_embeds_2 is not None:
        batch.prompt_embeds_2 = torch.cat(
            [batch.negative_prompt_embeds_2, batch.prompt_embeds_2])
    if batch.attention_mask_2 is not None:
        batch.attention_mask_2 = torch.cat(
            [batch.negative_attention_mask_2, batch.attention_mask_2])

    return batch
fastvideo.pipelines.stages.ConditioningStage.verify_input
verify_input(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult

Verify conditioning stage inputs.

Source code in fastvideo/pipelines/stages/conditioning.py
def verify_input(self, batch: ForwardBatch,
                 fastvideo_args: FastVideoArgs) -> VerificationResult:
    """Verify conditioning stage inputs."""
    result = VerificationResult()
    result.add_check("do_classifier_free_guidance",
                     batch.do_classifier_free_guidance, V.bool_value)
    result.add_check("guidance_scale", batch.guidance_scale,
                     V.positive_float)
    result.add_check("prompt_embeds", batch.prompt_embeds, V.list_not_empty)
    result.add_check(
        "negative_prompt_embeds", batch.negative_prompt_embeds, lambda x:
        not batch.do_classifier_free_guidance or V.list_not_empty(x))
    return result
fastvideo.pipelines.stages.ConditioningStage.verify_output
verify_output(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult

Verify conditioning stage outputs.

Source code in fastvideo/pipelines/stages/conditioning.py
def verify_output(self, batch: ForwardBatch,
                  fastvideo_args: FastVideoArgs) -> VerificationResult:
    """Verify conditioning stage outputs."""
    result = VerificationResult()
    result.add_check("prompt_embeds", batch.prompt_embeds, V.list_not_empty)
    return result

fastvideo.pipelines.stages.CosmosDenoisingStage

CosmosDenoisingStage(transformer, scheduler, pipeline=None)

Bases: DenoisingStage

Denoising stage for Cosmos models using FlowMatchEulerDiscreteScheduler.

Source code in fastvideo/pipelines/stages/denoising.py
def __init__(self, transformer, scheduler, pipeline=None) -> None:
    super().__init__(transformer, scheduler, pipeline)

Functions

fastvideo.pipelines.stages.CosmosDenoisingStage.verify_input
verify_input(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult

Verify Cosmos denoising stage inputs.

Source code in fastvideo/pipelines/stages/denoising.py
def verify_input(self, batch: ForwardBatch,
                 fastvideo_args: FastVideoArgs) -> VerificationResult:
    """Verify Cosmos denoising stage inputs."""
    result = VerificationResult()
    result.add_check("latents", batch.latents,
                     [V.is_tensor, V.with_dims(5)])
    result.add_check("prompt_embeds", batch.prompt_embeds, V.list_not_empty)
    result.add_check("num_inference_steps", batch.num_inference_steps,
                     V.positive_int)
    result.add_check("guidance_scale", batch.guidance_scale,
                     V.positive_float)
    result.add_check("do_classifier_free_guidance",
                     batch.do_classifier_free_guidance, V.bool_value)
    result.add_check(
        "negative_prompt_embeds", batch.negative_prompt_embeds, lambda x:
        not batch.do_classifier_free_guidance or V.list_not_empty(x))
    return result
fastvideo.pipelines.stages.CosmosDenoisingStage.verify_output
verify_output(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult

Verify Cosmos denoising stage outputs.

Source code in fastvideo/pipelines/stages/denoising.py
def verify_output(self, batch: ForwardBatch,
                  fastvideo_args: FastVideoArgs) -> VerificationResult:
    """Verify Cosmos denoising stage outputs."""
    result = VerificationResult()
    result.add_check("latents", batch.latents,
                     [V.is_tensor, V.with_dims(5)])
    return result

fastvideo.pipelines.stages.CosmosLatentPreparationStage

CosmosLatentPreparationStage(scheduler, transformer, vae=None)

Bases: PipelineStage

Cosmos-specific latent preparation stage that properly handles the tensor shapes and conditioning masks required by the Cosmos transformer.

This stage replicates the logic from diffusers' Cosmos2VideoToWorldPipeline.prepare_latents()

Source code in fastvideo/pipelines/stages/latent_preparation.py
def __init__(self, scheduler, transformer, vae=None) -> None:
    super().__init__()
    self.scheduler = scheduler
    self.transformer = transformer
    self.vae = vae

Functions

fastvideo.pipelines.stages.CosmosLatentPreparationStage.adjust_video_length
adjust_video_length(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> int

Adjust video length based on VAE version.

Parameters:

Name Type Description Default
batch ForwardBatch

The current batch information.

required
fastvideo_args FastVideoArgs

The inference arguments.

required

Returns:

Type Description
int

The batch with adjusted video length.

Source code in fastvideo/pipelines/stages/latent_preparation.py
def adjust_video_length(self, batch: ForwardBatch,
                        fastvideo_args: FastVideoArgs) -> int:
    """
    Adjust video length based on VAE version.

    Args:
        batch: The current batch information.
        fastvideo_args: The inference arguments.

    Returns:
        The batch with adjusted video length.
    """

    video_length = batch.num_frames
    use_temporal_scaling_frames = fastvideo_args.pipeline_config.vae_config.use_temporal_scaling_frames
    if use_temporal_scaling_frames:
        temporal_scale_factor = fastvideo_args.pipeline_config.vae_config.arch_config.temporal_compression_ratio
        latent_num_frames = (video_length - 1) // temporal_scale_factor + 1
    else:  # stepvideo only
        latent_num_frames = video_length // 17 * 3
    return int(latent_num_frames)
fastvideo.pipelines.stages.CosmosLatentPreparationStage.verify_input
verify_input(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult

Verify Cosmos latent preparation stage inputs.

Source code in fastvideo/pipelines/stages/latent_preparation.py
def verify_input(self, batch: ForwardBatch,
                 fastvideo_args: FastVideoArgs) -> VerificationResult:
    """Verify Cosmos latent preparation stage inputs."""
    result = VerificationResult()
    result.add_check(
        "prompt_or_embeds", None, lambda _: V.string_or_list_strings(
            batch.prompt) or V.list_not_empty(batch.prompt_embeds))
    result.add_check("prompt_embeds", batch.prompt_embeds,
                     V.list_of_tensors)
    result.add_check("num_videos_per_prompt", batch.num_videos_per_prompt,
                     V.positive_int)
    result.add_check("generator", batch.generator,
                     V.generator_or_list_generators)
    result.add_check("num_frames", batch.num_frames, V.positive_int)
    result.add_check("height", batch.height, V.positive_int)
    result.add_check("width", batch.width, V.positive_int)
    result.add_check("latents", batch.latents, V.none_or_tensor)
    return result
fastvideo.pipelines.stages.CosmosLatentPreparationStage.verify_output
verify_output(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult

Verify latent preparation stage outputs.

Source code in fastvideo/pipelines/stages/latent_preparation.py
def verify_output(self, batch: ForwardBatch,
                  fastvideo_args: FastVideoArgs) -> VerificationResult:
    """Verify latent preparation stage outputs."""
    result = VerificationResult()
    result.add_check("latents", batch.latents,
                     [V.is_tensor, V.with_dims(5)])
    result.add_check("raw_latent_shape", batch.raw_latent_shape, V.is_tuple)
    return result

fastvideo.pipelines.stages.DecodingStage

DecodingStage(vae, pipeline=None)

Bases: PipelineStage

Stage for decoding latent representations into pixel space.

This stage handles the decoding of latent representations into the final output format (e.g., pixel values).

Source code in fastvideo/pipelines/stages/decoding.py
def __init__(self, vae, pipeline=None) -> None:
    self.vae: ParallelTiledVAE = vae
    self.pipeline = weakref.ref(pipeline) if pipeline else None

Functions

fastvideo.pipelines.stages.DecodingStage.decode
decode(latents: Tensor, fastvideo_args: FastVideoArgs) -> Tensor

Decode latent representations into pixel space using VAE.

Parameters:

Name Type Description Default
latents Tensor

Input latent tensor with shape (batch, channels, frames, height_latents, width_latents)

required
fastvideo_args FastVideoArgs

Configuration containing: - disable_autocast: Whether to disable automatic mixed precision (default: False) - pipeline_config.vae_precision: VAE computation precision ("fp32", "fp16", "bf16") - pipeline_config.vae_tiling: Whether to enable VAE tiling for memory efficiency

required

Returns:

Type Description
Tensor

Decoded video tensor with shape (batch, channels, frames, height, width),

Tensor

normalized to [0, 1] range and moved to CPU as float32

Source code in fastvideo/pipelines/stages/decoding.py
@torch.no_grad()
def decode(self, latents: torch.Tensor,
           fastvideo_args: FastVideoArgs) -> torch.Tensor:
    """
    Decode latent representations into pixel space using VAE.

    Args:
        latents: Input latent tensor with shape (batch, channels, frames, height_latents, width_latents)
        fastvideo_args: Configuration containing:
            - disable_autocast: Whether to disable automatic mixed precision (default: False)
            - pipeline_config.vae_precision: VAE computation precision ("fp32", "fp16", "bf16")
            - pipeline_config.vae_tiling: Whether to enable VAE tiling for memory efficiency

    Returns:
        Decoded video tensor with shape (batch, channels, frames, height, width), 
        normalized to [0, 1] range and moved to CPU as float32
    """
    self.vae = self.vae.to(get_local_torch_device())
    latents = latents.to(get_local_torch_device())

    # Setup VAE precision
    vae_dtype = PRECISION_TO_TYPE[
        fastvideo_args.pipeline_config.vae_precision]
    vae_autocast_enabled = (
        vae_dtype != torch.float32) and not fastvideo_args.disable_autocast

    if hasattr(self.vae, 'scaling_factor'):
        if isinstance(self.vae.scaling_factor, torch.Tensor):
            latents = latents / self.vae.scaling_factor.to(
                latents.device, latents.dtype)
        else:
            latents = latents / self.vae.scaling_factor

    # Apply shifting if needed
    if (hasattr(self.vae, "shift_factor")
            and self.vae.shift_factor is not None):
        if isinstance(self.vae.shift_factor, torch.Tensor):
            latents += self.vae.shift_factor.to(latents.device,
                                                latents.dtype)
        else:
            latents += self.vae.shift_factor

    # Decode latents
    with torch.autocast(device_type="cuda",
                        dtype=vae_dtype,
                        enabled=vae_autocast_enabled):
        if fastvideo_args.pipeline_config.vae_tiling:
            self.vae.enable_tiling()
        # if fastvideo_args.vae_sp:
        #     self.vae.enable_parallel()
        if not vae_autocast_enabled:
            latents = latents.to(vae_dtype)
        image = self.vae.decode(latents)

    # Normalize image to [0, 1] range
    image = (image / 2 + 0.5).clamp(0, 1)
    return image
fastvideo.pipelines.stages.DecodingStage.forward
forward(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> ForwardBatch

Decode latent representations into pixel space.

This method processes the batch through the VAE decoder, converting latent representations to pixel-space video/images. It also optionally decodes trajectory latents for visualization purposes.

Parameters:

Name Type Description Default
batch ForwardBatch

The current batch containing: - latents: Tensor to decode (batch, channels, frames, height_latents, width_latents) - return_trajectory_decoded (optional): Flag to decode trajectory latents - trajectory_latents (optional): Latents at different timesteps - trajectory_timesteps (optional): Corresponding timesteps

required
fastvideo_args FastVideoArgs

Configuration containing: - output_type: "latent" to skip decoding, otherwise decode to pixels - vae_cpu_offload: Whether to offload VAE to CPU after decoding - model_loaded: Track VAE loading state - model_paths: Path to VAE model if loading needed

required

Returns:

Type Description
ForwardBatch

Modified batch with: - output: Decoded frames (batch, channels, frames, height, width) as CPU float32 - trajectory_decoded (if requested): List of decoded frames per timestep

Source code in fastvideo/pipelines/stages/decoding.py
@torch.no_grad()
def forward(
    self,
    batch: ForwardBatch,
    fastvideo_args: FastVideoArgs,
) -> ForwardBatch:
    """
    Decode latent representations into pixel space.

    This method processes the batch through the VAE decoder, converting latent
    representations to pixel-space video/images. It also optionally decodes
    trajectory latents for visualization purposes.

    Args:
        batch: The current batch containing:
            - latents: Tensor to decode (batch, channels, frames, height_latents, width_latents)
            - return_trajectory_decoded (optional): Flag to decode trajectory latents
            - trajectory_latents (optional): Latents at different timesteps
            - trajectory_timesteps (optional): Corresponding timesteps
        fastvideo_args: Configuration containing:
            - output_type: "latent" to skip decoding, otherwise decode to pixels
            - vae_cpu_offload: Whether to offload VAE to CPU after decoding
            - model_loaded: Track VAE loading state
            - model_paths: Path to VAE model if loading needed

    Returns:
        Modified batch with:
            - output: Decoded frames (batch, channels, frames, height, width) as CPU float32
            - trajectory_decoded (if requested): List of decoded frames per timestep
    """
    # load vae if not already loaded (used for memory constrained devices)
    pipeline = self.pipeline() if self.pipeline else None
    if not fastvideo_args.model_loaded["vae"]:
        loader = VAELoader()
        self.vae = loader.load(fastvideo_args.model_paths["vae"],
                               fastvideo_args)
        if pipeline:
            pipeline.add_module("vae", self.vae)
        fastvideo_args.model_loaded["vae"] = True

    if fastvideo_args.output_type == "latent":
        frames = batch.latents
    else:
        frames = self.decode(batch.latents, fastvideo_args)

    # decode trajectory latents if needed
    if batch.return_trajectory_decoded:
        batch.trajectory_decoded = []
        assert batch.trajectory_latents is not None, "batch should have trajectory latents"
        for idx in range(batch.trajectory_latents.shape[1]):
            # batch.trajectory_latents is [batch_size, timesteps, channels, frames, height, width]
            cur_latent = batch.trajectory_latents[:, idx, :, :, :, :]
            cur_timestep = batch.trajectory_timesteps[idx]
            logger.info("decoding trajectory latent for timestep: %s",
                        cur_timestep)
            decoded_frames = self.decode(cur_latent, fastvideo_args)
            batch.trajectory_decoded.append(decoded_frames.cpu().float())

    # Convert to CPU float32 for compatibility
    frames = frames.cpu().float()

    # Update batch with decoded image
    batch.output = frames

    # Offload models if needed
    if hasattr(self, 'maybe_free_model_hooks'):
        self.maybe_free_model_hooks()

    if fastvideo_args.vae_cpu_offload:
        self.vae.to("cpu")

    if torch.backends.mps.is_available():
        del self.vae
        if pipeline is not None and "vae" in pipeline.modules:
            del pipeline.modules["vae"]
        fastvideo_args.model_loaded["vae"] = False

    return batch
fastvideo.pipelines.stages.DecodingStage.verify_input
verify_input(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult

Verify decoding stage inputs.

Source code in fastvideo/pipelines/stages/decoding.py
def verify_input(self, batch: ForwardBatch,
                 fastvideo_args: FastVideoArgs) -> VerificationResult:
    """Verify decoding stage inputs."""
    result = VerificationResult()
    # Denoised latents for VAE decoding: [batch_size, channels, frames, height_latents, width_latents]
    result.add_check("latents", batch.latents,
                     [V.is_tensor, V.with_dims(5)])
    return result
fastvideo.pipelines.stages.DecodingStage.verify_output
verify_output(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult

Verify decoding stage outputs.

Source code in fastvideo/pipelines/stages/decoding.py
def verify_output(self, batch: ForwardBatch,
                  fastvideo_args: FastVideoArgs) -> VerificationResult:
    """Verify decoding stage outputs."""
    result = VerificationResult()
    # Decoded video/images: [batch_size, channels, frames, height, width]
    result.add_check("output", batch.output, [V.is_tensor, V.with_dims(5)])
    return result

fastvideo.pipelines.stages.DenoisingStage

DenoisingStage(transformer, scheduler, pipeline=None, transformer_2=None, vae=None)

Bases: PipelineStage

Stage for running the denoising loop in diffusion pipelines.

This stage handles the iterative denoising process that transforms the initial noise into the final output.

Source code in fastvideo/pipelines/stages/denoising.py
def __init__(self,
             transformer,
             scheduler,
             pipeline=None,
             transformer_2=None,
             vae=None) -> None:
    super().__init__()
    self.transformer = transformer
    self.transformer_2 = transformer_2
    self.scheduler = scheduler
    self.vae = vae
    self.pipeline = weakref.ref(pipeline) if pipeline else None
    attn_head_size = self.transformer.hidden_size // self.transformer.num_attention_heads
    self.attn_backend = get_attn_backend(
        head_size=attn_head_size,
        dtype=torch.float16,  # TODO(will): hack
        supported_attention_backends=(
            AttentionBackendEnum.SLIDING_TILE_ATTN,
            AttentionBackendEnum.VIDEO_SPARSE_ATTN,
            AttentionBackendEnum.VMOBA_ATTN,
            AttentionBackendEnum.FLASH_ATTN,
            AttentionBackendEnum.TORCH_SDPA,
            AttentionBackendEnum.SAGE_ATTN_THREE)  # hack
    )

Functions

fastvideo.pipelines.stages.DenoisingStage.forward
forward(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> ForwardBatch

Run the denoising loop.

Parameters:

Name Type Description Default
batch ForwardBatch

The current batch information.

required
fastvideo_args FastVideoArgs

The inference arguments.

required

Returns:

Type Description
ForwardBatch

The batch with denoised latents.

Source code in fastvideo/pipelines/stages/denoising.py
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
def forward(
    self,
    batch: ForwardBatch,
    fastvideo_args: FastVideoArgs,
) -> ForwardBatch:
    """
    Run the denoising loop.

    Args:
        batch: The current batch information.
        fastvideo_args: The inference arguments.

    Returns:
        The batch with denoised latents.
    """
    pipeline = self.pipeline() if self.pipeline else None
    if not fastvideo_args.model_loaded["transformer"]:
        loader = TransformerLoader()
        self.transformer = loader.load(
            fastvideo_args.model_paths["transformer"], fastvideo_args)
        if pipeline:
            pipeline.add_module("transformer", self.transformer)
        fastvideo_args.model_loaded["transformer"] = True

    # Prepare extra step kwargs for scheduler
    extra_step_kwargs = self.prepare_extra_func_kwargs(
        self.scheduler.step,
        {
            "generator": batch.generator,
            "eta": batch.eta
        },
    )

    # Setup precision and autocast settings
    # TODO(will): make the precision configurable for inference
    # target_dtype = PRECISION_TO_TYPE[fastvideo_args.precision]
    target_dtype = torch.bfloat16
    autocast_enabled = (target_dtype != torch.float32
                        ) and not fastvideo_args.disable_autocast

    # Handle sequence parallelism if enabled
    sp_world_size, rank_in_sp_group = get_sp_world_size(
    ), get_sp_parallel_rank()
    sp_group = sp_world_size > 1
    if sp_group:
        latents = rearrange(batch.latents,
                            "b c (n t) h w -> b c n t h w",
                            n=sp_world_size).contiguous()
        latents = latents[:, :, rank_in_sp_group, :, :, :]
        batch.latents = latents
        if batch.image_latent is not None:
            image_latent = rearrange(batch.image_latent,
                                     "b c (n t) h w -> b c n t h w",
                                     n=sp_world_size).contiguous()
            image_latent = image_latent[:, :, rank_in_sp_group, :, :, :]
            batch.image_latent = image_latent
    # Get timesteps and calculate warmup steps
    timesteps = batch.timesteps
    # TODO(will): remove this once we add input/output validation for stages
    if timesteps is None:
        raise ValueError("Timesteps must be provided")
    num_inference_steps = batch.num_inference_steps
    num_warmup_steps = len(
        timesteps) - num_inference_steps * self.scheduler.order

    # Prepare image latents and embeddings for I2V generation
    image_embeds = batch.image_embeds
    if len(image_embeds) > 0:
        assert not torch.isnan(
            image_embeds[0]).any(), "image_embeds contains nan"
        image_embeds = [
            image_embed.to(target_dtype) for image_embed in image_embeds
        ]

    image_kwargs = self.prepare_extra_func_kwargs(
        self.transformer.forward,
        {
            "encoder_hidden_states_image": image_embeds,
            "mask_strategy": dict_to_3d_list(
                None, t_max=50, l_max=60, h_max=24)
        },
    )

    pos_cond_kwargs = self.prepare_extra_func_kwargs(
        self.transformer.forward,
        {
            "encoder_hidden_states_2": batch.clip_embedding_pos,
            "encoder_attention_mask": batch.prompt_attention_mask,
        },
    )

    neg_cond_kwargs = self.prepare_extra_func_kwargs(
        self.transformer.forward,
        {
            "encoder_hidden_states_2": batch.clip_embedding_neg,
            "encoder_attention_mask": batch.negative_attention_mask,
        },
    )

    # Prepare STA parameters
    if st_attn_available and self.attn_backend == SlidingTileAttentionBackend:
        self.prepare_sta_param(batch, fastvideo_args)

    # Get latents and embeddings
    latents = batch.latents
    prompt_embeds = batch.prompt_embeds
    assert not torch.isnan(
        prompt_embeds[0]).any(), "prompt_embeds contains nan"
    if batch.do_classifier_free_guidance:
        neg_prompt_embeds = batch.negative_prompt_embeds
        assert neg_prompt_embeds is not None
        assert not torch.isnan(
            neg_prompt_embeds[0]).any(), "neg_prompt_embeds contains nan"

    # (Wan2.2) Calculate timestep to switch from high noise expert to low noise expert
    boundary_ratio = fastvideo_args.pipeline_config.dit_config.boundary_ratio
    if batch.boundary_ratio is not None:
        logger.info("Overriding boundary ratio from %s to %s",
                    boundary_ratio, batch.boundary_ratio)
        boundary_ratio = batch.boundary_ratio

    if boundary_ratio is not None:
        boundary_timestep = boundary_ratio * self.scheduler.num_train_timesteps
    else:
        boundary_timestep = None
    latent_model_input = latents.to(target_dtype)
    assert latent_model_input.shape[0] == 1, "only support batch size 1"

    if fastvideo_args.pipeline_config.ti2v_task and batch.pil_image is not None:
        # TI2V directly replaces the first frame of the latent with
        # the image latent instead of appending along the channel dim
        assert batch.image_latent is None, "TI2V task should not have image latents"
        assert self.vae is not None, "VAE is not provided for TI2V task"
        z = self.vae.encode(batch.pil_image).mean.float()
        if (hasattr(self.vae, "shift_factor")
                and self.vae.shift_factor is not None):
            if isinstance(self.vae.shift_factor, torch.Tensor):
                z -= self.vae.shift_factor.to(z.device, z.dtype)
            else:
                z -= self.vae.shift_factor

        if isinstance(self.vae.scaling_factor, torch.Tensor):
            z = z * self.vae.scaling_factor.to(z.device, z.dtype)
        else:
            z = z * self.vae.scaling_factor

        latent_model_input = latent_model_input.squeeze(0)
        _, mask2 = masks_like([latent_model_input], zero=True)

        latent_model_input = (1. -
                              mask2[0]) * z + mask2[0] * latent_model_input
        # latent_model_input = latent_model_input.unsqueeze(0)
        latent_model_input = latent_model_input.to(get_local_torch_device())
        latents = latent_model_input
        F = batch.num_frames
        temporal_scale = fastvideo_args.pipeline_config.vae_config.arch_config.scale_factor_temporal
        spatial_scale = fastvideo_args.pipeline_config.vae_config.arch_config.scale_factor_spatial
        patch_size = fastvideo_args.pipeline_config.dit_config.arch_config.patch_size
        seq_len = ((F - 1) // temporal_scale +
                   1) * (batch.height // spatial_scale) * (
                       batch.width // spatial_scale) // (patch_size[1] *
                                                         patch_size[2])
        seq_len = int(math.ceil(seq_len / sp_world_size)) * sp_world_size

    # Initialize lists for ODE trajectory
    trajectory_timesteps: list[torch.Tensor] = []
    trajectory_latents: list[torch.Tensor] = []

    # Run denoising loop
    with self.progress_bar(total=num_inference_steps) as progress_bar:
        for i, t in enumerate(timesteps):
            # Skip if interrupted
            if hasattr(self, 'interrupt') and self.interrupt:
                continue

            if boundary_timestep is None or t >= boundary_timestep:
                if (fastvideo_args.dit_cpu_offload
                        and self.transformer_2 is not None and next(
                            self.transformer_2.parameters()).device.type
                        == 'cuda'):
                    self.transformer_2.to('cpu')
                current_model = self.transformer
                current_guidance_scale = batch.guidance_scale
            else:
                # low-noise stage in wan2.2
                if fastvideo_args.dit_cpu_offload and next(
                        self.transformer.parameters(
                        )).device.type == 'cuda':
                    self.transformer.to('cpu')
                current_model = self.transformer_2
                current_guidance_scale = batch.guidance_scale_2
            assert current_model is not None, "current_model is None"

            # Expand latents for V2V/I2V
            latent_model_input = latents.to(target_dtype)
            if batch.video_latent is not None:
                latent_model_input = torch.cat([
                    latent_model_input, batch.video_latent,
                    torch.zeros_like(latents)
                ],
                                               dim=1).to(target_dtype)
            elif batch.image_latent is not None:
                assert not fastvideo_args.pipeline_config.ti2v_task, "image latents should not be provided for TI2V task"
                latent_model_input = torch.cat(
                    [latent_model_input, batch.image_latent],
                    dim=1).to(target_dtype)

            assert not torch.isnan(
                latent_model_input).any(), "latent_model_input contains nan"
            if fastvideo_args.pipeline_config.ti2v_task and batch.pil_image is not None:
                timestep = torch.stack([t]).to(get_local_torch_device())
                temp_ts = (mask2[0][0][:, ::2, ::2] * timestep).flatten()
                temp_ts = torch.cat([
                    temp_ts,
                    temp_ts.new_ones(seq_len - temp_ts.size(0)) * timestep
                ])
                timestep = temp_ts.unsqueeze(0)
                t_expand = timestep.repeat(latent_model_input.shape[0], 1)
            else:
                t_expand = t.repeat(latent_model_input.shape[0])

            latent_model_input = self.scheduler.scale_model_input(
                latent_model_input, t)

            # Prepare inputs for transformer
            guidance_expand = (
                torch.tensor(
                    [fastvideo_args.pipeline_config.embedded_cfg_scale] *
                    latent_model_input.shape[0],
                    dtype=torch.float32,
                    device=get_local_torch_device(),
                ).to(target_dtype) *
                1000.0 if fastvideo_args.pipeline_config.embedded_cfg_scale
                is not None else None)

            # Predict noise residual
            with torch.autocast(device_type="cuda",
                                dtype=target_dtype,
                                enabled=autocast_enabled):
                if (st_attn_available
                        and self.attn_backend == SlidingTileAttentionBackend
                    ) or (vsa_available and self.attn_backend
                          == VideoSparseAttentionBackend):
                    self.attn_metadata_builder_cls = self.attn_backend.get_builder_cls(
                    )

                    if self.attn_metadata_builder_cls is not None:
                        self.attn_metadata_builder = self.attn_metadata_builder_cls(
                        )
                        # TODO(will): clean this up
                        attn_metadata = self.attn_metadata_builder.build(  # type: ignore
                            current_timestep=i,  # type: ignore
                            raw_latent_shape=batch.
                            raw_latent_shape[2:5],  # type: ignore
                            patch_size=fastvideo_args.
                            pipeline_config.  # type: ignore
                            dit_config.patch_size,  # type: ignore
                            STA_param=batch.STA_param,  # type: ignore
                            VSA_sparsity=fastvideo_args.
                            VSA_sparsity,  # type: ignore
                            device=get_local_torch_device(),
                        )
                        assert attn_metadata is not None, "attn_metadata cannot be None"
                    else:
                        attn_metadata = None
                elif (vmoba_attn_available
                      and self.attn_backend == VMOBAAttentionBackend):
                    self.attn_metadata_builder_cls = self.attn_backend.get_builder_cls(
                    )
                    if self.attn_metadata_builder_cls is not None:
                        self.attn_metadata_builder = self.attn_metadata_builder_cls(
                        )
                        # Prepare V-MoBA parameters from config
                        moba_params = fastvideo_args.moba_config.copy()
                        moba_params.update({
                            "current_timestep":
                            i,
                            "raw_latent_shape":
                            batch.raw_latent_shape[2:5],
                            "patch_size":
                            fastvideo_args.pipeline_config.dit_config.
                            patch_size,
                            "device":
                            get_local_torch_device(),
                        })
                        attn_metadata = self.attn_metadata_builder.build(
                            **moba_params)
                        assert attn_metadata is not None, "attn_metadata cannot be None"
                    else:
                        attn_metadata = None
                else:
                    attn_metadata = None
                # TODO(will): finalize the interface. vLLM uses this to
                # support torch dynamo compilation. They pass in
                # attn_metadata, vllm_config, and num_tokens. We can pass in
                # fastvideo_args or training_args, and attn_metadata.
                batch.is_cfg_negative = False
                with set_forward_context(
                        current_timestep=i,
                        attn_metadata=attn_metadata,
                        forward_batch=batch,
                        # fastvideo_args=fastvideo_args
                ):
                    # Run transformer
                    noise_pred = current_model(
                        latent_model_input,
                        prompt_embeds,
                        t_expand,
                        guidance=guidance_expand,
                        **image_kwargs,
                        **pos_cond_kwargs,
                    )

                if batch.do_classifier_free_guidance:
                    batch.is_cfg_negative = True
                    with set_forward_context(
                            current_timestep=i,
                            attn_metadata=attn_metadata,
                            forward_batch=batch,
                    ):
                        noise_pred_uncond = current_model(
                            latent_model_input,
                            neg_prompt_embeds,
                            t_expand,
                            guidance=guidance_expand,
                            **image_kwargs,
                            **neg_cond_kwargs,
                        )

                    noise_pred_text = noise_pred
                    noise_pred = noise_pred_uncond + current_guidance_scale * (
                        noise_pred_text - noise_pred_uncond)

                    # Apply guidance rescale if needed
                    if batch.guidance_rescale > 0.0:
                        # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
                        noise_pred = self.rescale_noise_cfg(
                            noise_pred,
                            noise_pred_text,
                            guidance_rescale=batch.guidance_rescale,
                        )
                # Compute the previous noisy sample
                latents = self.scheduler.step(noise_pred,
                                              t,
                                              latents,
                                              **extra_step_kwargs,
                                              return_dict=False)[0]
                if fastvideo_args.pipeline_config.ti2v_task and batch.pil_image is not None:
                    latents = latents.squeeze(0)
                    latents = (1. - mask2[0]) * z + mask2[0] * latents
                    # latents = latents.unsqueeze(0)

            # save trajectory latents if needed
            if batch.return_trajectory_latents:
                trajectory_timesteps.append(t)
                trajectory_latents.append(latents)

            # Update progress bar
            if i == len(timesteps) - 1 or (
                (i + 1) > num_warmup_steps and
                (i + 1) % self.scheduler.order == 0
                    and progress_bar is not None):
                progress_bar.update()

    # Gather results if using sequence parallelism
    trajectory_tensor: torch.Tensor | None = None
    if trajectory_latents:
        trajectory_tensor = torch.stack(trajectory_latents, dim=1)
        trajectory_timesteps_tensor = torch.stack(trajectory_timesteps,
                                                  dim=0)
    else:
        trajectory_tensor = None
        trajectory_timesteps_tensor = None

    # Gather results if using sequence parallelism
    if sp_group:
        latents = sequence_model_parallel_all_gather(latents, dim=2)
        if batch.return_trajectory_latents:
            trajectory_tensor = trajectory_tensor.to(
                get_local_torch_device())
            trajectory_tensor = sequence_model_parallel_all_gather(
                trajectory_tensor, dim=3)

    if trajectory_tensor is not None and trajectory_timesteps_tensor is not None:
        batch.trajectory_timesteps = trajectory_timesteps_tensor.cpu()
        batch.trajectory_latents = trajectory_tensor.cpu()

    # Update batch with final latents
    batch.latents = latents

    # Save STA mask search results if needed
    if st_attn_available and self.attn_backend == SlidingTileAttentionBackend and fastvideo_args.STA_mode == STA_Mode.STA_SEARCHING:
        self.save_sta_search_results(batch)

    # deallocate transformer if on mps
    if torch.backends.mps.is_available():
        logger.info("Memory before deallocating transformer: %s",
                    torch.mps.current_allocated_memory())
        del self.transformer
        if pipeline is not None and "transformer" in pipeline.modules:
            del pipeline.modules["transformer"]
        fastvideo_args.model_loaded["transformer"] = False
        logger.info("Memory after deallocating transformer: %s",
                    torch.mps.current_allocated_memory())

    return batch
fastvideo.pipelines.stages.DenoisingStage.prepare_extra_func_kwargs
prepare_extra_func_kwargs(func, kwargs) -> dict[str, Any]

Prepare extra kwargs for the scheduler step / denoise step.

Parameters:

Name Type Description Default
func

The function to prepare kwargs for.

required
kwargs

The kwargs to prepare.

required

Returns:

Type Description
dict[str, Any]

The prepared kwargs.

Source code in fastvideo/pipelines/stages/denoising.py
def prepare_extra_func_kwargs(self, func, kwargs) -> dict[str, Any]:
    """
    Prepare extra kwargs for the scheduler step / denoise step.

    Args:
        func: The function to prepare kwargs for.
        kwargs: The kwargs to prepare.

    Returns:
        The prepared kwargs.
    """
    extra_step_kwargs = {}
    for k, v in kwargs.items():
        accepts = k in set(inspect.signature(func).parameters.keys())
        if accepts:
            extra_step_kwargs[k] = v
    return extra_step_kwargs
fastvideo.pipelines.stages.DenoisingStage.prepare_sta_param
prepare_sta_param(batch: ForwardBatch, fastvideo_args: FastVideoArgs)

Prepare Sliding Tile Attention (STA) parameters and settings.

Parameters:

Name Type Description Default
batch ForwardBatch

The current batch information.

required
fastvideo_args FastVideoArgs

The inference arguments.

required
Source code in fastvideo/pipelines/stages/denoising.py
def prepare_sta_param(self, batch: ForwardBatch,
                      fastvideo_args: FastVideoArgs):
    """
    Prepare Sliding Tile Attention (STA) parameters and settings.

    Args:
        batch: The current batch information.
        fastvideo_args: The inference arguments.
    """
    # TODO(kevin): STA mask search, currently only support Wan2.1 with 69x768x1280
    from fastvideo.attention.backends.STA_configuration import configure_sta
    STA_mode = fastvideo_args.STA_mode
    skip_time_steps = fastvideo_args.skip_time_steps
    if batch.timesteps is None:
        raise ValueError("Timesteps must be provided")
    timesteps_num = batch.timesteps.shape[0]

    logger.info("STA_mode: %s", STA_mode)
    if (batch.num_frames, batch.height,
            batch.width) != (69, 768, 1280) and STA_mode != "STA_inference":
        raise NotImplementedError(
            "STA mask search/tuning is not supported for this resolution")

    if STA_mode == STA_Mode.STA_SEARCHING or STA_mode == STA_Mode.STA_TUNING or STA_mode == STA_Mode.STA_TUNING_CFG:
        size = (batch.width, batch.height)
        if size == (1280, 768):
            # TODO: make it configurable
            sparse_mask_candidates_searching = [
                "3, 1, 10", "1, 5, 7", "3, 3, 3", "1, 6, 5", "1, 3, 10",
                "3, 6, 1"
            ]
            sparse_mask_candidates_tuning = [
                "3, 1, 10", "1, 5, 7", "3, 3, 3", "1, 6, 5", "1, 3, 10",
                "3, 6, 1"
            ]
            full_mask = ["3,6,10"]
        else:
            raise NotImplementedError(
                "STA mask search is not supported for this resolution")
    layer_num = self.transformer.config.num_layers
    # specific for HunyuanVideo
    if hasattr(self.transformer.config, "num_single_layers"):
        layer_num += self.transformer.config.num_single_layers
    head_num = self.transformer.config.num_attention_heads

    if STA_mode == STA_Mode.STA_SEARCHING:
        STA_param = configure_sta(
            mode=STA_Mode.STA_SEARCHING,
            layer_num=layer_num,
            head_num=head_num,
            time_step_num=timesteps_num,
            mask_candidates=sparse_mask_candidates_searching +
            full_mask,  # last is full mask; Can add more sparse masks while keep last one as full mask
        )
    elif STA_mode == STA_Mode.STA_TUNING:
        STA_param = configure_sta(
            mode=STA_Mode.STA_TUNING,
            layer_num=layer_num,
            head_num=head_num,
            time_step_num=timesteps_num,
            mask_search_files_path=
            f'output/mask_search_result_pos_{size[0]}x{size[1]}/',
            mask_candidates=sparse_mask_candidates_tuning,
            full_attention_mask=[int(x) for x in full_mask[0].split(',')],
            skip_time_steps=
            skip_time_steps,  # Use full attention for first 12 steps
            save_dir=
            f'output/mask_search_strategy_{size[0]}x{size[1]}/',  # Custom save directory
            timesteps=timesteps_num)
    elif STA_mode == STA_Mode.STA_TUNING_CFG:
        STA_param = configure_sta(
            mode=STA_Mode.STA_TUNING_CFG,
            layer_num=layer_num,
            head_num=head_num,
            time_step_num=timesteps_num,
            mask_search_files_path_pos=
            f'output/mask_search_result_pos_{size[0]}x{size[1]}/',
            mask_search_files_path_neg=
            f'output/mask_search_result_neg_{size[0]}x{size[1]}/',
            mask_candidates=sparse_mask_candidates_tuning,
            full_attention_mask=[int(x) for x in full_mask[0].split(',')],
            skip_time_steps=skip_time_steps,
            save_dir=f'output/mask_search_strategy_{size[0]}x{size[1]}/',
            timesteps=timesteps_num)
    elif STA_mode == STA_Mode.STA_INFERENCE:
        import fastvideo.envs as envs
        config_file = envs.FASTVIDEO_ATTENTION_CONFIG
        if config_file is None:
            raise ValueError("FASTVIDEO_ATTENTION_CONFIG is not set")
        STA_param = configure_sta(mode=STA_Mode.STA_INFERENCE,
                                  layer_num=layer_num,
                                  head_num=head_num,
                                  time_step_num=timesteps_num,
                                  load_path=config_file)

    batch.STA_param = STA_param
    batch.mask_search_final_result_pos = [[] for _ in range(timesteps_num)]
    batch.mask_search_final_result_neg = [[] for _ in range(timesteps_num)]
fastvideo.pipelines.stages.DenoisingStage.progress_bar
progress_bar(iterable: Iterable | None = None, total: int | None = None) -> tqdm

Create a progress bar for the denoising process.

Parameters:

Name Type Description Default
iterable Iterable | None

The iterable to iterate over.

None
total int | None

The total number of items.

None

Returns:

Type Description
tqdm

A tqdm progress bar.

Source code in fastvideo/pipelines/stages/denoising.py
def progress_bar(self,
                 iterable: Iterable | None = None,
                 total: int | None = None) -> tqdm:
    """
    Create a progress bar for the denoising process.

    Args:
        iterable: The iterable to iterate over.
        total: The total number of items.

    Returns:
        A tqdm progress bar.
    """
    local_rank = get_world_group().local_rank
    if local_rank == 0:
        return tqdm(iterable=iterable, total=total)
    else:
        return tqdm(iterable=iterable, total=total, disable=True)
fastvideo.pipelines.stages.DenoisingStage.rescale_noise_cfg
rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0) -> Tensor

Rescale noise prediction according to guidance_rescale.

Based on findings of "Common Diffusion Noise Schedules and Sample Steps are Flawed" (https://arxiv.org/pdf/2305.08891.pdf), Section 3.4.

Parameters:

Name Type Description Default
noise_cfg

The noise prediction with guidance.

required
noise_pred_text

The text-conditioned noise prediction.

required
guidance_rescale

The guidance rescale factor.

0.0

Returns:

Type Description
Tensor

The rescaled noise prediction.

Source code in fastvideo/pipelines/stages/denoising.py
def rescale_noise_cfg(self,
                      noise_cfg,
                      noise_pred_text,
                      guidance_rescale=0.0) -> torch.Tensor:
    """
    Rescale noise prediction according to guidance_rescale.

    Based on findings of "Common Diffusion Noise Schedules and Sample Steps are Flawed"
    (https://arxiv.org/pdf/2305.08891.pdf), Section 3.4.

    Args:
        noise_cfg: The noise prediction with guidance.
        noise_pred_text: The text-conditioned noise prediction.
        guidance_rescale: The guidance rescale factor.

    Returns:
        The rescaled noise prediction.
    """
    std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)),
                                   keepdim=True)
    std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)),
                            keepdim=True)
    # Rescale the results from guidance (fixes overexposure)
    noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
    # Mix with the original results from guidance by factor guidance_rescale
    noise_cfg = (guidance_rescale * noise_pred_rescaled +
                 (1 - guidance_rescale) * noise_cfg)
    return noise_cfg
fastvideo.pipelines.stages.DenoisingStage.save_sta_search_results
save_sta_search_results(batch: ForwardBatch)

Save the STA mask search results.

Parameters:

Name Type Description Default
batch ForwardBatch

The current batch information.

required
Source code in fastvideo/pipelines/stages/denoising.py
def save_sta_search_results(self, batch: ForwardBatch):
    """
    Save the STA mask search results.

    Args:
        batch: The current batch information.
    """
    size = (batch.width, batch.height)
    if size == (1280, 768):
        # TODO: make it configurable
        sparse_mask_candidates_searching = [
            "3, 1, 10", "1, 5, 7", "3, 3, 3", "1, 6, 5", "1, 3, 10",
            "3, 6, 1"
        ]
    else:
        raise NotImplementedError(
            "STA mask search is not supported for this resolution")

    from fastvideo.attention.backends.STA_configuration import save_mask_search_results
    if batch.mask_search_final_result_pos is not None and batch.prompt is not None:
        save_mask_search_results(
            [
                dict(layer_data)
                for layer_data in batch.mask_search_final_result_pos
            ],
            prompt=str(batch.prompt),
            mask_strategies=sparse_mask_candidates_searching,
            output_dir=f'output/mask_search_result_pos_{size[0]}x{size[1]}/'
        )
    if batch.mask_search_final_result_neg is not None and batch.prompt is not None:
        save_mask_search_results(
            [
                dict(layer_data)
                for layer_data in batch.mask_search_final_result_neg
            ],
            prompt=str(batch.prompt),
            mask_strategies=sparse_mask_candidates_searching,
            output_dir=f'output/mask_search_result_neg_{size[0]}x{size[1]}/'
        )
fastvideo.pipelines.stages.DenoisingStage.verify_input
verify_input(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult

Verify denoising stage inputs.

Source code in fastvideo/pipelines/stages/denoising.py
def verify_input(self, batch: ForwardBatch,
                 fastvideo_args: FastVideoArgs) -> VerificationResult:
    """Verify denoising stage inputs."""
    result = VerificationResult()
    result.add_check("timesteps", batch.timesteps,
                     [V.is_tensor, V.min_dims(1)])
    result.add_check("latents", batch.latents,
                     [V.is_tensor, V.with_dims(5)])
    result.add_check("prompt_embeds", batch.prompt_embeds, V.list_not_empty)
    result.add_check("image_embeds", batch.image_embeds, V.is_list)
    result.add_check("image_latent", batch.image_latent,
                     V.none_or_tensor_with_dims(5))
    result.add_check("num_inference_steps", batch.num_inference_steps,
                     V.positive_int)
    result.add_check("guidance_scale", batch.guidance_scale,
                     V.positive_float)
    result.add_check("eta", batch.eta, V.non_negative_float)
    result.add_check("generator", batch.generator,
                     V.generator_or_list_generators)
    result.add_check("do_classifier_free_guidance",
                     batch.do_classifier_free_guidance, V.bool_value)
    result.add_check(
        "negative_prompt_embeds", batch.negative_prompt_embeds, lambda x:
        not batch.do_classifier_free_guidance or V.list_not_empty(x))
    return result
fastvideo.pipelines.stages.DenoisingStage.verify_output
verify_output(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult

Verify denoising stage outputs.

Source code in fastvideo/pipelines/stages/denoising.py
def verify_output(self, batch: ForwardBatch,
                  fastvideo_args: FastVideoArgs) -> VerificationResult:
    """Verify denoising stage outputs."""
    result = VerificationResult()
    result.add_check("latents", batch.latents,
                     [V.is_tensor, V.with_dims(5)])
    return result

fastvideo.pipelines.stages.DmdDenoisingStage

DmdDenoisingStage(transformer, scheduler)

Bases: DenoisingStage

Denoising stage for DMD.

Source code in fastvideo/pipelines/stages/denoising.py
def __init__(self, transformer, scheduler) -> None:
    super().__init__(transformer, scheduler)
    self.scheduler = FlowMatchEulerDiscreteScheduler(shift=8.0)

Functions

fastvideo.pipelines.stages.DmdDenoisingStage.forward
forward(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> ForwardBatch

Run the denoising loop.

Parameters:

Name Type Description Default
batch ForwardBatch

The current batch information.

required
fastvideo_args FastVideoArgs

The inference arguments.

required

Returns:

Type Description
ForwardBatch

The batch with denoised latents.

Source code in fastvideo/pipelines/stages/denoising.py
def forward(
    self,
    batch: ForwardBatch,
    fastvideo_args: FastVideoArgs,
) -> ForwardBatch:
    """
    Run the denoising loop.

    Args:
        batch: The current batch information.
        fastvideo_args: The inference arguments.

    Returns:
        The batch with denoised latents.
    """
    # Setup precision and autocast settings
    # TODO(will): make the precision configurable for inference
    # target_dtype = PRECISION_TO_TYPE[fastvideo_args.precision]
    target_dtype = torch.bfloat16
    autocast_enabled = (target_dtype != torch.float32
                        ) and not fastvideo_args.disable_autocast

    # Get timesteps and calculate warmup steps
    timesteps = batch.timesteps

    # TODO(will): remove this once we add input/output validation for stages
    if timesteps is None:
        raise ValueError("Timesteps must be provided")
    num_inference_steps = batch.num_inference_steps
    num_warmup_steps = len(
        timesteps) - num_inference_steps * self.scheduler.order

    # Prepare image latents and embeddings for I2V generation
    image_embeds = batch.image_embeds
    if len(image_embeds) > 0:
        assert torch.isnan(image_embeds[0]).sum() == 0
        image_embeds = [
            image_embed.to(target_dtype) for image_embed in image_embeds
        ]

    image_kwargs = self.prepare_extra_func_kwargs(
        self.transformer.forward,
        {
            "encoder_hidden_states_image": image_embeds,
            "mask_strategy": dict_to_3d_list(
                None, t_max=50, l_max=60, h_max=24)
        },
    )

    pos_cond_kwargs = self.prepare_extra_func_kwargs(
        self.transformer.forward,
        {
            "encoder_hidden_states_2": batch.clip_embedding_pos,
            "encoder_attention_mask": batch.prompt_attention_mask,
        },
    )

    # Prepare STA parameters
    if st_attn_available and self.attn_backend == SlidingTileAttentionBackend:
        self.prepare_sta_param(batch, fastvideo_args)

    # Get latents and embeddings
    assert batch.latents is not None, "latents must be provided"
    latents = batch.latents
    latents = latents.permute(0, 2, 1, 3, 4)

    video_raw_latent_shape = latents.shape
    prompt_embeds = batch.prompt_embeds
    assert not torch.isnan(
        prompt_embeds[0]).any(), "prompt_embeds contains nan"
    timesteps = torch.tensor(
        fastvideo_args.pipeline_config.dmd_denoising_steps,
        dtype=torch.long,
        device=get_local_torch_device())

    # Handle sequence parallelism if enabled
    sp_world_size, rank_in_sp_group = get_sp_world_size(
    ), get_sp_parallel_rank()
    sp_group = sp_world_size > 1
    if sp_group:
        latents = rearrange(latents,
                            "b (n t) c h w -> b n t c h w",
                            n=sp_world_size).contiguous()
        latents = latents[:, rank_in_sp_group, :, :, :, :]
        if batch.image_latent is not None:
            image_latent = rearrange(batch.image_latent,
                                     "b c (n t) h w -> b c n t h w",
                                     n=sp_world_size).contiguous()

            image_latent = image_latent[:, :, rank_in_sp_group, :, :, :]
            batch.image_latent = image_latent

    # Run denoising loop
    with self.progress_bar(total=len(timesteps)) as progress_bar:
        for i, t in enumerate(timesteps):
            # Skip if interrupted
            if hasattr(self, 'interrupt') and self.interrupt:
                continue
            # Expand latents for I2V
            noise_latents = latents.clone()
            latent_model_input = latents.to(target_dtype)

            if batch.image_latent is not None:
                latent_model_input = torch.cat([
                    latent_model_input,
                    batch.image_latent.permute(0, 2, 1, 3, 4)
                ],
                                               dim=2).to(target_dtype)
            assert not torch.isnan(
                latent_model_input).any(), "latent_model_input contains nan"

            # Prepare inputs for transformer
            t_expand = t.repeat(latent_model_input.shape[0])
            guidance_expand = (
                torch.tensor(
                    [fastvideo_args.pipeline_config.embedded_cfg_scale] *
                    latent_model_input.shape[0],
                    dtype=torch.float32,
                    device=get_local_torch_device(),
                ).to(target_dtype) *
                1000.0 if fastvideo_args.pipeline_config.embedded_cfg_scale
                is not None else None)

            # Predict noise residual
            with torch.autocast(device_type="cuda",
                                dtype=target_dtype,
                                enabled=autocast_enabled):
                if (vsa_available and self.attn_backend
                        == VideoSparseAttentionBackend):
                    self.attn_metadata_builder_cls = self.attn_backend.get_builder_cls(
                    )

                    if self.attn_metadata_builder_cls is not None:
                        self.attn_metadata_builder = self.attn_metadata_builder_cls(
                        )
                        # TODO(will): clean this up
                        attn_metadata = self.attn_metadata_builder.build(  # type: ignore
                            current_timestep=i,  # type: ignore
                            raw_latent_shape=batch.
                            raw_latent_shape[2:5],  # type: ignore
                            patch_size=fastvideo_args.
                            pipeline_config.  # type: ignore
                            dit_config.patch_size,  # type: ignore
                            STA_param=batch.STA_param,  # type: ignore
                            VSA_sparsity=fastvideo_args.
                            VSA_sparsity,  # type: ignore
                            device=get_local_torch_device(),  # type: ignore
                        )  # type: ignore
                        assert attn_metadata is not None, "attn_metadata cannot be None"
                    else:
                        attn_metadata = None
                else:
                    attn_metadata = None

                batch.is_cfg_negative = False
                with set_forward_context(
                        current_timestep=i,
                        attn_metadata=attn_metadata,
                        forward_batch=batch,
                        # fastvideo_args=fastvideo_args
                ):
                    # Run transformer
                    pred_noise = self.transformer(
                        latent_model_input.permute(0, 2, 1, 3, 4),
                        prompt_embeds,
                        t_expand,
                        guidance=guidance_expand,
                        **image_kwargs,
                        **pos_cond_kwargs,
                    ).permute(0, 2, 1, 3, 4)

                pred_video = pred_noise_to_pred_video(
                    pred_noise=pred_noise.flatten(0, 1),
                    noise_input_latent=noise_latents.flatten(0, 1),
                    timestep=t_expand,
                    scheduler=self.scheduler).unflatten(
                        0, pred_noise.shape[:2])

                if i < len(timesteps) - 1:
                    next_timestep = timesteps[i + 1] * torch.ones(
                        [1], dtype=torch.long, device=pred_video.device)
                    noise = torch.randn(video_raw_latent_shape,
                                        dtype=pred_video.dtype,
                                        generator=batch.generator[0]).to(
                                            self.device)
                    if sp_group:
                        noise = rearrange(noise,
                                          "b (n t) c h w -> b n t c h w",
                                          n=sp_world_size).contiguous()
                        noise = noise[:, rank_in_sp_group, :, :, :, :]
                    latents = self.scheduler.add_noise(
                        pred_video.flatten(0, 1), noise.flatten(0, 1),
                        next_timestep).unflatten(0, pred_video.shape[:2])
                else:
                    latents = pred_video

                # Update progress bar
                if i == len(timesteps) - 1 or (
                    (i + 1) > num_warmup_steps and
                    (i + 1) % self.scheduler.order == 0
                        and progress_bar is not None):
                    progress_bar.update()

    # Gather results if using sequence parallelism
    if sp_group:
        latents = sequence_model_parallel_all_gather(latents, dim=1)
    latents = latents.permute(0, 2, 1, 3, 4)
    # Update batch with final latents
    batch.latents = latents

    return batch

fastvideo.pipelines.stages.EncodingStage

EncodingStage(vae: ParallelTiledVAE)

Bases: PipelineStage

Stage for encoding pixel space representations into latent space.

This stage handles the encoding of pixel-space video/images into latent representations for further processing in the diffusion pipeline.

Source code in fastvideo/pipelines/stages/encoding.py
def __init__(self, vae: ParallelTiledVAE) -> None:
    self.vae: ParallelTiledVAE = vae

Functions

fastvideo.pipelines.stages.EncodingStage.forward
forward(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> ForwardBatch

Encode pixel space representations into latent space.

Parameters:

Name Type Description Default
batch ForwardBatch

The current batch information.

required
fastvideo_args FastVideoArgs

The inference arguments.

required

Returns:

Type Description
ForwardBatch

The batch with encoded latents.

Source code in fastvideo/pipelines/stages/encoding.py
def forward(
    self,
    batch: ForwardBatch,
    fastvideo_args: FastVideoArgs,
) -> ForwardBatch:
    """
    Encode pixel space representations into latent space.

    Args:
        batch: The current batch information.
        fastvideo_args: The inference arguments.

    Returns:
        The batch with encoded latents.
    """
    assert batch.latents is not None and isinstance(batch.latents,
                                                    torch.Tensor)

    self.vae = self.vae.to(get_local_torch_device())

    # Setup VAE precision
    vae_dtype = PRECISION_TO_TYPE[
        fastvideo_args.pipeline_config.vae_precision]
    vae_autocast_enabled = (
        vae_dtype != torch.float32) and not fastvideo_args.disable_autocast

    # Normalize input to [-1, 1] range (reverse of decoding normalization)
    latents = (batch.latents * 2.0 - 1.0).clamp(-1, 1)

    # Move to appropriate device and dtype
    latents = latents.to(get_local_torch_device())

    # Encode image to latents
    with torch.autocast(device_type="cuda",
                        dtype=vae_dtype,
                        enabled=vae_autocast_enabled):
        if fastvideo_args.pipeline_config.vae_tiling:
            self.vae.enable_tiling()
        # if fastvideo_args.vae_sp:
        #     self.vae.enable_parallel()
        if not vae_autocast_enabled:
            latents = latents.to(vae_dtype)
        latents = self.vae.encode(latents).mean

    # Update batch with encoded latents
    batch.latents = latents

    # Offload models if needed
    if hasattr(self, 'maybe_free_model_hooks'):
        self.maybe_free_model_hooks()

    if fastvideo_args.vae_cpu_offload:
        self.vae.to("cpu")

    return batch
fastvideo.pipelines.stages.EncodingStage.verify_input
verify_input(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult

Verify encoding stage inputs.

Source code in fastvideo/pipelines/stages/encoding.py
@torch.no_grad()
def verify_input(self, batch: ForwardBatch,
                 fastvideo_args: FastVideoArgs) -> VerificationResult:
    """Verify encoding stage inputs."""
    result = VerificationResult()
    # Input video/images for VAE encoding: [batch_size, channels, frames, height, width]
    result.add_check("latents", batch.latents,
                     [V.is_tensor, V.with_dims(5)])
    return result
fastvideo.pipelines.stages.EncodingStage.verify_output
verify_output(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult

Verify encoding stage outputs.

Source code in fastvideo/pipelines/stages/encoding.py
def verify_output(self, batch: ForwardBatch,
                  fastvideo_args: FastVideoArgs) -> VerificationResult:
    """Verify encoding stage outputs."""
    result = VerificationResult()
    # Encoded latents: [batch_size, channels, frames, height_latents, width_latents]
    result.add_check("latents", batch.latents,
                     [V.is_tensor, V.with_dims(5)])
    return result

fastvideo.pipelines.stages.ImageEncodingStage

ImageEncodingStage(image_encoder, image_processor)

Bases: PipelineStage

Stage for encoding image prompts into embeddings for diffusion models.

This stage handles the encoding of image prompts into the embedding space expected by the diffusion model.

Initialize the prompt encoding stage.

Parameters:

Name Type Description Default
enable_logging

Whether to enable logging for this stage.

required
is_secondary

Whether this is a secondary image encoder.

required
Source code in fastvideo/pipelines/stages/image_encoding.py
def __init__(self, image_encoder, image_processor) -> None:
    """
    Initialize the prompt encoding stage.

    Args:
        enable_logging: Whether to enable logging for this stage.
        is_secondary: Whether this is a secondary image encoder.
    """
    super().__init__()
    self.image_processor = image_processor
    self.image_encoder = image_encoder

Functions

fastvideo.pipelines.stages.ImageEncodingStage.forward
forward(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> ForwardBatch

Encode the prompt into image encoder hidden states.

Parameters:

Name Type Description Default
batch ForwardBatch

The current batch information.

required
fastvideo_args FastVideoArgs

The inference arguments.

required

Returns:

Type Description
ForwardBatch

The batch with encoded prompt embeddings.

Source code in fastvideo/pipelines/stages/image_encoding.py
@torch.no_grad()
def forward(
    self,
    batch: ForwardBatch,
    fastvideo_args: FastVideoArgs,
) -> ForwardBatch:
    """
    Encode the prompt into image encoder hidden states.

    Args:
        batch: The current batch information.
        fastvideo_args: The inference arguments.

    Returns:
        The batch with encoded prompt embeddings.
    """
    self.image_encoder = self.image_encoder.to(get_local_torch_device())

    image = batch.pil_image

    image_inputs = self.image_processor(
        images=image, return_tensors="pt").to(get_local_torch_device())
    with set_forward_context(current_timestep=0, attn_metadata=None):
        outputs = self.image_encoder(**image_inputs)
        image_embeds = outputs.last_hidden_state

    batch.image_embeds.append(image_embeds)

    if fastvideo_args.image_encoder_cpu_offload:
        self.image_encoder.to('cpu')

    return batch
fastvideo.pipelines.stages.ImageEncodingStage.verify_input
verify_input(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult

Verify image encoding stage inputs.

Source code in fastvideo/pipelines/stages/image_encoding.py
def verify_input(self, batch: ForwardBatch,
                 fastvideo_args: FastVideoArgs) -> VerificationResult:
    """Verify image encoding stage inputs."""
    result = VerificationResult()
    result.add_check("pil_image", batch.pil_image, V.not_none)
    result.add_check("image_embeds", batch.image_embeds, V.is_list)
    return result
fastvideo.pipelines.stages.ImageEncodingStage.verify_output
verify_output(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult

Verify image encoding stage outputs.

Source code in fastvideo/pipelines/stages/image_encoding.py
def verify_output(self, batch: ForwardBatch,
                  fastvideo_args: FastVideoArgs) -> VerificationResult:
    """Verify image encoding stage outputs."""
    result = VerificationResult()
    result.add_check("image_embeds", batch.image_embeds,
                     V.list_of_tensors_dims(3))
    return result

fastvideo.pipelines.stages.ImageVAEEncodingStage

ImageVAEEncodingStage(vae: ParallelTiledVAE)

Bases: PipelineStage

Stage for encoding image pixel representations into latent space.

This stage handles the encoding of image pixel representations into the final input format (e.g., latents) for image-to-video generation.

Source code in fastvideo/pipelines/stages/image_encoding.py
def __init__(self, vae: ParallelTiledVAE) -> None:
    self.vae: ParallelTiledVAE = vae

Functions

fastvideo.pipelines.stages.ImageVAEEncodingStage.forward
forward(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> ForwardBatch

Encode pixel representations into latent space.

Parameters:

Name Type Description Default
batch ForwardBatch

The current batch information.

required
fastvideo_args FastVideoArgs

The inference arguments.

required

Returns:

Type Description
ForwardBatch

The batch with encoded outputs.

Source code in fastvideo/pipelines/stages/image_encoding.py
def forward(
    self,
    batch: ForwardBatch,
    fastvideo_args: FastVideoArgs,
) -> ForwardBatch:
    """
    Encode pixel representations into latent space.

    Args:
        batch: The current batch information.
        fastvideo_args: The inference arguments.

    Returns:
        The batch with encoded outputs.
    """
    assert batch.pil_image is not None
    if fastvideo_args.mode == ExecutionMode.INFERENCE:
        assert batch.pil_image is not None and isinstance(
            batch.pil_image, PIL.Image.Image)
        assert batch.height is not None and isinstance(batch.height, int)
        assert batch.width is not None and isinstance(batch.width, int)
        assert batch.num_frames is not None and isinstance(
            batch.num_frames, int)
        height = batch.height
        width = batch.width
        num_frames = batch.num_frames
    elif fastvideo_args.mode == ExecutionMode.PREPROCESS:
        assert batch.pil_image is not None and isinstance(
            batch.pil_image, torch.Tensor)
        assert batch.height is not None and isinstance(batch.height, list)
        assert batch.width is not None and isinstance(batch.width, list)
        assert batch.num_frames is not None and isinstance(
            batch.num_frames, list)
        num_frames = batch.num_frames[0]
        height = batch.height[0]
        width = batch.width[0]

    self.vae = self.vae.to(get_local_torch_device())

    # Process single image for I2V
    latent_height = height // self.vae.spatial_compression_ratio
    latent_width = width // self.vae.spatial_compression_ratio
    image = batch.pil_image
    image = self.preprocess(
        image,
        vae_scale_factor=self.vae.spatial_compression_ratio,
        height=height,
        width=width).to(get_local_torch_device(), dtype=torch.float32)

    # (B, C, H, W) -> (B, C, 1, H, W)
    image = image.unsqueeze(2)

    video_condition = torch.cat([
        image,
        image.new_zeros(image.shape[0], image.shape[1], num_frames - 1,
                        image.shape[3], image.shape[4])
    ],
                                dim=2)
    video_condition = video_condition.to(device=get_local_torch_device(),
                                         dtype=torch.float32)

    # Setup VAE precision
    vae_dtype = PRECISION_TO_TYPE[
        fastvideo_args.pipeline_config.vae_precision]
    vae_autocast_enabled = (
        vae_dtype != torch.float32) and not fastvideo_args.disable_autocast

    # Encode Image
    with torch.autocast(device_type="cuda",
                        dtype=vae_dtype,
                        enabled=vae_autocast_enabled):
        if fastvideo_args.pipeline_config.vae_tiling:
            self.vae.enable_tiling()
        # if fastvideo_args.vae_sp:
        #     self.vae.enable_parallel()
        if not vae_autocast_enabled:
            video_condition = video_condition.to(vae_dtype)
        encoder_output = self.vae.encode(video_condition)

    if fastvideo_args.mode == ExecutionMode.PREPROCESS:
        latent_condition = encoder_output.mean
    else:
        generator = batch.generator
        if generator is None:
            raise ValueError("Generator must be provided")
        latent_condition = self.retrieve_latents(encoder_output, generator)

    # Apply shifting if needed
    if (hasattr(self.vae, "shift_factor")
            and self.vae.shift_factor is not None):
        if isinstance(self.vae.shift_factor, torch.Tensor):
            latent_condition -= self.vae.shift_factor.to(
                latent_condition.device, latent_condition.dtype)
        else:
            latent_condition -= self.vae.shift_factor

    if isinstance(self.vae.scaling_factor, torch.Tensor):
        latent_condition = latent_condition * self.vae.scaling_factor.to(
            latent_condition.device, latent_condition.dtype)
    else:
        latent_condition = latent_condition * self.vae.scaling_factor

    if fastvideo_args.mode == ExecutionMode.PREPROCESS:
        batch.image_latent = latent_condition
    else:
        mask_lat_size = torch.ones(1, 1, num_frames, latent_height,
                                   latent_width)
        mask_lat_size[:, :, list(range(1, num_frames))] = 0
        first_frame_mask = mask_lat_size[:, :, 0:1]
        first_frame_mask = torch.repeat_interleave(
            first_frame_mask,
            dim=2,
            repeats=self.vae.temporal_compression_ratio)
        mask_lat_size = torch.concat(
            [first_frame_mask, mask_lat_size[:, :, 1:, :]], dim=2)
        mask_lat_size = mask_lat_size.view(
            1, -1, self.vae.temporal_compression_ratio, latent_height,
            latent_width)
        mask_lat_size = mask_lat_size.transpose(1, 2)
        mask_lat_size = mask_lat_size.to(latent_condition.device)

        batch.image_latent = torch.concat([mask_lat_size, latent_condition],
                                          dim=1)

    # Offload models if needed
    if hasattr(self, 'maybe_free_model_hooks'):
        self.maybe_free_model_hooks()

    self.vae.to("cpu")

    return batch
fastvideo.pipelines.stages.ImageVAEEncodingStage.verify_input
verify_input(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult

Verify encoding stage inputs.

Source code in fastvideo/pipelines/stages/image_encoding.py
def verify_input(self, batch: ForwardBatch,
                 fastvideo_args: FastVideoArgs) -> VerificationResult:
    """Verify encoding stage inputs."""
    result = VerificationResult()
    result.add_check("generator", batch.generator,
                     V.generator_or_list_generators)
    if fastvideo_args.mode == ExecutionMode.PREPROCESS:
        result.add_check("height", batch.height, V.list_not_empty)
        result.add_check("width", batch.width, V.list_not_empty)
        result.add_check("num_frames", batch.num_frames, V.list_not_empty)
    else:
        result.add_check("height", batch.height, V.positive_int)
        result.add_check("width", batch.width, V.positive_int)
        result.add_check("num_frames", batch.num_frames, V.positive_int)
    return result
fastvideo.pipelines.stages.ImageVAEEncodingStage.verify_output
verify_output(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult

Verify encoding stage outputs.

Source code in fastvideo/pipelines/stages/image_encoding.py
def verify_output(self, batch: ForwardBatch,
                  fastvideo_args: FastVideoArgs) -> VerificationResult:
    """Verify encoding stage outputs."""
    result = VerificationResult()
    result.add_check("image_latent", batch.image_latent,
                     [V.is_tensor, V.with_dims(5)])
    return result

fastvideo.pipelines.stages.InputValidationStage

Bases: PipelineStage

Stage for validating and preparing inputs for diffusion pipelines.

This stage validates that all required inputs are present and properly formatted before proceeding with the diffusion process.

Functions

fastvideo.pipelines.stages.InputValidationStage.forward
forward(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> ForwardBatch

Validate and prepare inputs.

Parameters:

Name Type Description Default
batch ForwardBatch

The current batch information.

required
fastvideo_args FastVideoArgs

The inference arguments.

required

Returns:

Type Description
ForwardBatch

The validated batch information.

Source code in fastvideo/pipelines/stages/input_validation.py
def forward(
    self,
    batch: ForwardBatch,
    fastvideo_args: FastVideoArgs,
) -> ForwardBatch:
    """
    Validate and prepare inputs.

    Args:
        batch: The current batch information.
        fastvideo_args: The inference arguments.

    Returns:
        The validated batch information.
    """

    self._generate_seeds(batch, fastvideo_args)

    # Ensure prompt is properly formatted
    if batch.prompt is None and batch.prompt_embeds is None:
        raise ValueError(
            "Either `prompt` or `prompt_embeds` must be provided")

    # Ensure negative prompt is properly formatted if using classifier-free guidance
    if (batch.do_classifier_free_guidance and batch.negative_prompt is None
            and batch.negative_prompt_embeds is None):
        raise ValueError(
            "For classifier-free guidance, either `negative_prompt` or "
            "`negative_prompt_embeds` must be provided")

    # Validate height and width
    if batch.height is None or batch.width is None:
        raise ValueError(
            "Height and width must be provided. Please set `height` and `width`."
        )
    if batch.height % 8 != 0 or batch.width % 8 != 0:
        raise ValueError(
            f"Height and width must be divisible by 8 but are {batch.height} and {batch.width}."
        )

    # Validate number of inference steps
    if batch.num_inference_steps <= 0:
        raise ValueError(
            f"Number of inference steps must be positive, but got {batch.num_inference_steps}"
        )

    # Validate guidance scale if using classifier-free guidance
    if batch.do_classifier_free_guidance and batch.guidance_scale <= 0:
        raise ValueError(
            f"Guidance scale must be positive, but got {batch.guidance_scale}"
        )

    # for i2v, get image from image_path
    # @TODO(Wei) hard-coded for wan2.2 5b ti2v for now. Should put this in image_encoding stage
    if batch.image_path is not None:
        if batch.image_path.endswith(".mp4"):
            image = load_video(batch.image_path)[0]
        else:
            image = load_image(batch.image_path)
        batch.pil_image = image

    # further processing for ti2v task
    if fastvideo_args.pipeline_config.ti2v_task and batch.pil_image is not None:
        img = batch.pil_image
        ih, iw = img.height, img.width
        patch_size = fastvideo_args.pipeline_config.dit_config.arch_config.patch_size
        vae_stride = fastvideo_args.pipeline_config.vae_config.arch_config.scale_factor_spatial
        dh, dw = patch_size[1] * vae_stride, patch_size[2] * vae_stride
        max_area = 704 * 1280
        ow, oh = best_output_size(iw, ih, dw, dh, max_area)

        scale = max(ow / iw, oh / ih)
        img = img.resize((round(iw * scale), round(ih * scale)),
                         Image.LANCZOS)
        logger.info("resized img height: %s, img width: %s", img.height,
                    img.width)

        # center-crop
        x1 = (img.width - ow) // 2
        y1 = (img.height - oh) // 2
        img = img.crop((x1, y1, x1 + ow, y1 + oh))
        assert img.width == ow and img.height == oh

        # to tensor
        img = TF.to_tensor(img).sub_(0.5).div_(0.5).to(
            self.device).unsqueeze(1)
        img = img.unsqueeze(0)
        batch.height = oh
        batch.width = ow
        batch.pil_image = img

    # for v2v, get control video from video path
    if batch.video_path is not None:
        pil_images, original_fps = load_video(batch.video_path,
                                              return_fps=True)
        logger.info("Loaded video with %s frames, original FPS: %s",
                    len(pil_images), original_fps)

        # Get target parameters from batch
        target_fps = batch.fps
        target_num_frames = batch.num_frames
        target_height = batch.height
        target_width = batch.width

        if target_fps is not None and original_fps is not None:
            frame_skip = max(1, int(original_fps // target_fps))
            if frame_skip > 1:
                pil_images = pil_images[::frame_skip]
                effective_fps = original_fps / frame_skip
                logger.info(
                    "Resampled video from %.1f fps to %.1f fps (skip=%s)",
                    original_fps, effective_fps, frame_skip)

        # Limit to target number of frames
        if target_num_frames is not None and len(
                pil_images) > target_num_frames:
            pil_images = pil_images[:target_num_frames]
            logger.info("Limited video to %s frames (from %s total)",
                        target_num_frames, len(pil_images))

        # Resize each PIL image to target dimensions
        resized_images = []
        for pil_img in pil_images:
            resized_img = resize(pil_img,
                                 target_height,
                                 target_width,
                                 resize_mode="default",
                                 resample="lanczos")
            resized_images.append(resized_img)

        # Convert PIL images to numpy array
        video_numpy = pil_to_numpy(resized_images)
        video_numpy = normalize(video_numpy)
        video_tensor = numpy_to_pt(video_numpy)

        # Rearrange to [C, T, H, W] and add batch dimension -> [B, C, T, H, W]
        input_video = video_tensor.permute(1, 0, 2, 3).unsqueeze(0)

        batch.video_latent = input_video

    return batch
fastvideo.pipelines.stages.InputValidationStage.verify_input
verify_input(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult

Verify input validation stage inputs.

Source code in fastvideo/pipelines/stages/input_validation.py
def verify_input(self, batch: ForwardBatch,
                 fastvideo_args: FastVideoArgs) -> VerificationResult:
    """Verify input validation stage inputs."""
    result = VerificationResult()
    result.add_check("seed", batch.seed, [V.not_none, V.positive_int])
    result.add_check("num_videos_per_prompt", batch.num_videos_per_prompt,
                     V.positive_int)
    result.add_check(
        "prompt_or_embeds", None, lambda _: V.string_or_list_strings(
            batch.prompt) or V.list_not_empty(batch.prompt_embeds))
    result.add_check("height", batch.height, V.positive_int)
    result.add_check("width", batch.width, V.positive_int)
    result.add_check("num_inference_steps", batch.num_inference_steps,
                     V.positive_int)
    result.add_check(
        "guidance_scale", batch.guidance_scale, lambda x: not batch.
        do_classifier_free_guidance or V.positive_float(x))
    return result
fastvideo.pipelines.stages.InputValidationStage.verify_output
verify_output(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult

Verify input validation stage outputs.

Source code in fastvideo/pipelines/stages/input_validation.py
def verify_output(self, batch: ForwardBatch,
                  fastvideo_args: FastVideoArgs) -> VerificationResult:
    """Verify input validation stage outputs."""
    result = VerificationResult()
    result.add_check("seeds", batch.seeds, V.list_not_empty)
    result.add_check("generator", batch.generator,
                     V.generator_or_list_generators)
    return result

fastvideo.pipelines.stages.LatentPreparationStage

LatentPreparationStage(scheduler, transformer)

Bases: PipelineStage

Stage for preparing initial latent variables for the diffusion process.

This stage handles the preparation of the initial latent variables that will be denoised during the diffusion process.

Source code in fastvideo/pipelines/stages/latent_preparation.py
def __init__(self, scheduler, transformer) -> None:
    super().__init__()
    self.scheduler = scheduler
    self.transformer = transformer

Functions

fastvideo.pipelines.stages.LatentPreparationStage.adjust_video_length
adjust_video_length(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> int

Adjust video length based on VAE version.

Parameters:

Name Type Description Default
batch ForwardBatch

The current batch information.

required
fastvideo_args FastVideoArgs

The inference arguments.

required

Returns:

Type Description
int

The batch with adjusted video length.

Source code in fastvideo/pipelines/stages/latent_preparation.py
def adjust_video_length(self, batch: ForwardBatch,
                        fastvideo_args: FastVideoArgs) -> int:
    """
    Adjust video length based on VAE version.

    Args:
        batch: The current batch information.
        fastvideo_args: The inference arguments.

    Returns:
        The batch with adjusted video length.
    """

    video_length = batch.num_frames
    use_temporal_scaling_frames = fastvideo_args.pipeline_config.vae_config.use_temporal_scaling_frames
    if use_temporal_scaling_frames:
        temporal_scale_factor = fastvideo_args.pipeline_config.vae_config.arch_config.temporal_compression_ratio
        latent_num_frames = (video_length - 1) // temporal_scale_factor + 1
    else:  # stepvideo only
        latent_num_frames = video_length // 17 * 3
    return int(latent_num_frames)
fastvideo.pipelines.stages.LatentPreparationStage.forward
forward(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> ForwardBatch

Prepare initial latent variables for the diffusion process.

Parameters:

Name Type Description Default
batch ForwardBatch

The current batch information.

required
fastvideo_args FastVideoArgs

The inference arguments.

required

Returns:

Type Description
ForwardBatch

The batch with prepared latent variables.

Source code in fastvideo/pipelines/stages/latent_preparation.py
def forward(
    self,
    batch: ForwardBatch,
    fastvideo_args: FastVideoArgs,
) -> ForwardBatch:
    """
    Prepare initial latent variables for the diffusion process.

    Args:
        batch: The current batch information.
        fastvideo_args: The inference arguments.

    Returns:
        The batch with prepared latent variables.
    """

    latent_num_frames = None
    # Adjust video length based on VAE version if needed
    if hasattr(self, 'adjust_video_length'):
        latent_num_frames = self.adjust_video_length(batch, fastvideo_args)
    # Determine batch size
    if isinstance(batch.prompt, list):
        batch_size = len(batch.prompt)
    elif batch.prompt is not None:
        batch_size = 1
    else:
        batch_size = batch.prompt_embeds[0].shape[0]

    # Adjust batch size for number of videos per prompt
    batch_size *= batch.num_videos_per_prompt

    # Get required parameters
    dtype = batch.prompt_embeds[0].dtype
    device = get_local_torch_device()
    generator = batch.generator
    latents = batch.latents
    num_frames = latent_num_frames if latent_num_frames is not None else batch.num_frames
    height = batch.height
    width = batch.width

    # TODO(will): remove this once we add input/output validation for stages
    if height is None or width is None:
        raise ValueError("Height and width must be provided")

    # Calculate latent shape
    shape = (
        batch_size,
        self.transformer.num_channels_latents,
        num_frames,
        height // fastvideo_args.pipeline_config.vae_config.arch_config.
        spatial_compression_ratio,
        width // fastvideo_args.pipeline_config.vae_config.arch_config.
        spatial_compression_ratio,
    )

    # Validate generator if it's a list
    if isinstance(generator, list) and len(generator) != batch_size:
        raise ValueError(
            f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
            f" size of {batch_size}. Make sure the batch size matches the length of the generators."
        )
    # Generate or use provided latents
    if latents is None:
        latents = randn_tensor(shape,
                               generator=generator,
                               device=device,
                               dtype=dtype)
    else:
        latents = latents.to(device)

    # Scale the initial noise if needed
    if hasattr(self.scheduler, "init_noise_sigma"):
        latents = latents * self.scheduler.init_noise_sigma
    # Update batch with prepared latents
    batch.latents = latents
    batch.raw_latent_shape = latents.shape

    return batch
fastvideo.pipelines.stages.LatentPreparationStage.verify_input
verify_input(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult

Verify latent preparation stage inputs.

Source code in fastvideo/pipelines/stages/latent_preparation.py
def verify_input(self, batch: ForwardBatch,
                 fastvideo_args: FastVideoArgs) -> VerificationResult:
    """Verify latent preparation stage inputs."""
    result = VerificationResult()
    result.add_check(
        "prompt_or_embeds", None, lambda _: V.string_or_list_strings(
            batch.prompt) or V.list_not_empty(batch.prompt_embeds))
    result.add_check("prompt_embeds", batch.prompt_embeds,
                     V.list_of_tensors)
    result.add_check("num_videos_per_prompt", batch.num_videos_per_prompt,
                     V.positive_int)
    result.add_check("generator", batch.generator,
                     V.generator_or_list_generators)
    result.add_check("num_frames", batch.num_frames, V.positive_int)
    result.add_check("height", batch.height, V.positive_int)
    result.add_check("width", batch.width, V.positive_int)
    result.add_check("latents", batch.latents, V.none_or_tensor)
    return result
fastvideo.pipelines.stages.LatentPreparationStage.verify_output
verify_output(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult

Verify latent preparation stage outputs.

Source code in fastvideo/pipelines/stages/latent_preparation.py
def verify_output(self, batch: ForwardBatch,
                  fastvideo_args: FastVideoArgs) -> VerificationResult:
    """Verify latent preparation stage outputs."""
    result = VerificationResult()
    result.add_check("latents", batch.latents,
                     [V.is_tensor, V.with_dims(5)])
    result.add_check("raw_latent_shape", batch.raw_latent_shape, V.is_tuple)
    return result

fastvideo.pipelines.stages.PipelineStage

Bases: ABC

Abstract base class for all pipeline stages.

A pipeline stage represents a discrete step in the diffusion process that can be composed with other stages to create a complete pipeline. Each stage is responsible for a specific part of the process, such as prompt encoding, latent preparation, etc.

Attributes

fastvideo.pipelines.stages.PipelineStage.device property
device: device

Get the device for this stage.

Functions

fastvideo.pipelines.stages.PipelineStage.__call__
__call__(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> ForwardBatch

Execute the stage's processing on the batch with optional verification and logging. Should not be overridden by subclasses.

Parameters:

Name Type Description Default
batch ForwardBatch

The current batch information.

required
fastvideo_args FastVideoArgs

The inference arguments.

required

Returns:

Type Description
ForwardBatch

The updated batch information after this stage's processing.

Source code in fastvideo/pipelines/stages/base.py
def __call__(
    self,
    batch: ForwardBatch,
    fastvideo_args: FastVideoArgs,
) -> ForwardBatch:
    """
    Execute the stage's processing on the batch with optional verification and logging.
    Should not be overridden by subclasses.

    Args:
        batch: The current batch information.
        fastvideo_args: The inference arguments.

    Returns:
        The updated batch information after this stage's processing.
    """
    stage_name = self.__class__.__name__

    # Check if verification is enabled (simple approach for prototype)
    enable_verification = getattr(fastvideo_args,
                                  'enable_stage_verification', False)

    if enable_verification:
        # Pre-execution input verification
        try:
            input_result = self.verify_input(batch, fastvideo_args)
            self._run_verification(input_result, stage_name, "input")
        except Exception as e:
            logger.error("Input verification failed for %s: %s", stage_name,
                         str(e))
            raise

    # Execute the actual stage logic
    if envs.FASTVIDEO_STAGE_LOGGING:
        logger.info("[%s] Starting execution", stage_name)
        start_time = time.perf_counter()

        try:
            result = self.forward(batch, fastvideo_args)
            execution_time = time.perf_counter() - start_time
            logger.info("[%s] Execution completed in %s ms", stage_name,
                        execution_time * 1000)
            batch.logging_info.add_stage_execution_time(
                stage_name, execution_time)
        except Exception as e:
            execution_time = time.perf_counter() - start_time
            logger.error("[%s] Error during execution after %s ms: %s",
                         stage_name, execution_time * 1000, e)
            logger.error("[%s] Traceback: %s", stage_name,
                         traceback.format_exc())
            raise
    else:
        # Direct execution (current behavior)
        result = self.forward(batch, fastvideo_args)

    if enable_verification:
        # Post-execution output verification
        try:
            output_result = self.verify_output(result, fastvideo_args)
            self._run_verification(output_result, stage_name, "output")
        except Exception as e:
            logger.error("Output verification failed for %s: %s",
                         stage_name, str(e))
            raise

    return result
fastvideo.pipelines.stages.PipelineStage.forward abstractmethod
forward(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> ForwardBatch

Forward pass of the stage's processing.

This method should be implemented by subclasses to provide the forward processing logic for the stage.

Parameters:

Name Type Description Default
batch ForwardBatch

The current batch information.

required
fastvideo_args FastVideoArgs

The inference arguments.

required

Returns:

Type Description
ForwardBatch

The updated batch information after this stage's processing.

Source code in fastvideo/pipelines/stages/base.py
@abstractmethod
def forward(
    self,
    batch: ForwardBatch,
    fastvideo_args: FastVideoArgs,
) -> ForwardBatch:
    """
    Forward pass of the stage's processing.

    This method should be implemented by subclasses to provide the forward
    processing logic for the stage.

    Args:
        batch: The current batch information.
        fastvideo_args: The inference arguments.

    Returns:
        The updated batch information after this stage's processing.
    """
    raise NotImplementedError
fastvideo.pipelines.stages.PipelineStage.set_logging
set_logging(enable: bool)

Enable or disable logging for this stage.

Parameters:

Name Type Description Default
enable bool

Whether to enable logging.

required
Source code in fastvideo/pipelines/stages/base.py
def set_logging(self, enable: bool):
    """
    Enable or disable logging for this stage.

    Args:
        enable: Whether to enable logging.
    """
    self._enable_logging = enable
fastvideo.pipelines.stages.PipelineStage.verify_input
verify_input(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult

Verify the input for the stage.

Example

from fastvideo.pipelines.stages.validators import V, VerificationResult

def verify_input(self, batch, fastvideo_args): result = VerificationResult() result.add_check("height", batch.height, V.positive_int_divisible(8)) result.add_check("width", batch.width, V.positive_int_divisible(8)) result.add_check("image_latent", batch.image_latent, V.is_tensor) return result

Parameters:

Name Type Description Default
batch ForwardBatch

The current batch information.

required
fastvideo_args FastVideoArgs

The inference arguments.

required

Returns:

Type Description
VerificationResult

A VerificationResult containing the verification status.

Source code in fastvideo/pipelines/stages/base.py
def verify_input(self, batch: ForwardBatch,
                 fastvideo_args: FastVideoArgs) -> VerificationResult:
    """
    Verify the input for the stage.

    Example:
        from fastvideo.pipelines.stages.validators import V, VerificationResult

        def verify_input(self, batch, fastvideo_args):
            result = VerificationResult()
            result.add_check("height", batch.height, V.positive_int_divisible(8))
            result.add_check("width", batch.width, V.positive_int_divisible(8))
            result.add_check("image_latent", batch.image_latent, V.is_tensor)
            return result

    Args:
        batch: The current batch information.
        fastvideo_args: The inference arguments.

    Returns:
        A VerificationResult containing the verification status.

    """
    # Default implementation - no verification
    return VerificationResult()
fastvideo.pipelines.stages.PipelineStage.verify_output
verify_output(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult

Verify the output for the stage.

Parameters:

Name Type Description Default
batch ForwardBatch

The current batch information.

required
fastvideo_args FastVideoArgs

The inference arguments.

required

Returns:

Type Description
VerificationResult

A VerificationResult containing the verification status.

Source code in fastvideo/pipelines/stages/base.py
def verify_output(self, batch: ForwardBatch,
                  fastvideo_args: FastVideoArgs) -> VerificationResult:
    """
    Verify the output for the stage.

    Args:
        batch: The current batch information.
        fastvideo_args: The inference arguments.

    Returns:
        A VerificationResult containing the verification status.
    """
    # Default implementation - no verification
    return VerificationResult()

fastvideo.pipelines.stages.RefImageEncodingStage

RefImageEncodingStage(image_encoder, image_processor)

Bases: ImageEncodingStage

Stage for encoding reference image prompts into embeddings for Wan2.1 Control models.

This stage extends ImageEncodingStage with specialized preprocessing for reference images.

Source code in fastvideo/pipelines/stages/image_encoding.py
def __init__(self, image_encoder, image_processor) -> None:
    """
    Initialize the prompt encoding stage.

    Args:
        enable_logging: Whether to enable logging for this stage.
        is_secondary: Whether this is a secondary image encoder.
    """
    super().__init__()
    self.image_processor = image_processor
    self.image_encoder = image_encoder

Functions

fastvideo.pipelines.stages.RefImageEncodingStage.forward
forward(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> ForwardBatch

Encode the prompt into image encoder hidden states.

Parameters:

Name Type Description Default
batch ForwardBatch

The current batch information.

required
fastvideo_args FastVideoArgs

The inference arguments.

required

Returns:

Type Description
ForwardBatch

The batch with encoded prompt embeddings.

Source code in fastvideo/pipelines/stages/image_encoding.py
@torch.no_grad()
def forward(
    self,
    batch: ForwardBatch,
    fastvideo_args: FastVideoArgs,
) -> ForwardBatch:
    """
    Encode the prompt into image encoder hidden states.

    Args:
        batch: The current batch information.
        fastvideo_args: The inference arguments.

    Returns:
        The batch with encoded prompt embeddings.
    """
    self.image_encoder = self.image_encoder.to(get_local_torch_device())

    image = batch.pil_image
    if image is None:
        image = create_default_image()
    # Preprocess reference image for CLIP encoder
    image_tensor = preprocess_reference_image_for_clip(
        image, get_local_torch_device())

    image_inputs = self.image_processor(images=image_tensor,
                                        return_tensors="pt").to(
                                            get_local_torch_device())
    with set_forward_context(current_timestep=0, attn_metadata=None):
        outputs = self.image_encoder(**image_inputs)
        image_embeds = outputs.last_hidden_state
    batch.image_embeds.append(image_embeds)

    if batch.pil_image is None:
        batch.image_embeds = [
            torch.zeros_like(x) for x in batch.image_embeds
        ]

    return batch

fastvideo.pipelines.stages.StepvideoPromptEncodingStage

StepvideoPromptEncodingStage(stepllm, clip)

Bases: PipelineStage

Stage for encoding prompts using the remote caption API.

This stage applies the magic string transformations and calls the remote caption service asynchronously to get: - primary prompt embeddings, - an attention mask, - and a clip embedding.

Source code in fastvideo/pipelines/stages/stepvideo_encoding.py
def __init__(self, stepllm, clip) -> None:
    super().__init__()
    # self.caption_client = caption_client  # This should have a call_caption(prompts: List[str]) method.
    self.stepllm = stepllm
    self.clip = clip

Functions

fastvideo.pipelines.stages.StepvideoPromptEncodingStage.verify_input
verify_input(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult

Verify stepvideo encoding stage inputs.

Source code in fastvideo/pipelines/stages/stepvideo_encoding.py
def verify_input(self, batch: ForwardBatch,
                 fastvideo_args: FastVideoArgs) -> VerificationResult:
    """Verify stepvideo encoding stage inputs."""
    result = VerificationResult()
    result.add_check("prompt", batch.prompt, V.string_not_empty)
    return result
fastvideo.pipelines.stages.StepvideoPromptEncodingStage.verify_output
verify_output(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult

Verify stepvideo encoding stage outputs.

Source code in fastvideo/pipelines/stages/stepvideo_encoding.py
def verify_output(self, batch: ForwardBatch,
                  fastvideo_args: FastVideoArgs) -> VerificationResult:
    """Verify stepvideo encoding stage outputs."""
    result = VerificationResult()
    result.add_check("prompt_embeds", batch.prompt_embeds,
                     [V.is_tensor, V.with_dims(3)])
    result.add_check("negative_prompt_embeds", batch.negative_prompt_embeds,
                     [V.is_tensor, V.with_dims(3)])
    result.add_check("prompt_attention_mask", batch.prompt_attention_mask,
                     [V.is_tensor, V.with_dims(2)])
    result.add_check("negative_attention_mask",
                     batch.negative_attention_mask,
                     [V.is_tensor, V.with_dims(2)])
    result.add_check("clip_embedding_pos", batch.clip_embedding_pos,
                     [V.is_tensor, V.with_dims(2)])
    result.add_check("clip_embedding_neg", batch.clip_embedding_neg,
                     [V.is_tensor, V.with_dims(2)])
    return result

fastvideo.pipelines.stages.TextEncodingStage

TextEncodingStage(text_encoders, tokenizers)

Bases: PipelineStage

Stage for encoding text prompts into embeddings for diffusion models.

This stage handles the encoding of text prompts into the embedding space expected by the diffusion model.

Initialize the prompt encoding stage.

Parameters:

Name Type Description Default
enable_logging

Whether to enable logging for this stage.

required
is_secondary

Whether this is a secondary text encoder.

required
Source code in fastvideo/pipelines/stages/text_encoding.py
def __init__(self, text_encoders, tokenizers) -> None:
    """
    Initialize the prompt encoding stage.

    Args:
        enable_logging: Whether to enable logging for this stage.
        is_secondary: Whether this is a secondary text encoder.
    """
    super().__init__()
    self.tokenizers = tokenizers
    self.text_encoders = text_encoders

Functions

fastvideo.pipelines.stages.TextEncodingStage.encode_text
encode_text(text: str | list[str], fastvideo_args: FastVideoArgs, encoder_index: int | list[int] | None = None, return_attention_mask: bool = False, return_type: str = 'list', device: device | str | None = None, dtype: dtype | None = None, max_length: int | None = None, truncation: bool | None = None, padding: bool | str | None = None)

Encode plain text using selected text encoder(s) and return embeddings.

Parameters:

Name Type Description Default
text str | list[str]

A single string or a list of strings to encode.

required
fastvideo_args FastVideoArgs

The inference arguments providing pipeline config, including tokenizer and encoder settings, preprocess and postprocess functions.

required
encoder_index int | list[int] | None

Encoder selector by index. Accepts an int or list of ints.

None
return_attention_mask bool

If True, also return attention masks for each selected encoder.

False
return_type str

"list" (default) returns a list aligned with selection; "dict" returns a dict keyed by encoder index as a string; "stack" stacks along a new first dimension (requires matching shapes).

'list'
device device | str | None

Optional device override for inputs; defaults to local torch device.

None
dtype dtype | None

Optional dtype to cast returned embeddings to.

None
max_length int | None

Optional per-call tokenizer override.

None
truncation bool | None

Optional per-call tokenizer override.

None
padding bool | str | None

Optional per-call tokenizer override.

None

Returns:

Type Description

Depending on return_type and return_attention_mask:

  • list: List[Tensor] or (List[Tensor], List[Tensor])
  • dict: Dict[str, Tensor] or (Dict[str, Tensor], Dict[str, Tensor])
  • stack: Tensor of shape [num_encoders, ...] or a tuple with stacked attention masks
Source code in fastvideo/pipelines/stages/text_encoding.py
@torch.no_grad()
def encode_text(
    self,
    text: str | list[str],
    fastvideo_args: FastVideoArgs,
    encoder_index: int | list[int] | None = None,
    return_attention_mask: bool = False,
    return_type: str = "list",  # one of: "list", "dict", "stack"
    device: torch.device | str | None = None,
    dtype: torch.dtype | None = None,
    max_length: int | None = None,
    truncation: bool | None = None,
    padding: bool | str | None = None,
):
    """
    Encode plain text using selected text encoder(s) and return embeddings.

    Args:
        text: A single string or a list of strings to encode.
        fastvideo_args: The inference arguments providing pipeline config,
            including tokenizer and encoder settings, preprocess and postprocess
            functions.
        encoder_index: Encoder selector by index. Accepts an int or list of ints.
        return_attention_mask: If True, also return attention masks for each
            selected encoder.
        return_type: "list" (default) returns a list aligned with selection;
            "dict" returns a dict keyed by encoder index as a string; "stack" stacks along a
            new first dimension (requires matching shapes).
        device: Optional device override for inputs; defaults to local torch device.
        dtype: Optional dtype to cast returned embeddings to.
        max_length: Optional per-call tokenizer override.
        truncation: Optional per-call tokenizer override.
        padding: Optional per-call tokenizer override.

    Returns:
        Depending on return_type and return_attention_mask:
        - list: List[Tensor] or (List[Tensor], List[Tensor])
        - dict: Dict[str, Tensor] or (Dict[str, Tensor], Dict[str, Tensor])
        - stack: Tensor of shape [num_encoders, ...] or a tuple with stacked
          attention masks
    """

    assert len(self.tokenizers) == len(self.text_encoders)
    assert len(self.text_encoders) == len(
        fastvideo_args.pipeline_config.text_encoder_configs)

    # Resolve selection into indices
    encoder_cfgs = fastvideo_args.pipeline_config.text_encoder_configs
    if encoder_index is None:
        indices: list[int] = [0]
    elif isinstance(encoder_index, int):
        indices = [encoder_index]
    else:
        indices = list(encoder_index)
    # validate range
    num_encoders = len(self.text_encoders)
    for idx in indices:
        if idx < 0 or idx >= num_encoders:
            raise IndexError(
                f"encoder index {idx} out of range [0, {num_encoders-1}]")

    # Validate indices are within range
    num_encoders = len(self.text_encoders)

    # Normalize input to list[str]
    assert isinstance(text, str | list)
    if isinstance(text, str):
        texts: list[str] = [text]
    else:
        texts = text

    embeds_list: list[torch.Tensor] = []
    attn_masks_list: list[torch.Tensor] = []

    preprocess_funcs = fastvideo_args.pipeline_config.preprocess_text_funcs
    postprocess_funcs = fastvideo_args.pipeline_config.postprocess_text_funcs
    encoder_cfgs = fastvideo_args.pipeline_config.text_encoder_configs

    if return_type not in ("list", "dict", "stack"):
        raise ValueError(
            f"Invalid return_type '{return_type}'. Expected one of: 'list', 'dict', 'stack'"
        )

    target_device = device if device is not None else get_local_torch_device(
    )

    for i in indices:
        tokenizer = self.tokenizers[i]
        text_encoder = self.text_encoders[i]
        encoder_config = encoder_cfgs[i]
        preprocess_func = preprocess_funcs[i]
        postprocess_func = postprocess_funcs[i]

        processed_texts: list[str] = []
        for prompt_str in texts:
            processed_texts.append(preprocess_func(prompt_str))

        tok_kwargs = dict(encoder_config.tokenizer_kwargs)
        if max_length is not None:
            tok_kwargs["max_length"] = max_length
        if truncation is not None:
            tok_kwargs["truncation"] = truncation
        if padding is not None:
            tok_kwargs["padding"] = padding

        text_inputs = tokenizer(processed_texts,
                                **tok_kwargs).to(target_device)

        input_ids = text_inputs["input_ids"]
        attention_mask = text_inputs["attention_mask"]

        with set_forward_context(current_timestep=0, attn_metadata=None):
            outputs = text_encoder(
                input_ids=input_ids,
                attention_mask=attention_mask,
                output_hidden_states=True,
            )

        prompt_embeds = postprocess_func(outputs)
        if dtype is not None:
            prompt_embeds = prompt_embeds.to(dtype=dtype)
        embeds_list.append(prompt_embeds)
        if return_attention_mask:
            attn_masks_list.append(attention_mask)

    # Shape results according to return_type
    if return_type == "list":
        if return_attention_mask:
            return embeds_list, attn_masks_list
        return embeds_list

    if return_type == "dict":
        key_strs = [str(i) for i in indices]
        embeds_dict = {
            k: v
            for k, v in zip(key_strs, embeds_list, strict=False)
        }
        if return_attention_mask:
            attn_dict = {
                k: v
                for k, v in zip(key_strs, attn_masks_list, strict=False)
            }
            return embeds_dict, attn_dict
        return embeds_dict

    # return_type == "stack"
    # Validate shapes are compatible
    base_shape = list(embeds_list[0].shape)
    for t in embeds_list[1:]:
        if list(t.shape) != base_shape:
            raise ValueError(
                f"Cannot stack embeddings with differing shapes: {[list(t.shape) for t in embeds_list]}"
            )
    stacked_embeds = torch.stack(embeds_list, dim=0)
    if return_attention_mask:
        base_mask_shape = list(attn_masks_list[0].shape)
        for m in attn_masks_list[1:]:
            if list(m.shape) != base_mask_shape:
                raise ValueError(
                    f"Cannot stack attention masks with differing shapes: {[list(m.shape) for m in attn_masks_list]}"
                )
        stacked_masks = torch.stack(attn_masks_list, dim=0)
        return stacked_embeds, stacked_masks
    return stacked_embeds
fastvideo.pipelines.stages.TextEncodingStage.forward
forward(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> ForwardBatch

Encode the prompt into text encoder hidden states.

Parameters:

Name Type Description Default
batch ForwardBatch

The current batch information.

required
fastvideo_args FastVideoArgs

The inference arguments.

required

Returns:

Type Description
ForwardBatch

The batch with encoded prompt embeddings.

Source code in fastvideo/pipelines/stages/text_encoding.py
@torch.no_grad()
def forward(
    self,
    batch: ForwardBatch,
    fastvideo_args: FastVideoArgs,
) -> ForwardBatch:
    """
    Encode the prompt into text encoder hidden states.

    Args:
        batch: The current batch information.
        fastvideo_args: The inference arguments.

    Returns:
        The batch with encoded prompt embeddings.
    """
    assert len(self.tokenizers) == len(self.text_encoders)
    assert len(self.text_encoders) == len(
        fastvideo_args.pipeline_config.text_encoder_configs)

    # Encode positive prompt with all available encoders
    assert batch.prompt is not None
    prompt_text: str | list[str] = batch.prompt
    all_indices: list[int] = list(range(len(self.text_encoders)))
    prompt_embeds_list, prompt_masks_list = self.encode_text(
        prompt_text,
        fastvideo_args,
        encoder_index=all_indices,
        return_attention_mask=True,
    )

    for pe in prompt_embeds_list:
        batch.prompt_embeds.append(pe)
    if batch.prompt_attention_mask is not None:
        for am in prompt_masks_list:
            batch.prompt_attention_mask.append(am)

    # Encode negative prompt if CFG is enabled
    if batch.do_classifier_free_guidance:
        assert isinstance(batch.negative_prompt, str)
        neg_embeds_list, neg_masks_list = self.encode_text(
            batch.negative_prompt,
            fastvideo_args,
            encoder_index=all_indices,
            return_attention_mask=True,
        )

        assert batch.negative_prompt_embeds is not None
        for ne in neg_embeds_list:
            batch.negative_prompt_embeds.append(ne)
        if batch.negative_attention_mask is not None:
            for nm in neg_masks_list:
                batch.negative_attention_mask.append(nm)

    return batch
fastvideo.pipelines.stages.TextEncodingStage.verify_input
verify_input(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult

Verify text encoding stage inputs.

Source code in fastvideo/pipelines/stages/text_encoding.py
def verify_input(self, batch: ForwardBatch,
                 fastvideo_args: FastVideoArgs) -> VerificationResult:
    """Verify text encoding stage inputs."""
    result = VerificationResult()
    result.add_check("prompt", batch.prompt, V.string_or_list_strings)
    result.add_check(
        "negative_prompt", batch.negative_prompt, lambda x: not batch.
        do_classifier_free_guidance or V.string_not_empty(x))
    result.add_check("do_classifier_free_guidance",
                     batch.do_classifier_free_guidance, V.bool_value)
    result.add_check("prompt_embeds", batch.prompt_embeds, V.is_list)
    result.add_check("negative_prompt_embeds", batch.negative_prompt_embeds,
                     V.none_or_list)
    return result
fastvideo.pipelines.stages.TextEncodingStage.verify_output
verify_output(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult

Verify text encoding stage outputs.

Source code in fastvideo/pipelines/stages/text_encoding.py
def verify_output(self, batch: ForwardBatch,
                  fastvideo_args: FastVideoArgs) -> VerificationResult:
    """Verify text encoding stage outputs."""
    result = VerificationResult()
    result.add_check("prompt_embeds", batch.prompt_embeds,
                     V.list_of_tensors_min_dims(2))
    result.add_check(
        "negative_prompt_embeds", batch.negative_prompt_embeds,
        lambda x: not batch.do_classifier_free_guidance or V.
        list_of_tensors_with_min_dims(x, 2))
    return result

fastvideo.pipelines.stages.TimestepPreparationStage

TimestepPreparationStage(scheduler)

Bases: PipelineStage

Stage for preparing timesteps for the diffusion process.

This stage handles the preparation of the timestep sequence that will be used during the diffusion process.

Source code in fastvideo/pipelines/stages/timestep_preparation.py
def __init__(self, scheduler) -> None:
    self.scheduler = scheduler

Functions

fastvideo.pipelines.stages.TimestepPreparationStage.forward
forward(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> ForwardBatch

Prepare timesteps for the diffusion process.

Parameters:

Name Type Description Default
batch ForwardBatch

The current batch information.

required
fastvideo_args FastVideoArgs

The inference arguments.

required

Returns:

Type Description
ForwardBatch

The batch with prepared timesteps.

Source code in fastvideo/pipelines/stages/timestep_preparation.py
def forward(
    self,
    batch: ForwardBatch,
    fastvideo_args: FastVideoArgs,
) -> ForwardBatch:
    """
    Prepare timesteps for the diffusion process.

    Args:
        batch: The current batch information.
        fastvideo_args: The inference arguments.

    Returns:
        The batch with prepared timesteps.
    """
    scheduler = self.scheduler
    device = get_local_torch_device()
    num_inference_steps = batch.num_inference_steps
    timesteps = batch.timesteps
    sigmas = batch.sigmas
    n_tokens = batch.n_tokens

    # Prepare extra kwargs for set_timesteps
    extra_set_timesteps_kwargs = {}
    if n_tokens is not None and "n_tokens" in inspect.signature(
            scheduler.set_timesteps).parameters:
        extra_set_timesteps_kwargs["n_tokens"] = n_tokens

    # Handle custom timesteps or sigmas
    if timesteps is not None and sigmas is not None:
        raise ValueError(
            "Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values"
        )

    if timesteps is not None:
        accepts_timesteps = "timesteps" in inspect.signature(
            scheduler.set_timesteps).parameters
        if not accepts_timesteps:
            raise ValueError(
                f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
                f" timestep schedules. Please check whether you are using the correct scheduler."
            )
        scheduler.set_timesteps(timesteps=timesteps,
                                device=device,
                                **extra_set_timesteps_kwargs)
        timesteps = scheduler.timesteps
    elif sigmas is not None:
        accept_sigmas = "sigmas" in inspect.signature(
            scheduler.set_timesteps).parameters
        if not accept_sigmas:
            raise ValueError(
                f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
                f" sigmas schedules. Please check whether you are using the correct scheduler."
            )
        scheduler.set_timesteps(sigmas=sigmas,
                                device=device,
                                **extra_set_timesteps_kwargs)
        timesteps = scheduler.timesteps
    else:
        scheduler.set_timesteps(num_inference_steps,
                                device=device,
                                **extra_set_timesteps_kwargs)
        timesteps = scheduler.timesteps

    # Update batch with prepared timesteps
    batch.timesteps = timesteps

    return batch
fastvideo.pipelines.stages.TimestepPreparationStage.verify_input
verify_input(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult

Verify timestep preparation stage inputs.

Source code in fastvideo/pipelines/stages/timestep_preparation.py
def verify_input(self, batch: ForwardBatch,
                 fastvideo_args: FastVideoArgs) -> VerificationResult:
    """Verify timestep preparation stage inputs."""
    result = VerificationResult()
    result.add_check("num_inference_steps", batch.num_inference_steps,
                     V.positive_int)
    result.add_check("timesteps", batch.timesteps, V.none_or_tensor)
    result.add_check("sigmas", batch.sigmas, V.none_or_list)
    result.add_check("n_tokens", batch.n_tokens, V.none_or_positive_int)
    return result
fastvideo.pipelines.stages.TimestepPreparationStage.verify_output
verify_output(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult

Verify timestep preparation stage outputs.

Source code in fastvideo/pipelines/stages/timestep_preparation.py
def verify_output(self, batch: ForwardBatch,
                  fastvideo_args: FastVideoArgs) -> VerificationResult:
    """Verify timestep preparation stage outputs."""
    result = VerificationResult()
    result.add_check("timesteps", batch.timesteps,
                     [V.is_tensor, V.with_dims(1)])
    return result

fastvideo.pipelines.stages.VideoVAEEncodingStage

VideoVAEEncodingStage(vae: ParallelTiledVAE)

Bases: ImageVAEEncodingStage

Stage for encoding video pixel representations into latent space.

This stage handles the encoding of video pixel representations for video-to-video generation and control. Inherits from ImageVAEEncodingStage to reuse common functionality.

Source code in fastvideo/pipelines/stages/image_encoding.py
def __init__(self, vae: ParallelTiledVAE) -> None:
    self.vae: ParallelTiledVAE = vae

Functions

fastvideo.pipelines.stages.VideoVAEEncodingStage.forward
forward(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> ForwardBatch

Encode video pixel representations into latent space.

Parameters:

Name Type Description Default
batch ForwardBatch

The current batch information.

required
fastvideo_args FastVideoArgs

The inference arguments.

required

Returns:

Type Description
ForwardBatch

The batch with encoded outputs.

Source code in fastvideo/pipelines/stages/image_encoding.py
def forward(
    self,
    batch: ForwardBatch,
    fastvideo_args: FastVideoArgs,
) -> ForwardBatch:
    """
    Encode video pixel representations into latent space.

    Args:
        batch: The current batch information.
        fastvideo_args: The inference arguments.

    Returns:
        The batch with encoded outputs.
    """
    assert batch.video_latent is not None, "Video latent input is required for VideoVAEEncodingStage"

    if fastvideo_args.mode == ExecutionMode.INFERENCE:
        assert batch.height is not None and isinstance(batch.height, int)
        assert batch.width is not None and isinstance(batch.width, int)
        assert batch.num_frames is not None and isinstance(
            batch.num_frames, int)
        height = batch.height
        width = batch.width
        num_frames = batch.num_frames
    elif fastvideo_args.mode == ExecutionMode.PREPROCESS:
        assert batch.height is not None and isinstance(batch.height, list)
        assert batch.width is not None and isinstance(batch.width, list)
        assert batch.num_frames is not None and isinstance(
            batch.num_frames, list)
        num_frames = batch.num_frames[0]
        height = batch.height[0]
        width = batch.width[0]

    self.vae = self.vae.to(get_local_torch_device())

    # Prepare video tensor from control video
    video_condition = self._prepare_control_video_tensor(
        batch.video_latent, num_frames, height,
        width).to(get_local_torch_device(), dtype=torch.float32)

    # Setup VAE precision
    vae_dtype = PRECISION_TO_TYPE[
        fastvideo_args.pipeline_config.vae_precision]
    vae_autocast_enabled = (
        vae_dtype != torch.float32) and not fastvideo_args.disable_autocast

    # Encode control video
    with torch.autocast(device_type="cuda",
                        dtype=vae_dtype,
                        enabled=vae_autocast_enabled):
        if fastvideo_args.pipeline_config.vae_tiling:
            self.vae.enable_tiling()
        if not vae_autocast_enabled:
            video_condition = video_condition.to(vae_dtype)
        encoder_output = self.vae.encode(video_condition)

    generator = batch.generator
    if generator is None:
        raise ValueError("Generator must be provided")
    latent_condition = self.retrieve_latents(encoder_output, generator)

    if (hasattr(self.vae, "shift_factor")
            and self.vae.shift_factor is not None):
        if isinstance(self.vae.shift_factor, torch.Tensor):
            latent_condition -= self.vae.shift_factor.to(
                latent_condition.device, latent_condition.dtype)
        else:
            latent_condition -= self.vae.shift_factor

    if isinstance(self.vae.scaling_factor, torch.Tensor):
        latent_condition = latent_condition * self.vae.scaling_factor.to(
            latent_condition.device, latent_condition.dtype)
    else:
        latent_condition = latent_condition * self.vae.scaling_factor

    batch.video_latent = latent_condition

    # Offload models if needed
    if hasattr(self, 'maybe_free_model_hooks'):
        self.maybe_free_model_hooks()

    self.vae.to("cpu")

    return batch
fastvideo.pipelines.stages.VideoVAEEncodingStage.verify_input
verify_input(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult

Verify video encoding stage inputs.

Source code in fastvideo/pipelines/stages/image_encoding.py
def verify_input(self, batch: ForwardBatch,
                 fastvideo_args: FastVideoArgs) -> VerificationResult:
    """Verify video encoding stage inputs."""
    result = VerificationResult()
    result.add_check("video_latent", batch.video_latent, V.not_none)
    result.add_check("generator", batch.generator,
                     V.generator_or_list_generators)
    if fastvideo_args.mode == ExecutionMode.PREPROCESS:
        result.add_check("height", batch.height, V.list_not_empty)
        result.add_check("width", batch.width, V.list_not_empty)
        result.add_check("num_frames", batch.num_frames, V.list_not_empty)
    else:
        result.add_check("height", batch.height, V.positive_int)
        result.add_check("width", batch.width, V.positive_int)
        result.add_check("num_frames", batch.num_frames, V.positive_int)
    return result
fastvideo.pipelines.stages.VideoVAEEncodingStage.verify_output
verify_output(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult

Verify video encoding stage outputs.

Source code in fastvideo/pipelines/stages/image_encoding.py
def verify_output(self, batch: ForwardBatch,
                  fastvideo_args: FastVideoArgs) -> VerificationResult:
    """Verify video encoding stage outputs."""
    result = VerificationResult()
    result.add_check("video_latent", batch.video_latent,
                     [V.is_tensor, V.with_dims(5)])
    return result

Modules

fastvideo.pipelines.stages.base

Base classes for pipeline stages.

This module defines the abstract base classes for pipeline stages that can be composed to create complete diffusion pipelines.

Classes

fastvideo.pipelines.stages.base.PipelineStage

Bases: ABC

Abstract base class for all pipeline stages.

A pipeline stage represents a discrete step in the diffusion process that can be composed with other stages to create a complete pipeline. Each stage is responsible for a specific part of the process, such as prompt encoding, latent preparation, etc.

Attributes
fastvideo.pipelines.stages.base.PipelineStage.device property
device: device

Get the device for this stage.

Functions
fastvideo.pipelines.stages.base.PipelineStage.__call__
__call__(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> ForwardBatch

Execute the stage's processing on the batch with optional verification and logging. Should not be overridden by subclasses.

Parameters:

Name Type Description Default
batch ForwardBatch

The current batch information.

required
fastvideo_args FastVideoArgs

The inference arguments.

required

Returns:

Type Description
ForwardBatch

The updated batch information after this stage's processing.

Source code in fastvideo/pipelines/stages/base.py
def __call__(
    self,
    batch: ForwardBatch,
    fastvideo_args: FastVideoArgs,
) -> ForwardBatch:
    """
    Execute the stage's processing on the batch with optional verification and logging.
    Should not be overridden by subclasses.

    Args:
        batch: The current batch information.
        fastvideo_args: The inference arguments.

    Returns:
        The updated batch information after this stage's processing.
    """
    stage_name = self.__class__.__name__

    # Check if verification is enabled (simple approach for prototype)
    enable_verification = getattr(fastvideo_args,
                                  'enable_stage_verification', False)

    if enable_verification:
        # Pre-execution input verification
        try:
            input_result = self.verify_input(batch, fastvideo_args)
            self._run_verification(input_result, stage_name, "input")
        except Exception as e:
            logger.error("Input verification failed for %s: %s", stage_name,
                         str(e))
            raise

    # Execute the actual stage logic
    if envs.FASTVIDEO_STAGE_LOGGING:
        logger.info("[%s] Starting execution", stage_name)
        start_time = time.perf_counter()

        try:
            result = self.forward(batch, fastvideo_args)
            execution_time = time.perf_counter() - start_time
            logger.info("[%s] Execution completed in %s ms", stage_name,
                        execution_time * 1000)
            batch.logging_info.add_stage_execution_time(
                stage_name, execution_time)
        except Exception as e:
            execution_time = time.perf_counter() - start_time
            logger.error("[%s] Error during execution after %s ms: %s",
                         stage_name, execution_time * 1000, e)
            logger.error("[%s] Traceback: %s", stage_name,
                         traceback.format_exc())
            raise
    else:
        # Direct execution (current behavior)
        result = self.forward(batch, fastvideo_args)

    if enable_verification:
        # Post-execution output verification
        try:
            output_result = self.verify_output(result, fastvideo_args)
            self._run_verification(output_result, stage_name, "output")
        except Exception as e:
            logger.error("Output verification failed for %s: %s",
                         stage_name, str(e))
            raise

    return result
fastvideo.pipelines.stages.base.PipelineStage.forward abstractmethod
forward(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> ForwardBatch

Forward pass of the stage's processing.

This method should be implemented by subclasses to provide the forward processing logic for the stage.

Parameters:

Name Type Description Default
batch ForwardBatch

The current batch information.

required
fastvideo_args FastVideoArgs

The inference arguments.

required

Returns:

Type Description
ForwardBatch

The updated batch information after this stage's processing.

Source code in fastvideo/pipelines/stages/base.py
@abstractmethod
def forward(
    self,
    batch: ForwardBatch,
    fastvideo_args: FastVideoArgs,
) -> ForwardBatch:
    """
    Forward pass of the stage's processing.

    This method should be implemented by subclasses to provide the forward
    processing logic for the stage.

    Args:
        batch: The current batch information.
        fastvideo_args: The inference arguments.

    Returns:
        The updated batch information after this stage's processing.
    """
    raise NotImplementedError
fastvideo.pipelines.stages.base.PipelineStage.set_logging
set_logging(enable: bool)

Enable or disable logging for this stage.

Parameters:

Name Type Description Default
enable bool

Whether to enable logging.

required
Source code in fastvideo/pipelines/stages/base.py
def set_logging(self, enable: bool):
    """
    Enable or disable logging for this stage.

    Args:
        enable: Whether to enable logging.
    """
    self._enable_logging = enable
fastvideo.pipelines.stages.base.PipelineStage.verify_input
verify_input(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult

Verify the input for the stage.

Example

from fastvideo.pipelines.stages.validators import V, VerificationResult

def verify_input(self, batch, fastvideo_args): result = VerificationResult() result.add_check("height", batch.height, V.positive_int_divisible(8)) result.add_check("width", batch.width, V.positive_int_divisible(8)) result.add_check("image_latent", batch.image_latent, V.is_tensor) return result

Parameters:

Name Type Description Default
batch ForwardBatch

The current batch information.

required
fastvideo_args FastVideoArgs

The inference arguments.

required

Returns:

Type Description
VerificationResult

A VerificationResult containing the verification status.

Source code in fastvideo/pipelines/stages/base.py
def verify_input(self, batch: ForwardBatch,
                 fastvideo_args: FastVideoArgs) -> VerificationResult:
    """
    Verify the input for the stage.

    Example:
        from fastvideo.pipelines.stages.validators import V, VerificationResult

        def verify_input(self, batch, fastvideo_args):
            result = VerificationResult()
            result.add_check("height", batch.height, V.positive_int_divisible(8))
            result.add_check("width", batch.width, V.positive_int_divisible(8))
            result.add_check("image_latent", batch.image_latent, V.is_tensor)
            return result

    Args:
        batch: The current batch information.
        fastvideo_args: The inference arguments.

    Returns:
        A VerificationResult containing the verification status.

    """
    # Default implementation - no verification
    return VerificationResult()
fastvideo.pipelines.stages.base.PipelineStage.verify_output
verify_output(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult

Verify the output for the stage.

Parameters:

Name Type Description Default
batch ForwardBatch

The current batch information.

required
fastvideo_args FastVideoArgs

The inference arguments.

required

Returns:

Type Description
VerificationResult

A VerificationResult containing the verification status.

Source code in fastvideo/pipelines/stages/base.py
def verify_output(self, batch: ForwardBatch,
                  fastvideo_args: FastVideoArgs) -> VerificationResult:
    """
    Verify the output for the stage.

    Args:
        batch: The current batch information.
        fastvideo_args: The inference arguments.

    Returns:
        A VerificationResult containing the verification status.
    """
    # Default implementation - no verification
    return VerificationResult()
fastvideo.pipelines.stages.base.StageVerificationError

Bases: Exception

Exception raised when stage verification fails.

Functions

fastvideo.pipelines.stages.causal_denoising

Classes

fastvideo.pipelines.stages.causal_denoising.CausalDMDDenosingStage
CausalDMDDenosingStage(transformer, scheduler, transformer_2=None)

Bases: DenoisingStage

Denoising stage for causal diffusion.

Source code in fastvideo/pipelines/stages/causal_denoising.py
def __init__(self, transformer, scheduler, transformer_2=None) -> None:
    super().__init__(transformer, scheduler, transformer_2)
    # KV and cross-attention cache state (initialized on first forward)
    self.transformer = transformer
    self.transformer_2 = transformer_2
    self.kv_cache1: list | None = None
    self.crossattn_cache: list | None = None
    # Model-dependent constants (aligned with causal_inference.py assumptions)
    self.num_transformer_blocks = len(self.transformer.blocks)
    self.num_frames_per_block = self.transformer.config.arch_config.num_frames_per_block
    self.sliding_window_num_frames = self.transformer.config.arch_config.sliding_window_num_frames

    try:
        self.local_attn_size = getattr(self.transformer.model,
                                       "local_attn_size",
                                       -1)  # type: ignore
    except Exception:
        self.local_attn_size = -1
Functions
fastvideo.pipelines.stages.causal_denoising.CausalDMDDenosingStage.verify_input
verify_input(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult

Verify denoising stage inputs.

Source code in fastvideo/pipelines/stages/causal_denoising.py
def verify_input(self, batch: ForwardBatch,
                 fastvideo_args: FastVideoArgs) -> VerificationResult:
    """Verify denoising stage inputs."""
    result = VerificationResult()
    result.add_check("latents", batch.latents,
                     [V.is_tensor, V.with_dims(5)])
    result.add_check("prompt_embeds", batch.prompt_embeds, V.list_not_empty)
    result.add_check("image_embeds", batch.image_embeds, V.is_list)
    result.add_check("image_latent", batch.image_latent,
                     V.none_or_tensor_with_dims(5))
    result.add_check("num_inference_steps", batch.num_inference_steps,
                     V.positive_int)
    result.add_check("guidance_scale", batch.guidance_scale,
                     V.positive_float)
    result.add_check("eta", batch.eta, V.non_negative_float)
    result.add_check("generator", batch.generator,
                     V.generator_or_list_generators)
    result.add_check("do_classifier_free_guidance",
                     batch.do_classifier_free_guidance, V.bool_value)
    result.add_check(
        "negative_prompt_embeds", batch.negative_prompt_embeds, lambda x:
        not batch.do_classifier_free_guidance or V.list_not_empty(x))
    return result

Functions

fastvideo.pipelines.stages.conditioning

Conditioning stage for diffusion pipelines.

Classes

fastvideo.pipelines.stages.conditioning.ConditioningStage

Bases: PipelineStage

Stage for applying conditioning to the diffusion process.

This stage handles the application of conditioning, such as classifier-free guidance, to the diffusion process.

Functions
fastvideo.pipelines.stages.conditioning.ConditioningStage.forward
forward(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> ForwardBatch

Apply conditioning to the diffusion process.

Parameters:

Name Type Description Default
batch ForwardBatch

The current batch information.

required
fastvideo_args FastVideoArgs

The inference arguments.

required

Returns:

Type Description
ForwardBatch

The batch with applied conditioning.

Source code in fastvideo/pipelines/stages/conditioning.py
@torch.no_grad()
def forward(
    self,
    batch: ForwardBatch,
    fastvideo_args: FastVideoArgs,
) -> ForwardBatch:
    """
    Apply conditioning to the diffusion process.

    Args:
        batch: The current batch information.
        fastvideo_args: The inference arguments.

    Returns:
        The batch with applied conditioning.
    """
    # TODO!!
    if not batch.do_classifier_free_guidance:
        return batch
    else:
        return batch

    logger.info("batch.negative_prompt_embeds: %s",
                batch.negative_prompt_embeds)
    logger.info("do_classifier_free_guidance: %s",
                batch.do_classifier_free_guidance)
    logger.info("cfg_scale: %s", batch.guidance_scale)

    # Ensure negative prompt embeddings are available
    assert batch.negative_prompt_embeds is not None, (
        "Negative prompt embeddings are required for classifier-free guidance"
    )

    # Concatenate primary embeddings and masks
    batch.prompt_embeds = torch.cat(
        [batch.negative_prompt_embeds, batch.prompt_embeds])
    if batch.attention_mask is not None:
        batch.attention_mask = torch.cat(
            [batch.negative_attention_mask, batch.attention_mask])

    # Concatenate secondary embeddings and masks if present
    if batch.prompt_embeds_2 is not None:
        batch.prompt_embeds_2 = torch.cat(
            [batch.negative_prompt_embeds_2, batch.prompt_embeds_2])
    if batch.attention_mask_2 is not None:
        batch.attention_mask_2 = torch.cat(
            [batch.negative_attention_mask_2, batch.attention_mask_2])

    return batch
fastvideo.pipelines.stages.conditioning.ConditioningStage.verify_input
verify_input(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult

Verify conditioning stage inputs.

Source code in fastvideo/pipelines/stages/conditioning.py
def verify_input(self, batch: ForwardBatch,
                 fastvideo_args: FastVideoArgs) -> VerificationResult:
    """Verify conditioning stage inputs."""
    result = VerificationResult()
    result.add_check("do_classifier_free_guidance",
                     batch.do_classifier_free_guidance, V.bool_value)
    result.add_check("guidance_scale", batch.guidance_scale,
                     V.positive_float)
    result.add_check("prompt_embeds", batch.prompt_embeds, V.list_not_empty)
    result.add_check(
        "negative_prompt_embeds", batch.negative_prompt_embeds, lambda x:
        not batch.do_classifier_free_guidance or V.list_not_empty(x))
    return result
fastvideo.pipelines.stages.conditioning.ConditioningStage.verify_output
verify_output(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult

Verify conditioning stage outputs.

Source code in fastvideo/pipelines/stages/conditioning.py
def verify_output(self, batch: ForwardBatch,
                  fastvideo_args: FastVideoArgs) -> VerificationResult:
    """Verify conditioning stage outputs."""
    result = VerificationResult()
    result.add_check("prompt_embeds", batch.prompt_embeds, V.list_not_empty)
    return result

Functions

fastvideo.pipelines.stages.decoding

Decoding stage for diffusion pipelines.

Classes

fastvideo.pipelines.stages.decoding.DecodingStage
DecodingStage(vae, pipeline=None)

Bases: PipelineStage

Stage for decoding latent representations into pixel space.

This stage handles the decoding of latent representations into the final output format (e.g., pixel values).

Source code in fastvideo/pipelines/stages/decoding.py
def __init__(self, vae, pipeline=None) -> None:
    self.vae: ParallelTiledVAE = vae
    self.pipeline = weakref.ref(pipeline) if pipeline else None
Functions
fastvideo.pipelines.stages.decoding.DecodingStage.decode
decode(latents: Tensor, fastvideo_args: FastVideoArgs) -> Tensor

Decode latent representations into pixel space using VAE.

Parameters:

Name Type Description Default
latents Tensor

Input latent tensor with shape (batch, channels, frames, height_latents, width_latents)

required
fastvideo_args FastVideoArgs

Configuration containing: - disable_autocast: Whether to disable automatic mixed precision (default: False) - pipeline_config.vae_precision: VAE computation precision ("fp32", "fp16", "bf16") - pipeline_config.vae_tiling: Whether to enable VAE tiling for memory efficiency

required

Returns:

Type Description
Tensor

Decoded video tensor with shape (batch, channels, frames, height, width),

Tensor

normalized to [0, 1] range and moved to CPU as float32

Source code in fastvideo/pipelines/stages/decoding.py
@torch.no_grad()
def decode(self, latents: torch.Tensor,
           fastvideo_args: FastVideoArgs) -> torch.Tensor:
    """
    Decode latent representations into pixel space using VAE.

    Args:
        latents: Input latent tensor with shape (batch, channels, frames, height_latents, width_latents)
        fastvideo_args: Configuration containing:
            - disable_autocast: Whether to disable automatic mixed precision (default: False)
            - pipeline_config.vae_precision: VAE computation precision ("fp32", "fp16", "bf16")
            - pipeline_config.vae_tiling: Whether to enable VAE tiling for memory efficiency

    Returns:
        Decoded video tensor with shape (batch, channels, frames, height, width), 
        normalized to [0, 1] range and moved to CPU as float32
    """
    self.vae = self.vae.to(get_local_torch_device())
    latents = latents.to(get_local_torch_device())

    # Setup VAE precision
    vae_dtype = PRECISION_TO_TYPE[
        fastvideo_args.pipeline_config.vae_precision]
    vae_autocast_enabled = (
        vae_dtype != torch.float32) and not fastvideo_args.disable_autocast

    if hasattr(self.vae, 'scaling_factor'):
        if isinstance(self.vae.scaling_factor, torch.Tensor):
            latents = latents / self.vae.scaling_factor.to(
                latents.device, latents.dtype)
        else:
            latents = latents / self.vae.scaling_factor

    # Apply shifting if needed
    if (hasattr(self.vae, "shift_factor")
            and self.vae.shift_factor is not None):
        if isinstance(self.vae.shift_factor, torch.Tensor):
            latents += self.vae.shift_factor.to(latents.device,
                                                latents.dtype)
        else:
            latents += self.vae.shift_factor

    # Decode latents
    with torch.autocast(device_type="cuda",
                        dtype=vae_dtype,
                        enabled=vae_autocast_enabled):
        if fastvideo_args.pipeline_config.vae_tiling:
            self.vae.enable_tiling()
        # if fastvideo_args.vae_sp:
        #     self.vae.enable_parallel()
        if not vae_autocast_enabled:
            latents = latents.to(vae_dtype)
        image = self.vae.decode(latents)

    # Normalize image to [0, 1] range
    image = (image / 2 + 0.5).clamp(0, 1)
    return image
fastvideo.pipelines.stages.decoding.DecodingStage.forward
forward(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> ForwardBatch

Decode latent representations into pixel space.

This method processes the batch through the VAE decoder, converting latent representations to pixel-space video/images. It also optionally decodes trajectory latents for visualization purposes.

Parameters:

Name Type Description Default
batch ForwardBatch

The current batch containing: - latents: Tensor to decode (batch, channels, frames, height_latents, width_latents) - return_trajectory_decoded (optional): Flag to decode trajectory latents - trajectory_latents (optional): Latents at different timesteps - trajectory_timesteps (optional): Corresponding timesteps

required
fastvideo_args FastVideoArgs

Configuration containing: - output_type: "latent" to skip decoding, otherwise decode to pixels - vae_cpu_offload: Whether to offload VAE to CPU after decoding - model_loaded: Track VAE loading state - model_paths: Path to VAE model if loading needed

required

Returns:

Type Description
ForwardBatch

Modified batch with: - output: Decoded frames (batch, channels, frames, height, width) as CPU float32 - trajectory_decoded (if requested): List of decoded frames per timestep

Source code in fastvideo/pipelines/stages/decoding.py
@torch.no_grad()
def forward(
    self,
    batch: ForwardBatch,
    fastvideo_args: FastVideoArgs,
) -> ForwardBatch:
    """
    Decode latent representations into pixel space.

    This method processes the batch through the VAE decoder, converting latent
    representations to pixel-space video/images. It also optionally decodes
    trajectory latents for visualization purposes.

    Args:
        batch: The current batch containing:
            - latents: Tensor to decode (batch, channels, frames, height_latents, width_latents)
            - return_trajectory_decoded (optional): Flag to decode trajectory latents
            - trajectory_latents (optional): Latents at different timesteps
            - trajectory_timesteps (optional): Corresponding timesteps
        fastvideo_args: Configuration containing:
            - output_type: "latent" to skip decoding, otherwise decode to pixels
            - vae_cpu_offload: Whether to offload VAE to CPU after decoding
            - model_loaded: Track VAE loading state
            - model_paths: Path to VAE model if loading needed

    Returns:
        Modified batch with:
            - output: Decoded frames (batch, channels, frames, height, width) as CPU float32
            - trajectory_decoded (if requested): List of decoded frames per timestep
    """
    # load vae if not already loaded (used for memory constrained devices)
    pipeline = self.pipeline() if self.pipeline else None
    if not fastvideo_args.model_loaded["vae"]:
        loader = VAELoader()
        self.vae = loader.load(fastvideo_args.model_paths["vae"],
                               fastvideo_args)
        if pipeline:
            pipeline.add_module("vae", self.vae)
        fastvideo_args.model_loaded["vae"] = True

    if fastvideo_args.output_type == "latent":
        frames = batch.latents
    else:
        frames = self.decode(batch.latents, fastvideo_args)

    # decode trajectory latents if needed
    if batch.return_trajectory_decoded:
        batch.trajectory_decoded = []
        assert batch.trajectory_latents is not None, "batch should have trajectory latents"
        for idx in range(batch.trajectory_latents.shape[1]):
            # batch.trajectory_latents is [batch_size, timesteps, channels, frames, height, width]
            cur_latent = batch.trajectory_latents[:, idx, :, :, :, :]
            cur_timestep = batch.trajectory_timesteps[idx]
            logger.info("decoding trajectory latent for timestep: %s",
                        cur_timestep)
            decoded_frames = self.decode(cur_latent, fastvideo_args)
            batch.trajectory_decoded.append(decoded_frames.cpu().float())

    # Convert to CPU float32 for compatibility
    frames = frames.cpu().float()

    # Update batch with decoded image
    batch.output = frames

    # Offload models if needed
    if hasattr(self, 'maybe_free_model_hooks'):
        self.maybe_free_model_hooks()

    if fastvideo_args.vae_cpu_offload:
        self.vae.to("cpu")

    if torch.backends.mps.is_available():
        del self.vae
        if pipeline is not None and "vae" in pipeline.modules:
            del pipeline.modules["vae"]
        fastvideo_args.model_loaded["vae"] = False

    return batch
fastvideo.pipelines.stages.decoding.DecodingStage.verify_input
verify_input(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult

Verify decoding stage inputs.

Source code in fastvideo/pipelines/stages/decoding.py
def verify_input(self, batch: ForwardBatch,
                 fastvideo_args: FastVideoArgs) -> VerificationResult:
    """Verify decoding stage inputs."""
    result = VerificationResult()
    # Denoised latents for VAE decoding: [batch_size, channels, frames, height_latents, width_latents]
    result.add_check("latents", batch.latents,
                     [V.is_tensor, V.with_dims(5)])
    return result
fastvideo.pipelines.stages.decoding.DecodingStage.verify_output
verify_output(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult

Verify decoding stage outputs.

Source code in fastvideo/pipelines/stages/decoding.py
def verify_output(self, batch: ForwardBatch,
                  fastvideo_args: FastVideoArgs) -> VerificationResult:
    """Verify decoding stage outputs."""
    result = VerificationResult()
    # Decoded video/images: [batch_size, channels, frames, height, width]
    result.add_check("output", batch.output, [V.is_tensor, V.with_dims(5)])
    return result

Functions

fastvideo.pipelines.stages.denoising

Denoising stage for diffusion pipelines.

Classes

fastvideo.pipelines.stages.denoising.CosmosDenoisingStage
CosmosDenoisingStage(transformer, scheduler, pipeline=None)

Bases: DenoisingStage

Denoising stage for Cosmos models using FlowMatchEulerDiscreteScheduler.

Source code in fastvideo/pipelines/stages/denoising.py
def __init__(self, transformer, scheduler, pipeline=None) -> None:
    super().__init__(transformer, scheduler, pipeline)
Functions
fastvideo.pipelines.stages.denoising.CosmosDenoisingStage.verify_input
verify_input(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult

Verify Cosmos denoising stage inputs.

Source code in fastvideo/pipelines/stages/denoising.py
def verify_input(self, batch: ForwardBatch,
                 fastvideo_args: FastVideoArgs) -> VerificationResult:
    """Verify Cosmos denoising stage inputs."""
    result = VerificationResult()
    result.add_check("latents", batch.latents,
                     [V.is_tensor, V.with_dims(5)])
    result.add_check("prompt_embeds", batch.prompt_embeds, V.list_not_empty)
    result.add_check("num_inference_steps", batch.num_inference_steps,
                     V.positive_int)
    result.add_check("guidance_scale", batch.guidance_scale,
                     V.positive_float)
    result.add_check("do_classifier_free_guidance",
                     batch.do_classifier_free_guidance, V.bool_value)
    result.add_check(
        "negative_prompt_embeds", batch.negative_prompt_embeds, lambda x:
        not batch.do_classifier_free_guidance or V.list_not_empty(x))
    return result
fastvideo.pipelines.stages.denoising.CosmosDenoisingStage.verify_output
verify_output(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult

Verify Cosmos denoising stage outputs.

Source code in fastvideo/pipelines/stages/denoising.py
def verify_output(self, batch: ForwardBatch,
                  fastvideo_args: FastVideoArgs) -> VerificationResult:
    """Verify Cosmos denoising stage outputs."""
    result = VerificationResult()
    result.add_check("latents", batch.latents,
                     [V.is_tensor, V.with_dims(5)])
    return result
fastvideo.pipelines.stages.denoising.DenoisingStage
DenoisingStage(transformer, scheduler, pipeline=None, transformer_2=None, vae=None)

Bases: PipelineStage

Stage for running the denoising loop in diffusion pipelines.

This stage handles the iterative denoising process that transforms the initial noise into the final output.

Source code in fastvideo/pipelines/stages/denoising.py
def __init__(self,
             transformer,
             scheduler,
             pipeline=None,
             transformer_2=None,
             vae=None) -> None:
    super().__init__()
    self.transformer = transformer
    self.transformer_2 = transformer_2
    self.scheduler = scheduler
    self.vae = vae
    self.pipeline = weakref.ref(pipeline) if pipeline else None
    attn_head_size = self.transformer.hidden_size // self.transformer.num_attention_heads
    self.attn_backend = get_attn_backend(
        head_size=attn_head_size,
        dtype=torch.float16,  # TODO(will): hack
        supported_attention_backends=(
            AttentionBackendEnum.SLIDING_TILE_ATTN,
            AttentionBackendEnum.VIDEO_SPARSE_ATTN,
            AttentionBackendEnum.VMOBA_ATTN,
            AttentionBackendEnum.FLASH_ATTN,
            AttentionBackendEnum.TORCH_SDPA,
            AttentionBackendEnum.SAGE_ATTN_THREE)  # hack
    )
Functions
fastvideo.pipelines.stages.denoising.DenoisingStage.forward
forward(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> ForwardBatch

Run the denoising loop.

Parameters:

Name Type Description Default
batch ForwardBatch

The current batch information.

required
fastvideo_args FastVideoArgs

The inference arguments.

required

Returns:

Type Description
ForwardBatch

The batch with denoised latents.

Source code in fastvideo/pipelines/stages/denoising.py
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
def forward(
    self,
    batch: ForwardBatch,
    fastvideo_args: FastVideoArgs,
) -> ForwardBatch:
    """
    Run the denoising loop.

    Args:
        batch: The current batch information.
        fastvideo_args: The inference arguments.

    Returns:
        The batch with denoised latents.
    """
    pipeline = self.pipeline() if self.pipeline else None
    if not fastvideo_args.model_loaded["transformer"]:
        loader = TransformerLoader()
        self.transformer = loader.load(
            fastvideo_args.model_paths["transformer"], fastvideo_args)
        if pipeline:
            pipeline.add_module("transformer", self.transformer)
        fastvideo_args.model_loaded["transformer"] = True

    # Prepare extra step kwargs for scheduler
    extra_step_kwargs = self.prepare_extra_func_kwargs(
        self.scheduler.step,
        {
            "generator": batch.generator,
            "eta": batch.eta
        },
    )

    # Setup precision and autocast settings
    # TODO(will): make the precision configurable for inference
    # target_dtype = PRECISION_TO_TYPE[fastvideo_args.precision]
    target_dtype = torch.bfloat16
    autocast_enabled = (target_dtype != torch.float32
                        ) and not fastvideo_args.disable_autocast

    # Handle sequence parallelism if enabled
    sp_world_size, rank_in_sp_group = get_sp_world_size(
    ), get_sp_parallel_rank()
    sp_group = sp_world_size > 1
    if sp_group:
        latents = rearrange(batch.latents,
                            "b c (n t) h w -> b c n t h w",
                            n=sp_world_size).contiguous()
        latents = latents[:, :, rank_in_sp_group, :, :, :]
        batch.latents = latents
        if batch.image_latent is not None:
            image_latent = rearrange(batch.image_latent,
                                     "b c (n t) h w -> b c n t h w",
                                     n=sp_world_size).contiguous()
            image_latent = image_latent[:, :, rank_in_sp_group, :, :, :]
            batch.image_latent = image_latent
    # Get timesteps and calculate warmup steps
    timesteps = batch.timesteps
    # TODO(will): remove this once we add input/output validation for stages
    if timesteps is None:
        raise ValueError("Timesteps must be provided")
    num_inference_steps = batch.num_inference_steps
    num_warmup_steps = len(
        timesteps) - num_inference_steps * self.scheduler.order

    # Prepare image latents and embeddings for I2V generation
    image_embeds = batch.image_embeds
    if len(image_embeds) > 0:
        assert not torch.isnan(
            image_embeds[0]).any(), "image_embeds contains nan"
        image_embeds = [
            image_embed.to(target_dtype) for image_embed in image_embeds
        ]

    image_kwargs = self.prepare_extra_func_kwargs(
        self.transformer.forward,
        {
            "encoder_hidden_states_image": image_embeds,
            "mask_strategy": dict_to_3d_list(
                None, t_max=50, l_max=60, h_max=24)
        },
    )

    pos_cond_kwargs = self.prepare_extra_func_kwargs(
        self.transformer.forward,
        {
            "encoder_hidden_states_2": batch.clip_embedding_pos,
            "encoder_attention_mask": batch.prompt_attention_mask,
        },
    )

    neg_cond_kwargs = self.prepare_extra_func_kwargs(
        self.transformer.forward,
        {
            "encoder_hidden_states_2": batch.clip_embedding_neg,
            "encoder_attention_mask": batch.negative_attention_mask,
        },
    )

    # Prepare STA parameters
    if st_attn_available and self.attn_backend == SlidingTileAttentionBackend:
        self.prepare_sta_param(batch, fastvideo_args)

    # Get latents and embeddings
    latents = batch.latents
    prompt_embeds = batch.prompt_embeds
    assert not torch.isnan(
        prompt_embeds[0]).any(), "prompt_embeds contains nan"
    if batch.do_classifier_free_guidance:
        neg_prompt_embeds = batch.negative_prompt_embeds
        assert neg_prompt_embeds is not None
        assert not torch.isnan(
            neg_prompt_embeds[0]).any(), "neg_prompt_embeds contains nan"

    # (Wan2.2) Calculate timestep to switch from high noise expert to low noise expert
    boundary_ratio = fastvideo_args.pipeline_config.dit_config.boundary_ratio
    if batch.boundary_ratio is not None:
        logger.info("Overriding boundary ratio from %s to %s",
                    boundary_ratio, batch.boundary_ratio)
        boundary_ratio = batch.boundary_ratio

    if boundary_ratio is not None:
        boundary_timestep = boundary_ratio * self.scheduler.num_train_timesteps
    else:
        boundary_timestep = None
    latent_model_input = latents.to(target_dtype)
    assert latent_model_input.shape[0] == 1, "only support batch size 1"

    if fastvideo_args.pipeline_config.ti2v_task and batch.pil_image is not None:
        # TI2V directly replaces the first frame of the latent with
        # the image latent instead of appending along the channel dim
        assert batch.image_latent is None, "TI2V task should not have image latents"
        assert self.vae is not None, "VAE is not provided for TI2V task"
        z = self.vae.encode(batch.pil_image).mean.float()
        if (hasattr(self.vae, "shift_factor")
                and self.vae.shift_factor is not None):
            if isinstance(self.vae.shift_factor, torch.Tensor):
                z -= self.vae.shift_factor.to(z.device, z.dtype)
            else:
                z -= self.vae.shift_factor

        if isinstance(self.vae.scaling_factor, torch.Tensor):
            z = z * self.vae.scaling_factor.to(z.device, z.dtype)
        else:
            z = z * self.vae.scaling_factor

        latent_model_input = latent_model_input.squeeze(0)
        _, mask2 = masks_like([latent_model_input], zero=True)

        latent_model_input = (1. -
                              mask2[0]) * z + mask2[0] * latent_model_input
        # latent_model_input = latent_model_input.unsqueeze(0)
        latent_model_input = latent_model_input.to(get_local_torch_device())
        latents = latent_model_input
        F = batch.num_frames
        temporal_scale = fastvideo_args.pipeline_config.vae_config.arch_config.scale_factor_temporal
        spatial_scale = fastvideo_args.pipeline_config.vae_config.arch_config.scale_factor_spatial
        patch_size = fastvideo_args.pipeline_config.dit_config.arch_config.patch_size
        seq_len = ((F - 1) // temporal_scale +
                   1) * (batch.height // spatial_scale) * (
                       batch.width // spatial_scale) // (patch_size[1] *
                                                         patch_size[2])
        seq_len = int(math.ceil(seq_len / sp_world_size)) * sp_world_size

    # Initialize lists for ODE trajectory
    trajectory_timesteps: list[torch.Tensor] = []
    trajectory_latents: list[torch.Tensor] = []

    # Run denoising loop
    with self.progress_bar(total=num_inference_steps) as progress_bar:
        for i, t in enumerate(timesteps):
            # Skip if interrupted
            if hasattr(self, 'interrupt') and self.interrupt:
                continue

            if boundary_timestep is None or t >= boundary_timestep:
                if (fastvideo_args.dit_cpu_offload
                        and self.transformer_2 is not None and next(
                            self.transformer_2.parameters()).device.type
                        == 'cuda'):
                    self.transformer_2.to('cpu')
                current_model = self.transformer
                current_guidance_scale = batch.guidance_scale
            else:
                # low-noise stage in wan2.2
                if fastvideo_args.dit_cpu_offload and next(
                        self.transformer.parameters(
                        )).device.type == 'cuda':
                    self.transformer.to('cpu')
                current_model = self.transformer_2
                current_guidance_scale = batch.guidance_scale_2
            assert current_model is not None, "current_model is None"

            # Expand latents for V2V/I2V
            latent_model_input = latents.to(target_dtype)
            if batch.video_latent is not None:
                latent_model_input = torch.cat([
                    latent_model_input, batch.video_latent,
                    torch.zeros_like(latents)
                ],
                                               dim=1).to(target_dtype)
            elif batch.image_latent is not None:
                assert not fastvideo_args.pipeline_config.ti2v_task, "image latents should not be provided for TI2V task"
                latent_model_input = torch.cat(
                    [latent_model_input, batch.image_latent],
                    dim=1).to(target_dtype)

            assert not torch.isnan(
                latent_model_input).any(), "latent_model_input contains nan"
            if fastvideo_args.pipeline_config.ti2v_task and batch.pil_image is not None:
                timestep = torch.stack([t]).to(get_local_torch_device())
                temp_ts = (mask2[0][0][:, ::2, ::2] * timestep).flatten()
                temp_ts = torch.cat([
                    temp_ts,
                    temp_ts.new_ones(seq_len - temp_ts.size(0)) * timestep
                ])
                timestep = temp_ts.unsqueeze(0)
                t_expand = timestep.repeat(latent_model_input.shape[0], 1)
            else:
                t_expand = t.repeat(latent_model_input.shape[0])

            latent_model_input = self.scheduler.scale_model_input(
                latent_model_input, t)

            # Prepare inputs for transformer
            guidance_expand = (
                torch.tensor(
                    [fastvideo_args.pipeline_config.embedded_cfg_scale] *
                    latent_model_input.shape[0],
                    dtype=torch.float32,
                    device=get_local_torch_device(),
                ).to(target_dtype) *
                1000.0 if fastvideo_args.pipeline_config.embedded_cfg_scale
                is not None else None)

            # Predict noise residual
            with torch.autocast(device_type="cuda",
                                dtype=target_dtype,
                                enabled=autocast_enabled):
                if (st_attn_available
                        and self.attn_backend == SlidingTileAttentionBackend
                    ) or (vsa_available and self.attn_backend
                          == VideoSparseAttentionBackend):
                    self.attn_metadata_builder_cls = self.attn_backend.get_builder_cls(
                    )

                    if self.attn_metadata_builder_cls is not None:
                        self.attn_metadata_builder = self.attn_metadata_builder_cls(
                        )
                        # TODO(will): clean this up
                        attn_metadata = self.attn_metadata_builder.build(  # type: ignore
                            current_timestep=i,  # type: ignore
                            raw_latent_shape=batch.
                            raw_latent_shape[2:5],  # type: ignore
                            patch_size=fastvideo_args.
                            pipeline_config.  # type: ignore
                            dit_config.patch_size,  # type: ignore
                            STA_param=batch.STA_param,  # type: ignore
                            VSA_sparsity=fastvideo_args.
                            VSA_sparsity,  # type: ignore
                            device=get_local_torch_device(),
                        )
                        assert attn_metadata is not None, "attn_metadata cannot be None"
                    else:
                        attn_metadata = None
                elif (vmoba_attn_available
                      and self.attn_backend == VMOBAAttentionBackend):
                    self.attn_metadata_builder_cls = self.attn_backend.get_builder_cls(
                    )
                    if self.attn_metadata_builder_cls is not None:
                        self.attn_metadata_builder = self.attn_metadata_builder_cls(
                        )
                        # Prepare V-MoBA parameters from config
                        moba_params = fastvideo_args.moba_config.copy()
                        moba_params.update({
                            "current_timestep":
                            i,
                            "raw_latent_shape":
                            batch.raw_latent_shape[2:5],
                            "patch_size":
                            fastvideo_args.pipeline_config.dit_config.
                            patch_size,
                            "device":
                            get_local_torch_device(),
                        })
                        attn_metadata = self.attn_metadata_builder.build(
                            **moba_params)
                        assert attn_metadata is not None, "attn_metadata cannot be None"
                    else:
                        attn_metadata = None
                else:
                    attn_metadata = None
                # TODO(will): finalize the interface. vLLM uses this to
                # support torch dynamo compilation. They pass in
                # attn_metadata, vllm_config, and num_tokens. We can pass in
                # fastvideo_args or training_args, and attn_metadata.
                batch.is_cfg_negative = False
                with set_forward_context(
                        current_timestep=i,
                        attn_metadata=attn_metadata,
                        forward_batch=batch,
                        # fastvideo_args=fastvideo_args
                ):
                    # Run transformer
                    noise_pred = current_model(
                        latent_model_input,
                        prompt_embeds,
                        t_expand,
                        guidance=guidance_expand,
                        **image_kwargs,
                        **pos_cond_kwargs,
                    )

                if batch.do_classifier_free_guidance:
                    batch.is_cfg_negative = True
                    with set_forward_context(
                            current_timestep=i,
                            attn_metadata=attn_metadata,
                            forward_batch=batch,
                    ):
                        noise_pred_uncond = current_model(
                            latent_model_input,
                            neg_prompt_embeds,
                            t_expand,
                            guidance=guidance_expand,
                            **image_kwargs,
                            **neg_cond_kwargs,
                        )

                    noise_pred_text = noise_pred
                    noise_pred = noise_pred_uncond + current_guidance_scale * (
                        noise_pred_text - noise_pred_uncond)

                    # Apply guidance rescale if needed
                    if batch.guidance_rescale > 0.0:
                        # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
                        noise_pred = self.rescale_noise_cfg(
                            noise_pred,
                            noise_pred_text,
                            guidance_rescale=batch.guidance_rescale,
                        )
                # Compute the previous noisy sample
                latents = self.scheduler.step(noise_pred,
                                              t,
                                              latents,
                                              **extra_step_kwargs,
                                              return_dict=False)[0]
                if fastvideo_args.pipeline_config.ti2v_task and batch.pil_image is not None:
                    latents = latents.squeeze(0)
                    latents = (1. - mask2[0]) * z + mask2[0] * latents
                    # latents = latents.unsqueeze(0)

            # save trajectory latents if needed
            if batch.return_trajectory_latents:
                trajectory_timesteps.append(t)
                trajectory_latents.append(latents)

            # Update progress bar
            if i == len(timesteps) - 1 or (
                (i + 1) > num_warmup_steps and
                (i + 1) % self.scheduler.order == 0
                    and progress_bar is not None):
                progress_bar.update()

    # Gather results if using sequence parallelism
    trajectory_tensor: torch.Tensor | None = None
    if trajectory_latents:
        trajectory_tensor = torch.stack(trajectory_latents, dim=1)
        trajectory_timesteps_tensor = torch.stack(trajectory_timesteps,
                                                  dim=0)
    else:
        trajectory_tensor = None
        trajectory_timesteps_tensor = None

    # Gather results if using sequence parallelism
    if sp_group:
        latents = sequence_model_parallel_all_gather(latents, dim=2)
        if batch.return_trajectory_latents:
            trajectory_tensor = trajectory_tensor.to(
                get_local_torch_device())
            trajectory_tensor = sequence_model_parallel_all_gather(
                trajectory_tensor, dim=3)

    if trajectory_tensor is not None and trajectory_timesteps_tensor is not None:
        batch.trajectory_timesteps = trajectory_timesteps_tensor.cpu()
        batch.trajectory_latents = trajectory_tensor.cpu()

    # Update batch with final latents
    batch.latents = latents

    # Save STA mask search results if needed
    if st_attn_available and self.attn_backend == SlidingTileAttentionBackend and fastvideo_args.STA_mode == STA_Mode.STA_SEARCHING:
        self.save_sta_search_results(batch)

    # deallocate transformer if on mps
    if torch.backends.mps.is_available():
        logger.info("Memory before deallocating transformer: %s",
                    torch.mps.current_allocated_memory())
        del self.transformer
        if pipeline is not None and "transformer" in pipeline.modules:
            del pipeline.modules["transformer"]
        fastvideo_args.model_loaded["transformer"] = False
        logger.info("Memory after deallocating transformer: %s",
                    torch.mps.current_allocated_memory())

    return batch
fastvideo.pipelines.stages.denoising.DenoisingStage.prepare_extra_func_kwargs
prepare_extra_func_kwargs(func, kwargs) -> dict[str, Any]

Prepare extra kwargs for the scheduler step / denoise step.

Parameters:

Name Type Description Default
func

The function to prepare kwargs for.

required
kwargs

The kwargs to prepare.

required

Returns:

Type Description
dict[str, Any]

The prepared kwargs.

Source code in fastvideo/pipelines/stages/denoising.py
def prepare_extra_func_kwargs(self, func, kwargs) -> dict[str, Any]:
    """
    Prepare extra kwargs for the scheduler step / denoise step.

    Args:
        func: The function to prepare kwargs for.
        kwargs: The kwargs to prepare.

    Returns:
        The prepared kwargs.
    """
    extra_step_kwargs = {}
    for k, v in kwargs.items():
        accepts = k in set(inspect.signature(func).parameters.keys())
        if accepts:
            extra_step_kwargs[k] = v
    return extra_step_kwargs
fastvideo.pipelines.stages.denoising.DenoisingStage.prepare_sta_param
prepare_sta_param(batch: ForwardBatch, fastvideo_args: FastVideoArgs)

Prepare Sliding Tile Attention (STA) parameters and settings.

Parameters:

Name Type Description Default
batch ForwardBatch

The current batch information.

required
fastvideo_args FastVideoArgs

The inference arguments.

required
Source code in fastvideo/pipelines/stages/denoising.py
def prepare_sta_param(self, batch: ForwardBatch,
                      fastvideo_args: FastVideoArgs):
    """
    Prepare Sliding Tile Attention (STA) parameters and settings.

    Args:
        batch: The current batch information.
        fastvideo_args: The inference arguments.
    """
    # TODO(kevin): STA mask search, currently only support Wan2.1 with 69x768x1280
    from fastvideo.attention.backends.STA_configuration import configure_sta
    STA_mode = fastvideo_args.STA_mode
    skip_time_steps = fastvideo_args.skip_time_steps
    if batch.timesteps is None:
        raise ValueError("Timesteps must be provided")
    timesteps_num = batch.timesteps.shape[0]

    logger.info("STA_mode: %s", STA_mode)
    if (batch.num_frames, batch.height,
            batch.width) != (69, 768, 1280) and STA_mode != "STA_inference":
        raise NotImplementedError(
            "STA mask search/tuning is not supported for this resolution")

    if STA_mode == STA_Mode.STA_SEARCHING or STA_mode == STA_Mode.STA_TUNING or STA_mode == STA_Mode.STA_TUNING_CFG:
        size = (batch.width, batch.height)
        if size == (1280, 768):
            # TODO: make it configurable
            sparse_mask_candidates_searching = [
                "3, 1, 10", "1, 5, 7", "3, 3, 3", "1, 6, 5", "1, 3, 10",
                "3, 6, 1"
            ]
            sparse_mask_candidates_tuning = [
                "3, 1, 10", "1, 5, 7", "3, 3, 3", "1, 6, 5", "1, 3, 10",
                "3, 6, 1"
            ]
            full_mask = ["3,6,10"]
        else:
            raise NotImplementedError(
                "STA mask search is not supported for this resolution")
    layer_num = self.transformer.config.num_layers
    # specific for HunyuanVideo
    if hasattr(self.transformer.config, "num_single_layers"):
        layer_num += self.transformer.config.num_single_layers
    head_num = self.transformer.config.num_attention_heads

    if STA_mode == STA_Mode.STA_SEARCHING:
        STA_param = configure_sta(
            mode=STA_Mode.STA_SEARCHING,
            layer_num=layer_num,
            head_num=head_num,
            time_step_num=timesteps_num,
            mask_candidates=sparse_mask_candidates_searching +
            full_mask,  # last is full mask; Can add more sparse masks while keep last one as full mask
        )
    elif STA_mode == STA_Mode.STA_TUNING:
        STA_param = configure_sta(
            mode=STA_Mode.STA_TUNING,
            layer_num=layer_num,
            head_num=head_num,
            time_step_num=timesteps_num,
            mask_search_files_path=
            f'output/mask_search_result_pos_{size[0]}x{size[1]}/',
            mask_candidates=sparse_mask_candidates_tuning,
            full_attention_mask=[int(x) for x in full_mask[0].split(',')],
            skip_time_steps=
            skip_time_steps,  # Use full attention for first 12 steps
            save_dir=
            f'output/mask_search_strategy_{size[0]}x{size[1]}/',  # Custom save directory
            timesteps=timesteps_num)
    elif STA_mode == STA_Mode.STA_TUNING_CFG:
        STA_param = configure_sta(
            mode=STA_Mode.STA_TUNING_CFG,
            layer_num=layer_num,
            head_num=head_num,
            time_step_num=timesteps_num,
            mask_search_files_path_pos=
            f'output/mask_search_result_pos_{size[0]}x{size[1]}/',
            mask_search_files_path_neg=
            f'output/mask_search_result_neg_{size[0]}x{size[1]}/',
            mask_candidates=sparse_mask_candidates_tuning,
            full_attention_mask=[int(x) for x in full_mask[0].split(',')],
            skip_time_steps=skip_time_steps,
            save_dir=f'output/mask_search_strategy_{size[0]}x{size[1]}/',
            timesteps=timesteps_num)
    elif STA_mode == STA_Mode.STA_INFERENCE:
        import fastvideo.envs as envs
        config_file = envs.FASTVIDEO_ATTENTION_CONFIG
        if config_file is None:
            raise ValueError("FASTVIDEO_ATTENTION_CONFIG is not set")
        STA_param = configure_sta(mode=STA_Mode.STA_INFERENCE,
                                  layer_num=layer_num,
                                  head_num=head_num,
                                  time_step_num=timesteps_num,
                                  load_path=config_file)

    batch.STA_param = STA_param
    batch.mask_search_final_result_pos = [[] for _ in range(timesteps_num)]
    batch.mask_search_final_result_neg = [[] for _ in range(timesteps_num)]
fastvideo.pipelines.stages.denoising.DenoisingStage.progress_bar
progress_bar(iterable: Iterable | None = None, total: int | None = None) -> tqdm

Create a progress bar for the denoising process.

Parameters:

Name Type Description Default
iterable Iterable | None

The iterable to iterate over.

None
total int | None

The total number of items.

None

Returns:

Type Description
tqdm

A tqdm progress bar.

Source code in fastvideo/pipelines/stages/denoising.py
def progress_bar(self,
                 iterable: Iterable | None = None,
                 total: int | None = None) -> tqdm:
    """
    Create a progress bar for the denoising process.

    Args:
        iterable: The iterable to iterate over.
        total: The total number of items.

    Returns:
        A tqdm progress bar.
    """
    local_rank = get_world_group().local_rank
    if local_rank == 0:
        return tqdm(iterable=iterable, total=total)
    else:
        return tqdm(iterable=iterable, total=total, disable=True)
fastvideo.pipelines.stages.denoising.DenoisingStage.rescale_noise_cfg
rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0) -> Tensor

Rescale noise prediction according to guidance_rescale.

Based on findings of "Common Diffusion Noise Schedules and Sample Steps are Flawed" (https://arxiv.org/pdf/2305.08891.pdf), Section 3.4.

Parameters:

Name Type Description Default
noise_cfg

The noise prediction with guidance.

required
noise_pred_text

The text-conditioned noise prediction.

required
guidance_rescale

The guidance rescale factor.

0.0

Returns:

Type Description
Tensor

The rescaled noise prediction.

Source code in fastvideo/pipelines/stages/denoising.py
def rescale_noise_cfg(self,
                      noise_cfg,
                      noise_pred_text,
                      guidance_rescale=0.0) -> torch.Tensor:
    """
    Rescale noise prediction according to guidance_rescale.

    Based on findings of "Common Diffusion Noise Schedules and Sample Steps are Flawed"
    (https://arxiv.org/pdf/2305.08891.pdf), Section 3.4.

    Args:
        noise_cfg: The noise prediction with guidance.
        noise_pred_text: The text-conditioned noise prediction.
        guidance_rescale: The guidance rescale factor.

    Returns:
        The rescaled noise prediction.
    """
    std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)),
                                   keepdim=True)
    std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)),
                            keepdim=True)
    # Rescale the results from guidance (fixes overexposure)
    noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
    # Mix with the original results from guidance by factor guidance_rescale
    noise_cfg = (guidance_rescale * noise_pred_rescaled +
                 (1 - guidance_rescale) * noise_cfg)
    return noise_cfg
fastvideo.pipelines.stages.denoising.DenoisingStage.save_sta_search_results
save_sta_search_results(batch: ForwardBatch)

Save the STA mask search results.

Parameters:

Name Type Description Default
batch ForwardBatch

The current batch information.

required
Source code in fastvideo/pipelines/stages/denoising.py
def save_sta_search_results(self, batch: ForwardBatch):
    """
    Save the STA mask search results.

    Args:
        batch: The current batch information.
    """
    size = (batch.width, batch.height)
    if size == (1280, 768):
        # TODO: make it configurable
        sparse_mask_candidates_searching = [
            "3, 1, 10", "1, 5, 7", "3, 3, 3", "1, 6, 5", "1, 3, 10",
            "3, 6, 1"
        ]
    else:
        raise NotImplementedError(
            "STA mask search is not supported for this resolution")

    from fastvideo.attention.backends.STA_configuration import save_mask_search_results
    if batch.mask_search_final_result_pos is not None and batch.prompt is not None:
        save_mask_search_results(
            [
                dict(layer_data)
                for layer_data in batch.mask_search_final_result_pos
            ],
            prompt=str(batch.prompt),
            mask_strategies=sparse_mask_candidates_searching,
            output_dir=f'output/mask_search_result_pos_{size[0]}x{size[1]}/'
        )
    if batch.mask_search_final_result_neg is not None and batch.prompt is not None:
        save_mask_search_results(
            [
                dict(layer_data)
                for layer_data in batch.mask_search_final_result_neg
            ],
            prompt=str(batch.prompt),
            mask_strategies=sparse_mask_candidates_searching,
            output_dir=f'output/mask_search_result_neg_{size[0]}x{size[1]}/'
        )
fastvideo.pipelines.stages.denoising.DenoisingStage.verify_input
verify_input(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult

Verify denoising stage inputs.

Source code in fastvideo/pipelines/stages/denoising.py
def verify_input(self, batch: ForwardBatch,
                 fastvideo_args: FastVideoArgs) -> VerificationResult:
    """Verify denoising stage inputs."""
    result = VerificationResult()
    result.add_check("timesteps", batch.timesteps,
                     [V.is_tensor, V.min_dims(1)])
    result.add_check("latents", batch.latents,
                     [V.is_tensor, V.with_dims(5)])
    result.add_check("prompt_embeds", batch.prompt_embeds, V.list_not_empty)
    result.add_check("image_embeds", batch.image_embeds, V.is_list)
    result.add_check("image_latent", batch.image_latent,
                     V.none_or_tensor_with_dims(5))
    result.add_check("num_inference_steps", batch.num_inference_steps,
                     V.positive_int)
    result.add_check("guidance_scale", batch.guidance_scale,
                     V.positive_float)
    result.add_check("eta", batch.eta, V.non_negative_float)
    result.add_check("generator", batch.generator,
                     V.generator_or_list_generators)
    result.add_check("do_classifier_free_guidance",
                     batch.do_classifier_free_guidance, V.bool_value)
    result.add_check(
        "negative_prompt_embeds", batch.negative_prompt_embeds, lambda x:
        not batch.do_classifier_free_guidance or V.list_not_empty(x))
    return result
fastvideo.pipelines.stages.denoising.DenoisingStage.verify_output
verify_output(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult

Verify denoising stage outputs.

Source code in fastvideo/pipelines/stages/denoising.py
def verify_output(self, batch: ForwardBatch,
                  fastvideo_args: FastVideoArgs) -> VerificationResult:
    """Verify denoising stage outputs."""
    result = VerificationResult()
    result.add_check("latents", batch.latents,
                     [V.is_tensor, V.with_dims(5)])
    return result
fastvideo.pipelines.stages.denoising.DmdDenoisingStage
DmdDenoisingStage(transformer, scheduler)

Bases: DenoisingStage

Denoising stage for DMD.

Source code in fastvideo/pipelines/stages/denoising.py
def __init__(self, transformer, scheduler) -> None:
    super().__init__(transformer, scheduler)
    self.scheduler = FlowMatchEulerDiscreteScheduler(shift=8.0)
Functions
fastvideo.pipelines.stages.denoising.DmdDenoisingStage.forward
forward(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> ForwardBatch

Run the denoising loop.

Parameters:

Name Type Description Default
batch ForwardBatch

The current batch information.

required
fastvideo_args FastVideoArgs

The inference arguments.

required

Returns:

Type Description
ForwardBatch

The batch with denoised latents.

Source code in fastvideo/pipelines/stages/denoising.py
def forward(
    self,
    batch: ForwardBatch,
    fastvideo_args: FastVideoArgs,
) -> ForwardBatch:
    """
    Run the denoising loop.

    Args:
        batch: The current batch information.
        fastvideo_args: The inference arguments.

    Returns:
        The batch with denoised latents.
    """
    # Setup precision and autocast settings
    # TODO(will): make the precision configurable for inference
    # target_dtype = PRECISION_TO_TYPE[fastvideo_args.precision]
    target_dtype = torch.bfloat16
    autocast_enabled = (target_dtype != torch.float32
                        ) and not fastvideo_args.disable_autocast

    # Get timesteps and calculate warmup steps
    timesteps = batch.timesteps

    # TODO(will): remove this once we add input/output validation for stages
    if timesteps is None:
        raise ValueError("Timesteps must be provided")
    num_inference_steps = batch.num_inference_steps
    num_warmup_steps = len(
        timesteps) - num_inference_steps * self.scheduler.order

    # Prepare image latents and embeddings for I2V generation
    image_embeds = batch.image_embeds
    if len(image_embeds) > 0:
        assert torch.isnan(image_embeds[0]).sum() == 0
        image_embeds = [
            image_embed.to(target_dtype) for image_embed in image_embeds
        ]

    image_kwargs = self.prepare_extra_func_kwargs(
        self.transformer.forward,
        {
            "encoder_hidden_states_image": image_embeds,
            "mask_strategy": dict_to_3d_list(
                None, t_max=50, l_max=60, h_max=24)
        },
    )

    pos_cond_kwargs = self.prepare_extra_func_kwargs(
        self.transformer.forward,
        {
            "encoder_hidden_states_2": batch.clip_embedding_pos,
            "encoder_attention_mask": batch.prompt_attention_mask,
        },
    )

    # Prepare STA parameters
    if st_attn_available and self.attn_backend == SlidingTileAttentionBackend:
        self.prepare_sta_param(batch, fastvideo_args)

    # Get latents and embeddings
    assert batch.latents is not None, "latents must be provided"
    latents = batch.latents
    latents = latents.permute(0, 2, 1, 3, 4)

    video_raw_latent_shape = latents.shape
    prompt_embeds = batch.prompt_embeds
    assert not torch.isnan(
        prompt_embeds[0]).any(), "prompt_embeds contains nan"
    timesteps = torch.tensor(
        fastvideo_args.pipeline_config.dmd_denoising_steps,
        dtype=torch.long,
        device=get_local_torch_device())

    # Handle sequence parallelism if enabled
    sp_world_size, rank_in_sp_group = get_sp_world_size(
    ), get_sp_parallel_rank()
    sp_group = sp_world_size > 1
    if sp_group:
        latents = rearrange(latents,
                            "b (n t) c h w -> b n t c h w",
                            n=sp_world_size).contiguous()
        latents = latents[:, rank_in_sp_group, :, :, :, :]
        if batch.image_latent is not None:
            image_latent = rearrange(batch.image_latent,
                                     "b c (n t) h w -> b c n t h w",
                                     n=sp_world_size).contiguous()

            image_latent = image_latent[:, :, rank_in_sp_group, :, :, :]
            batch.image_latent = image_latent

    # Run denoising loop
    with self.progress_bar(total=len(timesteps)) as progress_bar:
        for i, t in enumerate(timesteps):
            # Skip if interrupted
            if hasattr(self, 'interrupt') and self.interrupt:
                continue
            # Expand latents for I2V
            noise_latents = latents.clone()
            latent_model_input = latents.to(target_dtype)

            if batch.image_latent is not None:
                latent_model_input = torch.cat([
                    latent_model_input,
                    batch.image_latent.permute(0, 2, 1, 3, 4)
                ],
                                               dim=2).to(target_dtype)
            assert not torch.isnan(
                latent_model_input).any(), "latent_model_input contains nan"

            # Prepare inputs for transformer
            t_expand = t.repeat(latent_model_input.shape[0])
            guidance_expand = (
                torch.tensor(
                    [fastvideo_args.pipeline_config.embedded_cfg_scale] *
                    latent_model_input.shape[0],
                    dtype=torch.float32,
                    device=get_local_torch_device(),
                ).to(target_dtype) *
                1000.0 if fastvideo_args.pipeline_config.embedded_cfg_scale
                is not None else None)

            # Predict noise residual
            with torch.autocast(device_type="cuda",
                                dtype=target_dtype,
                                enabled=autocast_enabled):
                if (vsa_available and self.attn_backend
                        == VideoSparseAttentionBackend):
                    self.attn_metadata_builder_cls = self.attn_backend.get_builder_cls(
                    )

                    if self.attn_metadata_builder_cls is not None:
                        self.attn_metadata_builder = self.attn_metadata_builder_cls(
                        )
                        # TODO(will): clean this up
                        attn_metadata = self.attn_metadata_builder.build(  # type: ignore
                            current_timestep=i,  # type: ignore
                            raw_latent_shape=batch.
                            raw_latent_shape[2:5],  # type: ignore
                            patch_size=fastvideo_args.
                            pipeline_config.  # type: ignore
                            dit_config.patch_size,  # type: ignore
                            STA_param=batch.STA_param,  # type: ignore
                            VSA_sparsity=fastvideo_args.
                            VSA_sparsity,  # type: ignore
                            device=get_local_torch_device(),  # type: ignore
                        )  # type: ignore
                        assert attn_metadata is not None, "attn_metadata cannot be None"
                    else:
                        attn_metadata = None
                else:
                    attn_metadata = None

                batch.is_cfg_negative = False
                with set_forward_context(
                        current_timestep=i,
                        attn_metadata=attn_metadata,
                        forward_batch=batch,
                        # fastvideo_args=fastvideo_args
                ):
                    # Run transformer
                    pred_noise = self.transformer(
                        latent_model_input.permute(0, 2, 1, 3, 4),
                        prompt_embeds,
                        t_expand,
                        guidance=guidance_expand,
                        **image_kwargs,
                        **pos_cond_kwargs,
                    ).permute(0, 2, 1, 3, 4)

                pred_video = pred_noise_to_pred_video(
                    pred_noise=pred_noise.flatten(0, 1),
                    noise_input_latent=noise_latents.flatten(0, 1),
                    timestep=t_expand,
                    scheduler=self.scheduler).unflatten(
                        0, pred_noise.shape[:2])

                if i < len(timesteps) - 1:
                    next_timestep = timesteps[i + 1] * torch.ones(
                        [1], dtype=torch.long, device=pred_video.device)
                    noise = torch.randn(video_raw_latent_shape,
                                        dtype=pred_video.dtype,
                                        generator=batch.generator[0]).to(
                                            self.device)
                    if sp_group:
                        noise = rearrange(noise,
                                          "b (n t) c h w -> b n t c h w",
                                          n=sp_world_size).contiguous()
                        noise = noise[:, rank_in_sp_group, :, :, :, :]
                    latents = self.scheduler.add_noise(
                        pred_video.flatten(0, 1), noise.flatten(0, 1),
                        next_timestep).unflatten(0, pred_video.shape[:2])
                else:
                    latents = pred_video

                # Update progress bar
                if i == len(timesteps) - 1 or (
                    (i + 1) > num_warmup_steps and
                    (i + 1) % self.scheduler.order == 0
                        and progress_bar is not None):
                    progress_bar.update()

    # Gather results if using sequence parallelism
    if sp_group:
        latents = sequence_model_parallel_all_gather(latents, dim=1)
    latents = latents.permute(0, 2, 1, 3, 4)
    # Update batch with final latents
    batch.latents = latents

    return batch

Functions

fastvideo.pipelines.stages.encoding

Encoding stage for diffusion pipelines.

Classes

fastvideo.pipelines.stages.encoding.EncodingStage
EncodingStage(vae: ParallelTiledVAE)

Bases: PipelineStage

Stage for encoding pixel space representations into latent space.

This stage handles the encoding of pixel-space video/images into latent representations for further processing in the diffusion pipeline.

Source code in fastvideo/pipelines/stages/encoding.py
def __init__(self, vae: ParallelTiledVAE) -> None:
    self.vae: ParallelTiledVAE = vae
Functions
fastvideo.pipelines.stages.encoding.EncodingStage.forward
forward(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> ForwardBatch

Encode pixel space representations into latent space.

Parameters:

Name Type Description Default
batch ForwardBatch

The current batch information.

required
fastvideo_args FastVideoArgs

The inference arguments.

required

Returns:

Type Description
ForwardBatch

The batch with encoded latents.

Source code in fastvideo/pipelines/stages/encoding.py
def forward(
    self,
    batch: ForwardBatch,
    fastvideo_args: FastVideoArgs,
) -> ForwardBatch:
    """
    Encode pixel space representations into latent space.

    Args:
        batch: The current batch information.
        fastvideo_args: The inference arguments.

    Returns:
        The batch with encoded latents.
    """
    assert batch.latents is not None and isinstance(batch.latents,
                                                    torch.Tensor)

    self.vae = self.vae.to(get_local_torch_device())

    # Setup VAE precision
    vae_dtype = PRECISION_TO_TYPE[
        fastvideo_args.pipeline_config.vae_precision]
    vae_autocast_enabled = (
        vae_dtype != torch.float32) and not fastvideo_args.disable_autocast

    # Normalize input to [-1, 1] range (reverse of decoding normalization)
    latents = (batch.latents * 2.0 - 1.0).clamp(-1, 1)

    # Move to appropriate device and dtype
    latents = latents.to(get_local_torch_device())

    # Encode image to latents
    with torch.autocast(device_type="cuda",
                        dtype=vae_dtype,
                        enabled=vae_autocast_enabled):
        if fastvideo_args.pipeline_config.vae_tiling:
            self.vae.enable_tiling()
        # if fastvideo_args.vae_sp:
        #     self.vae.enable_parallel()
        if not vae_autocast_enabled:
            latents = latents.to(vae_dtype)
        latents = self.vae.encode(latents).mean

    # Update batch with encoded latents
    batch.latents = latents

    # Offload models if needed
    if hasattr(self, 'maybe_free_model_hooks'):
        self.maybe_free_model_hooks()

    if fastvideo_args.vae_cpu_offload:
        self.vae.to("cpu")

    return batch
fastvideo.pipelines.stages.encoding.EncodingStage.verify_input
verify_input(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult

Verify encoding stage inputs.

Source code in fastvideo/pipelines/stages/encoding.py
@torch.no_grad()
def verify_input(self, batch: ForwardBatch,
                 fastvideo_args: FastVideoArgs) -> VerificationResult:
    """Verify encoding stage inputs."""
    result = VerificationResult()
    # Input video/images for VAE encoding: [batch_size, channels, frames, height, width]
    result.add_check("latents", batch.latents,
                     [V.is_tensor, V.with_dims(5)])
    return result
fastvideo.pipelines.stages.encoding.EncodingStage.verify_output
verify_output(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult

Verify encoding stage outputs.

Source code in fastvideo/pipelines/stages/encoding.py
def verify_output(self, batch: ForwardBatch,
                  fastvideo_args: FastVideoArgs) -> VerificationResult:
    """Verify encoding stage outputs."""
    result = VerificationResult()
    # Encoded latents: [batch_size, channels, frames, height_latents, width_latents]
    result.add_check("latents", batch.latents,
                     [V.is_tensor, V.with_dims(5)])
    return result

Functions

fastvideo.pipelines.stages.image_encoding

Image and video encoding stages for diffusion pipelines.

This module contains implementations of encoding stages for diffusion pipelines: - ImageEncodingStage: Encodes images using image encoders (e.g., CLIP) - RefImageEncodingStage: Encodes reference image for Wan2.1 control pipeline - ImageVAEEncodingStage: Encodes images to latent space using VAE for I2V generation - VideoVAEEncodingStage: Encodes videos to latent space using VAE for V2V and control tasks

Classes

fastvideo.pipelines.stages.image_encoding.ImageEncodingStage
ImageEncodingStage(image_encoder, image_processor)

Bases: PipelineStage

Stage for encoding image prompts into embeddings for diffusion models.

This stage handles the encoding of image prompts into the embedding space expected by the diffusion model.

Initialize the prompt encoding stage.

Parameters:

Name Type Description Default
enable_logging

Whether to enable logging for this stage.

required
is_secondary

Whether this is a secondary image encoder.

required
Source code in fastvideo/pipelines/stages/image_encoding.py
def __init__(self, image_encoder, image_processor) -> None:
    """
    Initialize the prompt encoding stage.

    Args:
        enable_logging: Whether to enable logging for this stage.
        is_secondary: Whether this is a secondary image encoder.
    """
    super().__init__()
    self.image_processor = image_processor
    self.image_encoder = image_encoder
Functions
fastvideo.pipelines.stages.image_encoding.ImageEncodingStage.forward
forward(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> ForwardBatch

Encode the prompt into image encoder hidden states.

Parameters:

Name Type Description Default
batch ForwardBatch

The current batch information.

required
fastvideo_args FastVideoArgs

The inference arguments.

required

Returns:

Type Description
ForwardBatch

The batch with encoded prompt embeddings.

Source code in fastvideo/pipelines/stages/image_encoding.py
@torch.no_grad()
def forward(
    self,
    batch: ForwardBatch,
    fastvideo_args: FastVideoArgs,
) -> ForwardBatch:
    """
    Encode the prompt into image encoder hidden states.

    Args:
        batch: The current batch information.
        fastvideo_args: The inference arguments.

    Returns:
        The batch with encoded prompt embeddings.
    """
    self.image_encoder = self.image_encoder.to(get_local_torch_device())

    image = batch.pil_image

    image_inputs = self.image_processor(
        images=image, return_tensors="pt").to(get_local_torch_device())
    with set_forward_context(current_timestep=0, attn_metadata=None):
        outputs = self.image_encoder(**image_inputs)
        image_embeds = outputs.last_hidden_state

    batch.image_embeds.append(image_embeds)

    if fastvideo_args.image_encoder_cpu_offload:
        self.image_encoder.to('cpu')

    return batch
fastvideo.pipelines.stages.image_encoding.ImageEncodingStage.verify_input
verify_input(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult

Verify image encoding stage inputs.

Source code in fastvideo/pipelines/stages/image_encoding.py
def verify_input(self, batch: ForwardBatch,
                 fastvideo_args: FastVideoArgs) -> VerificationResult:
    """Verify image encoding stage inputs."""
    result = VerificationResult()
    result.add_check("pil_image", batch.pil_image, V.not_none)
    result.add_check("image_embeds", batch.image_embeds, V.is_list)
    return result
fastvideo.pipelines.stages.image_encoding.ImageEncodingStage.verify_output
verify_output(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult

Verify image encoding stage outputs.

Source code in fastvideo/pipelines/stages/image_encoding.py
def verify_output(self, batch: ForwardBatch,
                  fastvideo_args: FastVideoArgs) -> VerificationResult:
    """Verify image encoding stage outputs."""
    result = VerificationResult()
    result.add_check("image_embeds", batch.image_embeds,
                     V.list_of_tensors_dims(3))
    return result
fastvideo.pipelines.stages.image_encoding.ImageVAEEncodingStage
ImageVAEEncodingStage(vae: ParallelTiledVAE)

Bases: PipelineStage

Stage for encoding image pixel representations into latent space.

This stage handles the encoding of image pixel representations into the final input format (e.g., latents) for image-to-video generation.

Source code in fastvideo/pipelines/stages/image_encoding.py
def __init__(self, vae: ParallelTiledVAE) -> None:
    self.vae: ParallelTiledVAE = vae
Functions
fastvideo.pipelines.stages.image_encoding.ImageVAEEncodingStage.forward
forward(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> ForwardBatch

Encode pixel representations into latent space.

Parameters:

Name Type Description Default
batch ForwardBatch

The current batch information.

required
fastvideo_args FastVideoArgs

The inference arguments.

required

Returns:

Type Description
ForwardBatch

The batch with encoded outputs.

Source code in fastvideo/pipelines/stages/image_encoding.py
def forward(
    self,
    batch: ForwardBatch,
    fastvideo_args: FastVideoArgs,
) -> ForwardBatch:
    """
    Encode pixel representations into latent space.

    Args:
        batch: The current batch information.
        fastvideo_args: The inference arguments.

    Returns:
        The batch with encoded outputs.
    """
    assert batch.pil_image is not None
    if fastvideo_args.mode == ExecutionMode.INFERENCE:
        assert batch.pil_image is not None and isinstance(
            batch.pil_image, PIL.Image.Image)
        assert batch.height is not None and isinstance(batch.height, int)
        assert batch.width is not None and isinstance(batch.width, int)
        assert batch.num_frames is not None and isinstance(
            batch.num_frames, int)
        height = batch.height
        width = batch.width
        num_frames = batch.num_frames
    elif fastvideo_args.mode == ExecutionMode.PREPROCESS:
        assert batch.pil_image is not None and isinstance(
            batch.pil_image, torch.Tensor)
        assert batch.height is not None and isinstance(batch.height, list)
        assert batch.width is not None and isinstance(batch.width, list)
        assert batch.num_frames is not None and isinstance(
            batch.num_frames, list)
        num_frames = batch.num_frames[0]
        height = batch.height[0]
        width = batch.width[0]

    self.vae = self.vae.to(get_local_torch_device())

    # Process single image for I2V
    latent_height = height // self.vae.spatial_compression_ratio
    latent_width = width // self.vae.spatial_compression_ratio
    image = batch.pil_image
    image = self.preprocess(
        image,
        vae_scale_factor=self.vae.spatial_compression_ratio,
        height=height,
        width=width).to(get_local_torch_device(), dtype=torch.float32)

    # (B, C, H, W) -> (B, C, 1, H, W)
    image = image.unsqueeze(2)

    video_condition = torch.cat([
        image,
        image.new_zeros(image.shape[0], image.shape[1], num_frames - 1,
                        image.shape[3], image.shape[4])
    ],
                                dim=2)
    video_condition = video_condition.to(device=get_local_torch_device(),
                                         dtype=torch.float32)

    # Setup VAE precision
    vae_dtype = PRECISION_TO_TYPE[
        fastvideo_args.pipeline_config.vae_precision]
    vae_autocast_enabled = (
        vae_dtype != torch.float32) and not fastvideo_args.disable_autocast

    # Encode Image
    with torch.autocast(device_type="cuda",
                        dtype=vae_dtype,
                        enabled=vae_autocast_enabled):
        if fastvideo_args.pipeline_config.vae_tiling:
            self.vae.enable_tiling()
        # if fastvideo_args.vae_sp:
        #     self.vae.enable_parallel()
        if not vae_autocast_enabled:
            video_condition = video_condition.to(vae_dtype)
        encoder_output = self.vae.encode(video_condition)

    if fastvideo_args.mode == ExecutionMode.PREPROCESS:
        latent_condition = encoder_output.mean
    else:
        generator = batch.generator
        if generator is None:
            raise ValueError("Generator must be provided")
        latent_condition = self.retrieve_latents(encoder_output, generator)

    # Apply shifting if needed
    if (hasattr(self.vae, "shift_factor")
            and self.vae.shift_factor is not None):
        if isinstance(self.vae.shift_factor, torch.Tensor):
            latent_condition -= self.vae.shift_factor.to(
                latent_condition.device, latent_condition.dtype)
        else:
            latent_condition -= self.vae.shift_factor

    if isinstance(self.vae.scaling_factor, torch.Tensor):
        latent_condition = latent_condition * self.vae.scaling_factor.to(
            latent_condition.device, latent_condition.dtype)
    else:
        latent_condition = latent_condition * self.vae.scaling_factor

    if fastvideo_args.mode == ExecutionMode.PREPROCESS:
        batch.image_latent = latent_condition
    else:
        mask_lat_size = torch.ones(1, 1, num_frames, latent_height,
                                   latent_width)
        mask_lat_size[:, :, list(range(1, num_frames))] = 0
        first_frame_mask = mask_lat_size[:, :, 0:1]
        first_frame_mask = torch.repeat_interleave(
            first_frame_mask,
            dim=2,
            repeats=self.vae.temporal_compression_ratio)
        mask_lat_size = torch.concat(
            [first_frame_mask, mask_lat_size[:, :, 1:, :]], dim=2)
        mask_lat_size = mask_lat_size.view(
            1, -1, self.vae.temporal_compression_ratio, latent_height,
            latent_width)
        mask_lat_size = mask_lat_size.transpose(1, 2)
        mask_lat_size = mask_lat_size.to(latent_condition.device)

        batch.image_latent = torch.concat([mask_lat_size, latent_condition],
                                          dim=1)

    # Offload models if needed
    if hasattr(self, 'maybe_free_model_hooks'):
        self.maybe_free_model_hooks()

    self.vae.to("cpu")

    return batch
fastvideo.pipelines.stages.image_encoding.ImageVAEEncodingStage.verify_input
verify_input(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult

Verify encoding stage inputs.

Source code in fastvideo/pipelines/stages/image_encoding.py
def verify_input(self, batch: ForwardBatch,
                 fastvideo_args: FastVideoArgs) -> VerificationResult:
    """Verify encoding stage inputs."""
    result = VerificationResult()
    result.add_check("generator", batch.generator,
                     V.generator_or_list_generators)
    if fastvideo_args.mode == ExecutionMode.PREPROCESS:
        result.add_check("height", batch.height, V.list_not_empty)
        result.add_check("width", batch.width, V.list_not_empty)
        result.add_check("num_frames", batch.num_frames, V.list_not_empty)
    else:
        result.add_check("height", batch.height, V.positive_int)
        result.add_check("width", batch.width, V.positive_int)
        result.add_check("num_frames", batch.num_frames, V.positive_int)
    return result
fastvideo.pipelines.stages.image_encoding.ImageVAEEncodingStage.verify_output
verify_output(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult

Verify encoding stage outputs.

Source code in fastvideo/pipelines/stages/image_encoding.py
def verify_output(self, batch: ForwardBatch,
                  fastvideo_args: FastVideoArgs) -> VerificationResult:
    """Verify encoding stage outputs."""
    result = VerificationResult()
    result.add_check("image_latent", batch.image_latent,
                     [V.is_tensor, V.with_dims(5)])
    return result
fastvideo.pipelines.stages.image_encoding.RefImageEncodingStage
RefImageEncodingStage(image_encoder, image_processor)

Bases: ImageEncodingStage

Stage for encoding reference image prompts into embeddings for Wan2.1 Control models.

This stage extends ImageEncodingStage with specialized preprocessing for reference images.

Source code in fastvideo/pipelines/stages/image_encoding.py
def __init__(self, image_encoder, image_processor) -> None:
    """
    Initialize the prompt encoding stage.

    Args:
        enable_logging: Whether to enable logging for this stage.
        is_secondary: Whether this is a secondary image encoder.
    """
    super().__init__()
    self.image_processor = image_processor
    self.image_encoder = image_encoder
Functions
fastvideo.pipelines.stages.image_encoding.RefImageEncodingStage.forward
forward(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> ForwardBatch

Encode the prompt into image encoder hidden states.

Parameters:

Name Type Description Default
batch ForwardBatch

The current batch information.

required
fastvideo_args FastVideoArgs

The inference arguments.

required

Returns:

Type Description
ForwardBatch

The batch with encoded prompt embeddings.

Source code in fastvideo/pipelines/stages/image_encoding.py
@torch.no_grad()
def forward(
    self,
    batch: ForwardBatch,
    fastvideo_args: FastVideoArgs,
) -> ForwardBatch:
    """
    Encode the prompt into image encoder hidden states.

    Args:
        batch: The current batch information.
        fastvideo_args: The inference arguments.

    Returns:
        The batch with encoded prompt embeddings.
    """
    self.image_encoder = self.image_encoder.to(get_local_torch_device())

    image = batch.pil_image
    if image is None:
        image = create_default_image()
    # Preprocess reference image for CLIP encoder
    image_tensor = preprocess_reference_image_for_clip(
        image, get_local_torch_device())

    image_inputs = self.image_processor(images=image_tensor,
                                        return_tensors="pt").to(
                                            get_local_torch_device())
    with set_forward_context(current_timestep=0, attn_metadata=None):
        outputs = self.image_encoder(**image_inputs)
        image_embeds = outputs.last_hidden_state
    batch.image_embeds.append(image_embeds)

    if batch.pil_image is None:
        batch.image_embeds = [
            torch.zeros_like(x) for x in batch.image_embeds
        ]

    return batch
fastvideo.pipelines.stages.image_encoding.VideoVAEEncodingStage
VideoVAEEncodingStage(vae: ParallelTiledVAE)

Bases: ImageVAEEncodingStage

Stage for encoding video pixel representations into latent space.

This stage handles the encoding of video pixel representations for video-to-video generation and control. Inherits from ImageVAEEncodingStage to reuse common functionality.

Source code in fastvideo/pipelines/stages/image_encoding.py
def __init__(self, vae: ParallelTiledVAE) -> None:
    self.vae: ParallelTiledVAE = vae
Functions
fastvideo.pipelines.stages.image_encoding.VideoVAEEncodingStage.forward
forward(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> ForwardBatch

Encode video pixel representations into latent space.

Parameters:

Name Type Description Default
batch ForwardBatch

The current batch information.

required
fastvideo_args FastVideoArgs

The inference arguments.

required

Returns:

Type Description
ForwardBatch

The batch with encoded outputs.

Source code in fastvideo/pipelines/stages/image_encoding.py
def forward(
    self,
    batch: ForwardBatch,
    fastvideo_args: FastVideoArgs,
) -> ForwardBatch:
    """
    Encode video pixel representations into latent space.

    Args:
        batch: The current batch information.
        fastvideo_args: The inference arguments.

    Returns:
        The batch with encoded outputs.
    """
    assert batch.video_latent is not None, "Video latent input is required for VideoVAEEncodingStage"

    if fastvideo_args.mode == ExecutionMode.INFERENCE:
        assert batch.height is not None and isinstance(batch.height, int)
        assert batch.width is not None and isinstance(batch.width, int)
        assert batch.num_frames is not None and isinstance(
            batch.num_frames, int)
        height = batch.height
        width = batch.width
        num_frames = batch.num_frames
    elif fastvideo_args.mode == ExecutionMode.PREPROCESS:
        assert batch.height is not None and isinstance(batch.height, list)
        assert batch.width is not None and isinstance(batch.width, list)
        assert batch.num_frames is not None and isinstance(
            batch.num_frames, list)
        num_frames = batch.num_frames[0]
        height = batch.height[0]
        width = batch.width[0]

    self.vae = self.vae.to(get_local_torch_device())

    # Prepare video tensor from control video
    video_condition = self._prepare_control_video_tensor(
        batch.video_latent, num_frames, height,
        width).to(get_local_torch_device(), dtype=torch.float32)

    # Setup VAE precision
    vae_dtype = PRECISION_TO_TYPE[
        fastvideo_args.pipeline_config.vae_precision]
    vae_autocast_enabled = (
        vae_dtype != torch.float32) and not fastvideo_args.disable_autocast

    # Encode control video
    with torch.autocast(device_type="cuda",
                        dtype=vae_dtype,
                        enabled=vae_autocast_enabled):
        if fastvideo_args.pipeline_config.vae_tiling:
            self.vae.enable_tiling()
        if not vae_autocast_enabled:
            video_condition = video_condition.to(vae_dtype)
        encoder_output = self.vae.encode(video_condition)

    generator = batch.generator
    if generator is None:
        raise ValueError("Generator must be provided")
    latent_condition = self.retrieve_latents(encoder_output, generator)

    if (hasattr(self.vae, "shift_factor")
            and self.vae.shift_factor is not None):
        if isinstance(self.vae.shift_factor, torch.Tensor):
            latent_condition -= self.vae.shift_factor.to(
                latent_condition.device, latent_condition.dtype)
        else:
            latent_condition -= self.vae.shift_factor

    if isinstance(self.vae.scaling_factor, torch.Tensor):
        latent_condition = latent_condition * self.vae.scaling_factor.to(
            latent_condition.device, latent_condition.dtype)
    else:
        latent_condition = latent_condition * self.vae.scaling_factor

    batch.video_latent = latent_condition

    # Offload models if needed
    if hasattr(self, 'maybe_free_model_hooks'):
        self.maybe_free_model_hooks()

    self.vae.to("cpu")

    return batch
fastvideo.pipelines.stages.image_encoding.VideoVAEEncodingStage.verify_input
verify_input(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult

Verify video encoding stage inputs.

Source code in fastvideo/pipelines/stages/image_encoding.py
def verify_input(self, batch: ForwardBatch,
                 fastvideo_args: FastVideoArgs) -> VerificationResult:
    """Verify video encoding stage inputs."""
    result = VerificationResult()
    result.add_check("video_latent", batch.video_latent, V.not_none)
    result.add_check("generator", batch.generator,
                     V.generator_or_list_generators)
    if fastvideo_args.mode == ExecutionMode.PREPROCESS:
        result.add_check("height", batch.height, V.list_not_empty)
        result.add_check("width", batch.width, V.list_not_empty)
        result.add_check("num_frames", batch.num_frames, V.list_not_empty)
    else:
        result.add_check("height", batch.height, V.positive_int)
        result.add_check("width", batch.width, V.positive_int)
        result.add_check("num_frames", batch.num_frames, V.positive_int)
    return result
fastvideo.pipelines.stages.image_encoding.VideoVAEEncodingStage.verify_output
verify_output(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult

Verify video encoding stage outputs.

Source code in fastvideo/pipelines/stages/image_encoding.py
def verify_output(self, batch: ForwardBatch,
                  fastvideo_args: FastVideoArgs) -> VerificationResult:
    """Verify video encoding stage outputs."""
    result = VerificationResult()
    result.add_check("video_latent", batch.video_latent,
                     [V.is_tensor, V.with_dims(5)])
    return result

Functions

fastvideo.pipelines.stages.input_validation

Input validation stage for diffusion pipelines.

Classes

fastvideo.pipelines.stages.input_validation.InputValidationStage

Bases: PipelineStage

Stage for validating and preparing inputs for diffusion pipelines.

This stage validates that all required inputs are present and properly formatted before proceeding with the diffusion process.

Functions
fastvideo.pipelines.stages.input_validation.InputValidationStage.forward
forward(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> ForwardBatch

Validate and prepare inputs.

Parameters:

Name Type Description Default
batch ForwardBatch

The current batch information.

required
fastvideo_args FastVideoArgs

The inference arguments.

required

Returns:

Type Description
ForwardBatch

The validated batch information.

Source code in fastvideo/pipelines/stages/input_validation.py
def forward(
    self,
    batch: ForwardBatch,
    fastvideo_args: FastVideoArgs,
) -> ForwardBatch:
    """
    Validate and prepare inputs.

    Args:
        batch: The current batch information.
        fastvideo_args: The inference arguments.

    Returns:
        The validated batch information.
    """

    self._generate_seeds(batch, fastvideo_args)

    # Ensure prompt is properly formatted
    if batch.prompt is None and batch.prompt_embeds is None:
        raise ValueError(
            "Either `prompt` or `prompt_embeds` must be provided")

    # Ensure negative prompt is properly formatted if using classifier-free guidance
    if (batch.do_classifier_free_guidance and batch.negative_prompt is None
            and batch.negative_prompt_embeds is None):
        raise ValueError(
            "For classifier-free guidance, either `negative_prompt` or "
            "`negative_prompt_embeds` must be provided")

    # Validate height and width
    if batch.height is None or batch.width is None:
        raise ValueError(
            "Height and width must be provided. Please set `height` and `width`."
        )
    if batch.height % 8 != 0 or batch.width % 8 != 0:
        raise ValueError(
            f"Height and width must be divisible by 8 but are {batch.height} and {batch.width}."
        )

    # Validate number of inference steps
    if batch.num_inference_steps <= 0:
        raise ValueError(
            f"Number of inference steps must be positive, but got {batch.num_inference_steps}"
        )

    # Validate guidance scale if using classifier-free guidance
    if batch.do_classifier_free_guidance and batch.guidance_scale <= 0:
        raise ValueError(
            f"Guidance scale must be positive, but got {batch.guidance_scale}"
        )

    # for i2v, get image from image_path
    # @TODO(Wei) hard-coded for wan2.2 5b ti2v for now. Should put this in image_encoding stage
    if batch.image_path is not None:
        if batch.image_path.endswith(".mp4"):
            image = load_video(batch.image_path)[0]
        else:
            image = load_image(batch.image_path)
        batch.pil_image = image

    # further processing for ti2v task
    if fastvideo_args.pipeline_config.ti2v_task and batch.pil_image is not None:
        img = batch.pil_image
        ih, iw = img.height, img.width
        patch_size = fastvideo_args.pipeline_config.dit_config.arch_config.patch_size
        vae_stride = fastvideo_args.pipeline_config.vae_config.arch_config.scale_factor_spatial
        dh, dw = patch_size[1] * vae_stride, patch_size[2] * vae_stride
        max_area = 704 * 1280
        ow, oh = best_output_size(iw, ih, dw, dh, max_area)

        scale = max(ow / iw, oh / ih)
        img = img.resize((round(iw * scale), round(ih * scale)),
                         Image.LANCZOS)
        logger.info("resized img height: %s, img width: %s", img.height,
                    img.width)

        # center-crop
        x1 = (img.width - ow) // 2
        y1 = (img.height - oh) // 2
        img = img.crop((x1, y1, x1 + ow, y1 + oh))
        assert img.width == ow and img.height == oh

        # to tensor
        img = TF.to_tensor(img).sub_(0.5).div_(0.5).to(
            self.device).unsqueeze(1)
        img = img.unsqueeze(0)
        batch.height = oh
        batch.width = ow
        batch.pil_image = img

    # for v2v, get control video from video path
    if batch.video_path is not None:
        pil_images, original_fps = load_video(batch.video_path,
                                              return_fps=True)
        logger.info("Loaded video with %s frames, original FPS: %s",
                    len(pil_images), original_fps)

        # Get target parameters from batch
        target_fps = batch.fps
        target_num_frames = batch.num_frames
        target_height = batch.height
        target_width = batch.width

        if target_fps is not None and original_fps is not None:
            frame_skip = max(1, int(original_fps // target_fps))
            if frame_skip > 1:
                pil_images = pil_images[::frame_skip]
                effective_fps = original_fps / frame_skip
                logger.info(
                    "Resampled video from %.1f fps to %.1f fps (skip=%s)",
                    original_fps, effective_fps, frame_skip)

        # Limit to target number of frames
        if target_num_frames is not None and len(
                pil_images) > target_num_frames:
            pil_images = pil_images[:target_num_frames]
            logger.info("Limited video to %s frames (from %s total)",
                        target_num_frames, len(pil_images))

        # Resize each PIL image to target dimensions
        resized_images = []
        for pil_img in pil_images:
            resized_img = resize(pil_img,
                                 target_height,
                                 target_width,
                                 resize_mode="default",
                                 resample="lanczos")
            resized_images.append(resized_img)

        # Convert PIL images to numpy array
        video_numpy = pil_to_numpy(resized_images)
        video_numpy = normalize(video_numpy)
        video_tensor = numpy_to_pt(video_numpy)

        # Rearrange to [C, T, H, W] and add batch dimension -> [B, C, T, H, W]
        input_video = video_tensor.permute(1, 0, 2, 3).unsqueeze(0)

        batch.video_latent = input_video

    return batch
fastvideo.pipelines.stages.input_validation.InputValidationStage.verify_input
verify_input(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult

Verify input validation stage inputs.

Source code in fastvideo/pipelines/stages/input_validation.py
def verify_input(self, batch: ForwardBatch,
                 fastvideo_args: FastVideoArgs) -> VerificationResult:
    """Verify input validation stage inputs."""
    result = VerificationResult()
    result.add_check("seed", batch.seed, [V.not_none, V.positive_int])
    result.add_check("num_videos_per_prompt", batch.num_videos_per_prompt,
                     V.positive_int)
    result.add_check(
        "prompt_or_embeds", None, lambda _: V.string_or_list_strings(
            batch.prompt) or V.list_not_empty(batch.prompt_embeds))
    result.add_check("height", batch.height, V.positive_int)
    result.add_check("width", batch.width, V.positive_int)
    result.add_check("num_inference_steps", batch.num_inference_steps,
                     V.positive_int)
    result.add_check(
        "guidance_scale", batch.guidance_scale, lambda x: not batch.
        do_classifier_free_guidance or V.positive_float(x))
    return result
fastvideo.pipelines.stages.input_validation.InputValidationStage.verify_output
verify_output(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult

Verify input validation stage outputs.

Source code in fastvideo/pipelines/stages/input_validation.py
def verify_output(self, batch: ForwardBatch,
                  fastvideo_args: FastVideoArgs) -> VerificationResult:
    """Verify input validation stage outputs."""
    result = VerificationResult()
    result.add_check("seeds", batch.seeds, V.list_not_empty)
    result.add_check("generator", batch.generator,
                     V.generator_or_list_generators)
    return result

Functions

fastvideo.pipelines.stages.latent_preparation

Latent preparation stage for diffusion pipelines.

Classes

fastvideo.pipelines.stages.latent_preparation.CosmosLatentPreparationStage
CosmosLatentPreparationStage(scheduler, transformer, vae=None)

Bases: PipelineStage

Cosmos-specific latent preparation stage that properly handles the tensor shapes and conditioning masks required by the Cosmos transformer.

This stage replicates the logic from diffusers' Cosmos2VideoToWorldPipeline.prepare_latents()

Source code in fastvideo/pipelines/stages/latent_preparation.py
def __init__(self, scheduler, transformer, vae=None) -> None:
    super().__init__()
    self.scheduler = scheduler
    self.transformer = transformer
    self.vae = vae
Functions
fastvideo.pipelines.stages.latent_preparation.CosmosLatentPreparationStage.adjust_video_length
adjust_video_length(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> int

Adjust video length based on VAE version.

Parameters:

Name Type Description Default
batch ForwardBatch

The current batch information.

required
fastvideo_args FastVideoArgs

The inference arguments.

required

Returns:

Type Description
int

The batch with adjusted video length.

Source code in fastvideo/pipelines/stages/latent_preparation.py
def adjust_video_length(self, batch: ForwardBatch,
                        fastvideo_args: FastVideoArgs) -> int:
    """
    Adjust video length based on VAE version.

    Args:
        batch: The current batch information.
        fastvideo_args: The inference arguments.

    Returns:
        The batch with adjusted video length.
    """

    video_length = batch.num_frames
    use_temporal_scaling_frames = fastvideo_args.pipeline_config.vae_config.use_temporal_scaling_frames
    if use_temporal_scaling_frames:
        temporal_scale_factor = fastvideo_args.pipeline_config.vae_config.arch_config.temporal_compression_ratio
        latent_num_frames = (video_length - 1) // temporal_scale_factor + 1
    else:  # stepvideo only
        latent_num_frames = video_length // 17 * 3
    return int(latent_num_frames)
fastvideo.pipelines.stages.latent_preparation.CosmosLatentPreparationStage.verify_input
verify_input(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult

Verify Cosmos latent preparation stage inputs.

Source code in fastvideo/pipelines/stages/latent_preparation.py
def verify_input(self, batch: ForwardBatch,
                 fastvideo_args: FastVideoArgs) -> VerificationResult:
    """Verify Cosmos latent preparation stage inputs."""
    result = VerificationResult()
    result.add_check(
        "prompt_or_embeds", None, lambda _: V.string_or_list_strings(
            batch.prompt) or V.list_not_empty(batch.prompt_embeds))
    result.add_check("prompt_embeds", batch.prompt_embeds,
                     V.list_of_tensors)
    result.add_check("num_videos_per_prompt", batch.num_videos_per_prompt,
                     V.positive_int)
    result.add_check("generator", batch.generator,
                     V.generator_or_list_generators)
    result.add_check("num_frames", batch.num_frames, V.positive_int)
    result.add_check("height", batch.height, V.positive_int)
    result.add_check("width", batch.width, V.positive_int)
    result.add_check("latents", batch.latents, V.none_or_tensor)
    return result
fastvideo.pipelines.stages.latent_preparation.CosmosLatentPreparationStage.verify_output
verify_output(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult

Verify latent preparation stage outputs.

Source code in fastvideo/pipelines/stages/latent_preparation.py
def verify_output(self, batch: ForwardBatch,
                  fastvideo_args: FastVideoArgs) -> VerificationResult:
    """Verify latent preparation stage outputs."""
    result = VerificationResult()
    result.add_check("latents", batch.latents,
                     [V.is_tensor, V.with_dims(5)])
    result.add_check("raw_latent_shape", batch.raw_latent_shape, V.is_tuple)
    return result
fastvideo.pipelines.stages.latent_preparation.LatentPreparationStage
LatentPreparationStage(scheduler, transformer)

Bases: PipelineStage

Stage for preparing initial latent variables for the diffusion process.

This stage handles the preparation of the initial latent variables that will be denoised during the diffusion process.

Source code in fastvideo/pipelines/stages/latent_preparation.py
def __init__(self, scheduler, transformer) -> None:
    super().__init__()
    self.scheduler = scheduler
    self.transformer = transformer
Functions
fastvideo.pipelines.stages.latent_preparation.LatentPreparationStage.adjust_video_length
adjust_video_length(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> int

Adjust video length based on VAE version.

Parameters:

Name Type Description Default
batch ForwardBatch

The current batch information.

required
fastvideo_args FastVideoArgs

The inference arguments.

required

Returns:

Type Description
int

The batch with adjusted video length.

Source code in fastvideo/pipelines/stages/latent_preparation.py
def adjust_video_length(self, batch: ForwardBatch,
                        fastvideo_args: FastVideoArgs) -> int:
    """
    Adjust video length based on VAE version.

    Args:
        batch: The current batch information.
        fastvideo_args: The inference arguments.

    Returns:
        The batch with adjusted video length.
    """

    video_length = batch.num_frames
    use_temporal_scaling_frames = fastvideo_args.pipeline_config.vae_config.use_temporal_scaling_frames
    if use_temporal_scaling_frames:
        temporal_scale_factor = fastvideo_args.pipeline_config.vae_config.arch_config.temporal_compression_ratio
        latent_num_frames = (video_length - 1) // temporal_scale_factor + 1
    else:  # stepvideo only
        latent_num_frames = video_length // 17 * 3
    return int(latent_num_frames)
fastvideo.pipelines.stages.latent_preparation.LatentPreparationStage.forward
forward(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> ForwardBatch

Prepare initial latent variables for the diffusion process.

Parameters:

Name Type Description Default
batch ForwardBatch

The current batch information.

required
fastvideo_args FastVideoArgs

The inference arguments.

required

Returns:

Type Description
ForwardBatch

The batch with prepared latent variables.

Source code in fastvideo/pipelines/stages/latent_preparation.py
def forward(
    self,
    batch: ForwardBatch,
    fastvideo_args: FastVideoArgs,
) -> ForwardBatch:
    """
    Prepare initial latent variables for the diffusion process.

    Args:
        batch: The current batch information.
        fastvideo_args: The inference arguments.

    Returns:
        The batch with prepared latent variables.
    """

    latent_num_frames = None
    # Adjust video length based on VAE version if needed
    if hasattr(self, 'adjust_video_length'):
        latent_num_frames = self.adjust_video_length(batch, fastvideo_args)
    # Determine batch size
    if isinstance(batch.prompt, list):
        batch_size = len(batch.prompt)
    elif batch.prompt is not None:
        batch_size = 1
    else:
        batch_size = batch.prompt_embeds[0].shape[0]

    # Adjust batch size for number of videos per prompt
    batch_size *= batch.num_videos_per_prompt

    # Get required parameters
    dtype = batch.prompt_embeds[0].dtype
    device = get_local_torch_device()
    generator = batch.generator
    latents = batch.latents
    num_frames = latent_num_frames if latent_num_frames is not None else batch.num_frames
    height = batch.height
    width = batch.width

    # TODO(will): remove this once we add input/output validation for stages
    if height is None or width is None:
        raise ValueError("Height and width must be provided")

    # Calculate latent shape
    shape = (
        batch_size,
        self.transformer.num_channels_latents,
        num_frames,
        height // fastvideo_args.pipeline_config.vae_config.arch_config.
        spatial_compression_ratio,
        width // fastvideo_args.pipeline_config.vae_config.arch_config.
        spatial_compression_ratio,
    )

    # Validate generator if it's a list
    if isinstance(generator, list) and len(generator) != batch_size:
        raise ValueError(
            f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
            f" size of {batch_size}. Make sure the batch size matches the length of the generators."
        )
    # Generate or use provided latents
    if latents is None:
        latents = randn_tensor(shape,
                               generator=generator,
                               device=device,
                               dtype=dtype)
    else:
        latents = latents.to(device)

    # Scale the initial noise if needed
    if hasattr(self.scheduler, "init_noise_sigma"):
        latents = latents * self.scheduler.init_noise_sigma
    # Update batch with prepared latents
    batch.latents = latents
    batch.raw_latent_shape = latents.shape

    return batch
fastvideo.pipelines.stages.latent_preparation.LatentPreparationStage.verify_input
verify_input(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult

Verify latent preparation stage inputs.

Source code in fastvideo/pipelines/stages/latent_preparation.py
def verify_input(self, batch: ForwardBatch,
                 fastvideo_args: FastVideoArgs) -> VerificationResult:
    """Verify latent preparation stage inputs."""
    result = VerificationResult()
    result.add_check(
        "prompt_or_embeds", None, lambda _: V.string_or_list_strings(
            batch.prompt) or V.list_not_empty(batch.prompt_embeds))
    result.add_check("prompt_embeds", batch.prompt_embeds,
                     V.list_of_tensors)
    result.add_check("num_videos_per_prompt", batch.num_videos_per_prompt,
                     V.positive_int)
    result.add_check("generator", batch.generator,
                     V.generator_or_list_generators)
    result.add_check("num_frames", batch.num_frames, V.positive_int)
    result.add_check("height", batch.height, V.positive_int)
    result.add_check("width", batch.width, V.positive_int)
    result.add_check("latents", batch.latents, V.none_or_tensor)
    return result
fastvideo.pipelines.stages.latent_preparation.LatentPreparationStage.verify_output
verify_output(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult

Verify latent preparation stage outputs.

Source code in fastvideo/pipelines/stages/latent_preparation.py
def verify_output(self, batch: ForwardBatch,
                  fastvideo_args: FastVideoArgs) -> VerificationResult:
    """Verify latent preparation stage outputs."""
    result = VerificationResult()
    result.add_check("latents", batch.latents,
                     [V.is_tensor, V.with_dims(5)])
    result.add_check("raw_latent_shape", batch.raw_latent_shape, V.is_tuple)
    return result

Functions

fastvideo.pipelines.stages.stepvideo_encoding

Classes

fastvideo.pipelines.stages.stepvideo_encoding.StepvideoPromptEncodingStage
StepvideoPromptEncodingStage(stepllm, clip)

Bases: PipelineStage

Stage for encoding prompts using the remote caption API.

This stage applies the magic string transformations and calls the remote caption service asynchronously to get: - primary prompt embeddings, - an attention mask, - and a clip embedding.

Source code in fastvideo/pipelines/stages/stepvideo_encoding.py
def __init__(self, stepllm, clip) -> None:
    super().__init__()
    # self.caption_client = caption_client  # This should have a call_caption(prompts: List[str]) method.
    self.stepllm = stepllm
    self.clip = clip
Functions
fastvideo.pipelines.stages.stepvideo_encoding.StepvideoPromptEncodingStage.verify_input
verify_input(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult

Verify stepvideo encoding stage inputs.

Source code in fastvideo/pipelines/stages/stepvideo_encoding.py
def verify_input(self, batch: ForwardBatch,
                 fastvideo_args: FastVideoArgs) -> VerificationResult:
    """Verify stepvideo encoding stage inputs."""
    result = VerificationResult()
    result.add_check("prompt", batch.prompt, V.string_not_empty)
    return result
fastvideo.pipelines.stages.stepvideo_encoding.StepvideoPromptEncodingStage.verify_output
verify_output(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult

Verify stepvideo encoding stage outputs.

Source code in fastvideo/pipelines/stages/stepvideo_encoding.py
def verify_output(self, batch: ForwardBatch,
                  fastvideo_args: FastVideoArgs) -> VerificationResult:
    """Verify stepvideo encoding stage outputs."""
    result = VerificationResult()
    result.add_check("prompt_embeds", batch.prompt_embeds,
                     [V.is_tensor, V.with_dims(3)])
    result.add_check("negative_prompt_embeds", batch.negative_prompt_embeds,
                     [V.is_tensor, V.with_dims(3)])
    result.add_check("prompt_attention_mask", batch.prompt_attention_mask,
                     [V.is_tensor, V.with_dims(2)])
    result.add_check("negative_attention_mask",
                     batch.negative_attention_mask,
                     [V.is_tensor, V.with_dims(2)])
    result.add_check("clip_embedding_pos", batch.clip_embedding_pos,
                     [V.is_tensor, V.with_dims(2)])
    result.add_check("clip_embedding_neg", batch.clip_embedding_neg,
                     [V.is_tensor, V.with_dims(2)])
    return result

Functions

fastvideo.pipelines.stages.text_encoding

Prompt encoding stages for diffusion pipelines.

This module contains implementations of prompt encoding stages for diffusion pipelines.

Classes

fastvideo.pipelines.stages.text_encoding.TextEncodingStage
TextEncodingStage(text_encoders, tokenizers)

Bases: PipelineStage

Stage for encoding text prompts into embeddings for diffusion models.

This stage handles the encoding of text prompts into the embedding space expected by the diffusion model.

Initialize the prompt encoding stage.

Parameters:

Name Type Description Default
enable_logging

Whether to enable logging for this stage.

required
is_secondary

Whether this is a secondary text encoder.

required
Source code in fastvideo/pipelines/stages/text_encoding.py
def __init__(self, text_encoders, tokenizers) -> None:
    """
    Initialize the prompt encoding stage.

    Args:
        enable_logging: Whether to enable logging for this stage.
        is_secondary: Whether this is a secondary text encoder.
    """
    super().__init__()
    self.tokenizers = tokenizers
    self.text_encoders = text_encoders
Functions
fastvideo.pipelines.stages.text_encoding.TextEncodingStage.encode_text
encode_text(text: str | list[str], fastvideo_args: FastVideoArgs, encoder_index: int | list[int] | None = None, return_attention_mask: bool = False, return_type: str = 'list', device: device | str | None = None, dtype: dtype | None = None, max_length: int | None = None, truncation: bool | None = None, padding: bool | str | None = None)

Encode plain text using selected text encoder(s) and return embeddings.

Parameters:

Name Type Description Default
text str | list[str]

A single string or a list of strings to encode.

required
fastvideo_args FastVideoArgs

The inference arguments providing pipeline config, including tokenizer and encoder settings, preprocess and postprocess functions.

required
encoder_index int | list[int] | None

Encoder selector by index. Accepts an int or list of ints.

None
return_attention_mask bool

If True, also return attention masks for each selected encoder.

False
return_type str

"list" (default) returns a list aligned with selection; "dict" returns a dict keyed by encoder index as a string; "stack" stacks along a new first dimension (requires matching shapes).

'list'
device device | str | None

Optional device override for inputs; defaults to local torch device.

None
dtype dtype | None

Optional dtype to cast returned embeddings to.

None
max_length int | None

Optional per-call tokenizer override.

None
truncation bool | None

Optional per-call tokenizer override.

None
padding bool | str | None

Optional per-call tokenizer override.

None

Returns:

Type Description

Depending on return_type and return_attention_mask:

  • list: List[Tensor] or (List[Tensor], List[Tensor])
  • dict: Dict[str, Tensor] or (Dict[str, Tensor], Dict[str, Tensor])
  • stack: Tensor of shape [num_encoders, ...] or a tuple with stacked attention masks
Source code in fastvideo/pipelines/stages/text_encoding.py
@torch.no_grad()
def encode_text(
    self,
    text: str | list[str],
    fastvideo_args: FastVideoArgs,
    encoder_index: int | list[int] | None = None,
    return_attention_mask: bool = False,
    return_type: str = "list",  # one of: "list", "dict", "stack"
    device: torch.device | str | None = None,
    dtype: torch.dtype | None = None,
    max_length: int | None = None,
    truncation: bool | None = None,
    padding: bool | str | None = None,
):
    """
    Encode plain text using selected text encoder(s) and return embeddings.

    Args:
        text: A single string or a list of strings to encode.
        fastvideo_args: The inference arguments providing pipeline config,
            including tokenizer and encoder settings, preprocess and postprocess
            functions.
        encoder_index: Encoder selector by index. Accepts an int or list of ints.
        return_attention_mask: If True, also return attention masks for each
            selected encoder.
        return_type: "list" (default) returns a list aligned with selection;
            "dict" returns a dict keyed by encoder index as a string; "stack" stacks along a
            new first dimension (requires matching shapes).
        device: Optional device override for inputs; defaults to local torch device.
        dtype: Optional dtype to cast returned embeddings to.
        max_length: Optional per-call tokenizer override.
        truncation: Optional per-call tokenizer override.
        padding: Optional per-call tokenizer override.

    Returns:
        Depending on return_type and return_attention_mask:
        - list: List[Tensor] or (List[Tensor], List[Tensor])
        - dict: Dict[str, Tensor] or (Dict[str, Tensor], Dict[str, Tensor])
        - stack: Tensor of shape [num_encoders, ...] or a tuple with stacked
          attention masks
    """

    assert len(self.tokenizers) == len(self.text_encoders)
    assert len(self.text_encoders) == len(
        fastvideo_args.pipeline_config.text_encoder_configs)

    # Resolve selection into indices
    encoder_cfgs = fastvideo_args.pipeline_config.text_encoder_configs
    if encoder_index is None:
        indices: list[int] = [0]
    elif isinstance(encoder_index, int):
        indices = [encoder_index]
    else:
        indices = list(encoder_index)
    # validate range
    num_encoders = len(self.text_encoders)
    for idx in indices:
        if idx < 0 or idx >= num_encoders:
            raise IndexError(
                f"encoder index {idx} out of range [0, {num_encoders-1}]")

    # Validate indices are within range
    num_encoders = len(self.text_encoders)

    # Normalize input to list[str]
    assert isinstance(text, str | list)
    if isinstance(text, str):
        texts: list[str] = [text]
    else:
        texts = text

    embeds_list: list[torch.Tensor] = []
    attn_masks_list: list[torch.Tensor] = []

    preprocess_funcs = fastvideo_args.pipeline_config.preprocess_text_funcs
    postprocess_funcs = fastvideo_args.pipeline_config.postprocess_text_funcs
    encoder_cfgs = fastvideo_args.pipeline_config.text_encoder_configs

    if return_type not in ("list", "dict", "stack"):
        raise ValueError(
            f"Invalid return_type '{return_type}'. Expected one of: 'list', 'dict', 'stack'"
        )

    target_device = device if device is not None else get_local_torch_device(
    )

    for i in indices:
        tokenizer = self.tokenizers[i]
        text_encoder = self.text_encoders[i]
        encoder_config = encoder_cfgs[i]
        preprocess_func = preprocess_funcs[i]
        postprocess_func = postprocess_funcs[i]

        processed_texts: list[str] = []
        for prompt_str in texts:
            processed_texts.append(preprocess_func(prompt_str))

        tok_kwargs = dict(encoder_config.tokenizer_kwargs)
        if max_length is not None:
            tok_kwargs["max_length"] = max_length
        if truncation is not None:
            tok_kwargs["truncation"] = truncation
        if padding is not None:
            tok_kwargs["padding"] = padding

        text_inputs = tokenizer(processed_texts,
                                **tok_kwargs).to(target_device)

        input_ids = text_inputs["input_ids"]
        attention_mask = text_inputs["attention_mask"]

        with set_forward_context(current_timestep=0, attn_metadata=None):
            outputs = text_encoder(
                input_ids=input_ids,
                attention_mask=attention_mask,
                output_hidden_states=True,
            )

        prompt_embeds = postprocess_func(outputs)
        if dtype is not None:
            prompt_embeds = prompt_embeds.to(dtype=dtype)
        embeds_list.append(prompt_embeds)
        if return_attention_mask:
            attn_masks_list.append(attention_mask)

    # Shape results according to return_type
    if return_type == "list":
        if return_attention_mask:
            return embeds_list, attn_masks_list
        return embeds_list

    if return_type == "dict":
        key_strs = [str(i) for i in indices]
        embeds_dict = {
            k: v
            for k, v in zip(key_strs, embeds_list, strict=False)
        }
        if return_attention_mask:
            attn_dict = {
                k: v
                for k, v in zip(key_strs, attn_masks_list, strict=False)
            }
            return embeds_dict, attn_dict
        return embeds_dict

    # return_type == "stack"
    # Validate shapes are compatible
    base_shape = list(embeds_list[0].shape)
    for t in embeds_list[1:]:
        if list(t.shape) != base_shape:
            raise ValueError(
                f"Cannot stack embeddings with differing shapes: {[list(t.shape) for t in embeds_list]}"
            )
    stacked_embeds = torch.stack(embeds_list, dim=0)
    if return_attention_mask:
        base_mask_shape = list(attn_masks_list[0].shape)
        for m in attn_masks_list[1:]:
            if list(m.shape) != base_mask_shape:
                raise ValueError(
                    f"Cannot stack attention masks with differing shapes: {[list(m.shape) for m in attn_masks_list]}"
                )
        stacked_masks = torch.stack(attn_masks_list, dim=0)
        return stacked_embeds, stacked_masks
    return stacked_embeds
fastvideo.pipelines.stages.text_encoding.TextEncodingStage.forward
forward(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> ForwardBatch

Encode the prompt into text encoder hidden states.

Parameters:

Name Type Description Default
batch ForwardBatch

The current batch information.

required
fastvideo_args FastVideoArgs

The inference arguments.

required

Returns:

Type Description
ForwardBatch

The batch with encoded prompt embeddings.

Source code in fastvideo/pipelines/stages/text_encoding.py
@torch.no_grad()
def forward(
    self,
    batch: ForwardBatch,
    fastvideo_args: FastVideoArgs,
) -> ForwardBatch:
    """
    Encode the prompt into text encoder hidden states.

    Args:
        batch: The current batch information.
        fastvideo_args: The inference arguments.

    Returns:
        The batch with encoded prompt embeddings.
    """
    assert len(self.tokenizers) == len(self.text_encoders)
    assert len(self.text_encoders) == len(
        fastvideo_args.pipeline_config.text_encoder_configs)

    # Encode positive prompt with all available encoders
    assert batch.prompt is not None
    prompt_text: str | list[str] = batch.prompt
    all_indices: list[int] = list(range(len(self.text_encoders)))
    prompt_embeds_list, prompt_masks_list = self.encode_text(
        prompt_text,
        fastvideo_args,
        encoder_index=all_indices,
        return_attention_mask=True,
    )

    for pe in prompt_embeds_list:
        batch.prompt_embeds.append(pe)
    if batch.prompt_attention_mask is not None:
        for am in prompt_masks_list:
            batch.prompt_attention_mask.append(am)

    # Encode negative prompt if CFG is enabled
    if batch.do_classifier_free_guidance:
        assert isinstance(batch.negative_prompt, str)
        neg_embeds_list, neg_masks_list = self.encode_text(
            batch.negative_prompt,
            fastvideo_args,
            encoder_index=all_indices,
            return_attention_mask=True,
        )

        assert batch.negative_prompt_embeds is not None
        for ne in neg_embeds_list:
            batch.negative_prompt_embeds.append(ne)
        if batch.negative_attention_mask is not None:
            for nm in neg_masks_list:
                batch.negative_attention_mask.append(nm)

    return batch
fastvideo.pipelines.stages.text_encoding.TextEncodingStage.verify_input
verify_input(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult

Verify text encoding stage inputs.

Source code in fastvideo/pipelines/stages/text_encoding.py
def verify_input(self, batch: ForwardBatch,
                 fastvideo_args: FastVideoArgs) -> VerificationResult:
    """Verify text encoding stage inputs."""
    result = VerificationResult()
    result.add_check("prompt", batch.prompt, V.string_or_list_strings)
    result.add_check(
        "negative_prompt", batch.negative_prompt, lambda x: not batch.
        do_classifier_free_guidance or V.string_not_empty(x))
    result.add_check("do_classifier_free_guidance",
                     batch.do_classifier_free_guidance, V.bool_value)
    result.add_check("prompt_embeds", batch.prompt_embeds, V.is_list)
    result.add_check("negative_prompt_embeds", batch.negative_prompt_embeds,
                     V.none_or_list)
    return result
fastvideo.pipelines.stages.text_encoding.TextEncodingStage.verify_output
verify_output(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult

Verify text encoding stage outputs.

Source code in fastvideo/pipelines/stages/text_encoding.py
def verify_output(self, batch: ForwardBatch,
                  fastvideo_args: FastVideoArgs) -> VerificationResult:
    """Verify text encoding stage outputs."""
    result = VerificationResult()
    result.add_check("prompt_embeds", batch.prompt_embeds,
                     V.list_of_tensors_min_dims(2))
    result.add_check(
        "negative_prompt_embeds", batch.negative_prompt_embeds,
        lambda x: not batch.do_classifier_free_guidance or V.
        list_of_tensors_with_min_dims(x, 2))
    return result

Functions

fastvideo.pipelines.stages.timestep_preparation

Timestep preparation stages for diffusion pipelines.

This module contains implementations of timestep preparation stages for diffusion pipelines.

Classes

fastvideo.pipelines.stages.timestep_preparation.TimestepPreparationStage
TimestepPreparationStage(scheduler)

Bases: PipelineStage

Stage for preparing timesteps for the diffusion process.

This stage handles the preparation of the timestep sequence that will be used during the diffusion process.

Source code in fastvideo/pipelines/stages/timestep_preparation.py
def __init__(self, scheduler) -> None:
    self.scheduler = scheduler
Functions
fastvideo.pipelines.stages.timestep_preparation.TimestepPreparationStage.forward
forward(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> ForwardBatch

Prepare timesteps for the diffusion process.

Parameters:

Name Type Description Default
batch ForwardBatch

The current batch information.

required
fastvideo_args FastVideoArgs

The inference arguments.

required

Returns:

Type Description
ForwardBatch

The batch with prepared timesteps.

Source code in fastvideo/pipelines/stages/timestep_preparation.py
def forward(
    self,
    batch: ForwardBatch,
    fastvideo_args: FastVideoArgs,
) -> ForwardBatch:
    """
    Prepare timesteps for the diffusion process.

    Args:
        batch: The current batch information.
        fastvideo_args: The inference arguments.

    Returns:
        The batch with prepared timesteps.
    """
    scheduler = self.scheduler
    device = get_local_torch_device()
    num_inference_steps = batch.num_inference_steps
    timesteps = batch.timesteps
    sigmas = batch.sigmas
    n_tokens = batch.n_tokens

    # Prepare extra kwargs for set_timesteps
    extra_set_timesteps_kwargs = {}
    if n_tokens is not None and "n_tokens" in inspect.signature(
            scheduler.set_timesteps).parameters:
        extra_set_timesteps_kwargs["n_tokens"] = n_tokens

    # Handle custom timesteps or sigmas
    if timesteps is not None and sigmas is not None:
        raise ValueError(
            "Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values"
        )

    if timesteps is not None:
        accepts_timesteps = "timesteps" in inspect.signature(
            scheduler.set_timesteps).parameters
        if not accepts_timesteps:
            raise ValueError(
                f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
                f" timestep schedules. Please check whether you are using the correct scheduler."
            )
        scheduler.set_timesteps(timesteps=timesteps,
                                device=device,
                                **extra_set_timesteps_kwargs)
        timesteps = scheduler.timesteps
    elif sigmas is not None:
        accept_sigmas = "sigmas" in inspect.signature(
            scheduler.set_timesteps).parameters
        if not accept_sigmas:
            raise ValueError(
                f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
                f" sigmas schedules. Please check whether you are using the correct scheduler."
            )
        scheduler.set_timesteps(sigmas=sigmas,
                                device=device,
                                **extra_set_timesteps_kwargs)
        timesteps = scheduler.timesteps
    else:
        scheduler.set_timesteps(num_inference_steps,
                                device=device,
                                **extra_set_timesteps_kwargs)
        timesteps = scheduler.timesteps

    # Update batch with prepared timesteps
    batch.timesteps = timesteps

    return batch
fastvideo.pipelines.stages.timestep_preparation.TimestepPreparationStage.verify_input
verify_input(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult

Verify timestep preparation stage inputs.

Source code in fastvideo/pipelines/stages/timestep_preparation.py
def verify_input(self, batch: ForwardBatch,
                 fastvideo_args: FastVideoArgs) -> VerificationResult:
    """Verify timestep preparation stage inputs."""
    result = VerificationResult()
    result.add_check("num_inference_steps", batch.num_inference_steps,
                     V.positive_int)
    result.add_check("timesteps", batch.timesteps, V.none_or_tensor)
    result.add_check("sigmas", batch.sigmas, V.none_or_list)
    result.add_check("n_tokens", batch.n_tokens, V.none_or_positive_int)
    return result
fastvideo.pipelines.stages.timestep_preparation.TimestepPreparationStage.verify_output
verify_output(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult

Verify timestep preparation stage outputs.

Source code in fastvideo/pipelines/stages/timestep_preparation.py
def verify_output(self, batch: ForwardBatch,
                  fastvideo_args: FastVideoArgs) -> VerificationResult:
    """Verify timestep preparation stage outputs."""
    result = VerificationResult()
    result.add_check("timesteps", batch.timesteps,
                     [V.is_tensor, V.with_dims(1)])
    return result

Functions

fastvideo.pipelines.stages.utils

Utility functions for pipeline stages.

Functions

fastvideo.pipelines.stages.utils.retrieve_timesteps
retrieve_timesteps(scheduler: Any, num_inference_steps: int | None = None, device: str | device | None = None, timesteps: list[int] | None = None, sigmas: list[float] | None = None, **kwargs: Any) -> tuple[Any, int]

Calls the scheduler's set_timesteps method and retrieves timesteps from the scheduler after the call. Handles custom timesteps. Any kwargs will be supplied to scheduler.set_timesteps.

Parameters:

Name Type Description Default
scheduler `SchedulerMixin`

The scheduler to get timesteps from.

required
num_inference_steps `int`

The number of diffusion steps used when generating samples with a pre-trained model. If used, timesteps must be None.

None
device `str` or `torch.device`, *optional*

The device to which the timesteps should be moved to. If None, the timesteps are not moved.

None
timesteps `List[int]`, *optional*

Custom timesteps used to override the timestep spacing strategy of the scheduler. If timesteps is passed, num_inference_steps and sigmas must be None.

None
sigmas `List[float]`, *optional*

Custom sigmas used to override the timestep spacing strategy of the scheduler. If sigmas is passed, num_inference_steps and timesteps must be None.

None

Returns:

Type Description
Any

Tuple[torch.Tensor, int]: A tuple where the first element is the timestep schedule and the

int

second element is the number of inference steps.

Source code in fastvideo/pipelines/stages/utils.py
def retrieve_timesteps(
    scheduler: Any,
    num_inference_steps: int | None = None,
    device: str | torch.device | None = None,
    timesteps: list[int] | None = None,
    sigmas: list[float] | None = None,
    **kwargs: Any,
) -> tuple[Any, int]:
    """
    Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
    custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.

    Args:
        scheduler (`SchedulerMixin`):
            The scheduler to get timesteps from.
        num_inference_steps (`int`):
            The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
            must be `None`.
        device (`str` or `torch.device`, *optional*):
            The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
        timesteps (`List[int]`, *optional*):
            Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
            `num_inference_steps` and `sigmas` must be `None`.
        sigmas (`List[float]`, *optional*):
            Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
            `num_inference_steps` and `timesteps` must be `None`.

    Returns:
        `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule and the
        second element is the number of inference steps.
    """
    if timesteps is not None and sigmas is not None:
        raise ValueError(
            "Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values"
        )
    if timesteps is not None:
        accepts_timesteps = "timesteps" in set(
            inspect.signature(scheduler.set_timesteps).parameters.keys())
        if not accepts_timesteps:
            raise ValueError(
                f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
                f" timestep schedules. Please check whether you are using the correct scheduler."
            )
        scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
        timesteps = scheduler.timesteps
        if timesteps is None:
            raise ValueError("scheduler.timesteps is None after set_timesteps")
        num_inference_steps = len(timesteps)
    elif sigmas is not None:
        accept_sigmas = "sigmas" in set(
            inspect.signature(scheduler.set_timesteps).parameters.keys())
        if not accept_sigmas:
            raise ValueError(
                f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
                f" sigmas schedules. Please check whether you are using the correct scheduler."
            )
        scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
        timesteps = scheduler.timesteps
        if timesteps is None:
            raise ValueError("scheduler.timesteps is None after set_timesteps")
        num_inference_steps = len(timesteps)
    else:
        scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
        timesteps = scheduler.timesteps
        if timesteps is None:
            raise ValueError("scheduler.timesteps is None after set_timesteps")
        num_inference_steps = len(timesteps)
    return timesteps, num_inference_steps

fastvideo.pipelines.stages.validators

Common validators for pipeline stage verification.

This module provides reusable validation functions that can be used across all pipeline stages for input/output verification.

Classes

fastvideo.pipelines.stages.validators.StageValidators

Common validators for pipeline stages.

Functions
fastvideo.pipelines.stages.validators.StageValidators.bool_value staticmethod
bool_value(value: Any) -> bool

Check if value is a boolean.

Source code in fastvideo/pipelines/stages/validators.py
@staticmethod
def bool_value(value: Any) -> bool:
    """Check if value is a boolean."""
    return isinstance(value, bool)
fastvideo.pipelines.stages.validators.StageValidators.divisible staticmethod
divisible(divisor: int) -> Callable[[Any], bool]

Return a validator that checks if value is divisible by divisor.

Source code in fastvideo/pipelines/stages/validators.py
@staticmethod
def divisible(divisor: int) -> Callable[[Any], bool]:
    """Return a validator that checks if value is divisible by divisor."""

    def validator(value: Any) -> bool:
        return StageValidators.divisible_by(value, divisor)

    return validator
fastvideo.pipelines.stages.validators.StageValidators.divisible_by staticmethod
divisible_by(value: Any, divisor: int) -> bool

Check if value is divisible by divisor.

Source code in fastvideo/pipelines/stages/validators.py
@staticmethod
def divisible_by(value: Any, divisor: int) -> bool:
    """Check if value is divisible by divisor."""
    return value is not None and isinstance(value,
                                            int) and value % divisor == 0
fastvideo.pipelines.stages.validators.StageValidators.generator_or_list_generators staticmethod
generator_or_list_generators(value: Any) -> bool

Check if value is a Generator or list of Generators.

Source code in fastvideo/pipelines/stages/validators.py
@staticmethod
def generator_or_list_generators(value: Any) -> bool:
    """Check if value is a Generator or list of Generators."""
    if isinstance(value, torch.Generator):
        return True
    if isinstance(value, list):
        return all(isinstance(item, torch.Generator) for item in value)
    return False
fastvideo.pipelines.stages.validators.StageValidators.is_list staticmethod
is_list(value: Any) -> bool

Check if value is a list (can be empty).

Source code in fastvideo/pipelines/stages/validators.py
@staticmethod
def is_list(value: Any) -> bool:
    """Check if value is a list (can be empty)."""
    return isinstance(value, list)
fastvideo.pipelines.stages.validators.StageValidators.is_tensor staticmethod
is_tensor(value: Any) -> bool

Check if value is a torch tensor and doesn't contain NaN values.

Source code in fastvideo/pipelines/stages/validators.py
@staticmethod
def is_tensor(value: Any) -> bool:
    """Check if value is a torch tensor and doesn't contain NaN values."""
    if not isinstance(value, torch.Tensor):
        return False
    return not torch.isnan(value).any().item()
fastvideo.pipelines.stages.validators.StageValidators.is_tuple staticmethod
is_tuple(value: Any) -> bool

Check if value is a tuple.

Source code in fastvideo/pipelines/stages/validators.py
@staticmethod
def is_tuple(value: Any) -> bool:
    """Check if value is a tuple."""
    return isinstance(value, tuple)
fastvideo.pipelines.stages.validators.StageValidators.list_length staticmethod
list_length(value: Any, length: int) -> bool

Check if list has specific length.

Source code in fastvideo/pipelines/stages/validators.py
@staticmethod
def list_length(value: Any, length: int) -> bool:
    """Check if list has specific length."""
    return isinstance(value, list) and len(value) == length
fastvideo.pipelines.stages.validators.StageValidators.list_min_length staticmethod
list_min_length(value: Any, min_length: int) -> bool

Check if list has at least min_length items.

Source code in fastvideo/pipelines/stages/validators.py
@staticmethod
def list_min_length(value: Any, min_length: int) -> bool:
    """Check if list has at least min_length items."""
    return isinstance(value, list) and len(value) >= min_length
fastvideo.pipelines.stages.validators.StageValidators.list_not_empty staticmethod
list_not_empty(value: Any) -> bool

Check if value is a non-empty list.

Source code in fastvideo/pipelines/stages/validators.py
@staticmethod
def list_not_empty(value: Any) -> bool:
    """Check if value is a non-empty list."""
    return isinstance(value, list) and len(value) > 0
fastvideo.pipelines.stages.validators.StageValidators.list_of_tensors staticmethod
list_of_tensors(value: Any) -> bool

Check if value is a non-empty list where all items are tensors without NaN values.

Source code in fastvideo/pipelines/stages/validators.py
@staticmethod
def list_of_tensors(value: Any) -> bool:
    """Check if value is a non-empty list where all items are tensors without NaN values."""
    if not isinstance(value, list) or len(value) == 0:
        return False
    for item in value:
        if not isinstance(item, torch.Tensor):
            return False
        if torch.isnan(item).any().item():
            return False
    return True
fastvideo.pipelines.stages.validators.StageValidators.list_of_tensors_dims staticmethod
list_of_tensors_dims(dims: int) -> Callable[[Any], bool]

Return a validator that checks if value is a list of tensors with specific dimensions and no NaN values.

Source code in fastvideo/pipelines/stages/validators.py
@staticmethod
def list_of_tensors_dims(dims: int) -> Callable[[Any], bool]:
    """Return a validator that checks if value is a list of tensors with specific dimensions and no NaN values."""

    def validator(value: Any) -> bool:
        return StageValidators.list_of_tensors_with_dims(value, dims)

    return validator
fastvideo.pipelines.stages.validators.StageValidators.list_of_tensors_min_dims staticmethod
list_of_tensors_min_dims(min_dims: int) -> Callable[[Any], bool]

Return a validator that checks if value is a list of tensors with at least min_dims dimensions and no NaN values.

Source code in fastvideo/pipelines/stages/validators.py
@staticmethod
def list_of_tensors_min_dims(min_dims: int) -> Callable[[Any], bool]:
    """Return a validator that checks if value is a list of tensors with at least min_dims dimensions and no NaN values."""

    def validator(value: Any) -> bool:
        return StageValidators.list_of_tensors_with_min_dims(
            value, min_dims)

    return validator
fastvideo.pipelines.stages.validators.StageValidators.list_of_tensors_with_dims staticmethod
list_of_tensors_with_dims(value: Any, dims: int) -> bool

Check if value is a non-empty list where all items are tensors with specific dimensions and no NaN values.

Source code in fastvideo/pipelines/stages/validators.py
@staticmethod
def list_of_tensors_with_dims(value: Any, dims: int) -> bool:
    """Check if value is a non-empty list where all items are tensors with specific dimensions and no NaN values."""
    if not isinstance(value, list) or len(value) == 0:
        return False
    for item in value:
        if not isinstance(item, torch.Tensor):
            return False
        if item.dim() != dims:
            return False
        if torch.isnan(item).any().item():
            return False
    return True
fastvideo.pipelines.stages.validators.StageValidators.list_of_tensors_with_min_dims staticmethod
list_of_tensors_with_min_dims(value: Any, min_dims: int) -> bool

Check if value is a non-empty list where all items are tensors with at least min_dims dimensions and no NaN values.

Source code in fastvideo/pipelines/stages/validators.py
@staticmethod
def list_of_tensors_with_min_dims(value: Any, min_dims: int) -> bool:
    """Check if value is a non-empty list where all items are tensors with at least min_dims dimensions and no NaN values."""
    if not isinstance(value, list) or len(value) == 0:
        return False
    for item in value:
        if not isinstance(item, torch.Tensor):
            return False
        if item.dim() < min_dims:
            return False
        if torch.isnan(item).any().item():
            return False
    return True
fastvideo.pipelines.stages.validators.StageValidators.min_dims staticmethod
min_dims(min_dims: int) -> Callable[[Any], bool]

Return a validator that checks if tensor has at least min_dims dimensions and no NaN values.

Source code in fastvideo/pipelines/stages/validators.py
@staticmethod
def min_dims(min_dims: int) -> Callable[[Any], bool]:
    """Return a validator that checks if tensor has at least min_dims dimensions and no NaN values."""

    def validator(value: Any) -> bool:
        return StageValidators.tensor_min_dims(value, min_dims)

    return validator
fastvideo.pipelines.stages.validators.StageValidators.non_negative_float staticmethod
non_negative_float(value: Any) -> bool

Check if value is a non-negative float.

Source code in fastvideo/pipelines/stages/validators.py
@staticmethod
def non_negative_float(value: Any) -> bool:
    """Check if value is a non-negative float."""
    return isinstance(value, int | float) and value >= 0
fastvideo.pipelines.stages.validators.StageValidators.none_or_list staticmethod
none_or_list(value: Any) -> bool

Check if value is None or a list.

Source code in fastvideo/pipelines/stages/validators.py
@staticmethod
def none_or_list(value: Any) -> bool:
    """Check if value is None or a list."""
    return value is None or isinstance(value, list)
fastvideo.pipelines.stages.validators.StageValidators.none_or_positive_int staticmethod
none_or_positive_int(value: Any) -> bool

Check if value is None or a positive integer.

Source code in fastvideo/pipelines/stages/validators.py
@staticmethod
def none_or_positive_int(value: Any) -> bool:
    """Check if value is None or a positive integer."""
    return value is None or (isinstance(value, int) and value > 0)
fastvideo.pipelines.stages.validators.StageValidators.none_or_tensor staticmethod
none_or_tensor(value: Any) -> bool

Check if value is None or a tensor without NaN values.

Source code in fastvideo/pipelines/stages/validators.py
@staticmethod
def none_or_tensor(value: Any) -> bool:
    """Check if value is None or a tensor without NaN values."""
    if value is None:
        return True
    if not isinstance(value, torch.Tensor):
        return False
    return not torch.isnan(value).any().item()
fastvideo.pipelines.stages.validators.StageValidators.none_or_tensor_with_dims staticmethod
none_or_tensor_with_dims(dims: int) -> Callable[[Any], bool]

Return a validator that checks if value is None or a tensor with specific dimensions and no NaN values.

Source code in fastvideo/pipelines/stages/validators.py
@staticmethod
def none_or_tensor_with_dims(dims: int) -> Callable[[Any], bool]:
    """Return a validator that checks if value is None or a tensor with specific dimensions and no NaN values."""

    def validator(value: Any) -> bool:
        if value is None:
            return True
        if not isinstance(value, torch.Tensor):
            return False
        if value.dim() != dims:
            return False
        return not torch.isnan(value).any().item()

    return validator
fastvideo.pipelines.stages.validators.StageValidators.not_none staticmethod
not_none(value: Any) -> bool

Check if value is not None.

Source code in fastvideo/pipelines/stages/validators.py
@staticmethod
def not_none(value: Any) -> bool:
    """Check if value is not None."""
    return value is not None
fastvideo.pipelines.stages.validators.StageValidators.positive_float staticmethod
positive_float(value: Any) -> bool

Check if value is a positive float.

Source code in fastvideo/pipelines/stages/validators.py
@staticmethod
def positive_float(value: Any) -> bool:
    """Check if value is a positive float."""
    return isinstance(value, int | float) and value > 0
fastvideo.pipelines.stages.validators.StageValidators.positive_int staticmethod
positive_int(value: Any) -> bool

Check if value is a positive integer.

Source code in fastvideo/pipelines/stages/validators.py
@staticmethod
def positive_int(value: Any) -> bool:
    """Check if value is a positive integer."""
    return isinstance(value, int) and value > 0
fastvideo.pipelines.stages.validators.StageValidators.positive_int_divisible staticmethod
positive_int_divisible(divisor: int) -> Callable[[Any], bool]

Return a validator that checks if value is a positive integer divisible by divisor.

Source code in fastvideo/pipelines/stages/validators.py
@staticmethod
def positive_int_divisible(divisor: int) -> Callable[[Any], bool]:
    """Return a validator that checks if value is a positive integer divisible by divisor."""

    def validator(value: Any) -> bool:
        return (isinstance(value, int) and value > 0
                and StageValidators.divisible_by(value, divisor))

    return validator
fastvideo.pipelines.stages.validators.StageValidators.string_not_empty staticmethod
string_not_empty(value: Any) -> bool

Check if value is a non-empty string.

Source code in fastvideo/pipelines/stages/validators.py
@staticmethod
def string_not_empty(value: Any) -> bool:
    """Check if value is a non-empty string."""
    return isinstance(value, str) and len(value.strip()) > 0
fastvideo.pipelines.stages.validators.StageValidators.string_or_list_strings staticmethod
string_or_list_strings(value: Any) -> bool

Check if value is a string or list of strings.

Source code in fastvideo/pipelines/stages/validators.py
@staticmethod
def string_or_list_strings(value: Any) -> bool:
    """Check if value is a string or list of strings."""
    if isinstance(value, str):
        return True
    if isinstance(value, list):
        return all(isinstance(item, str) for item in value)
    return False
fastvideo.pipelines.stages.validators.StageValidators.tensor_min_dims staticmethod
tensor_min_dims(value: Any, min_dims: int) -> bool

Check if value is a tensor with at least min_dims dimensions and no NaN values.

Source code in fastvideo/pipelines/stages/validators.py
@staticmethod
def tensor_min_dims(value: Any, min_dims: int) -> bool:
    """Check if value is a tensor with at least min_dims dimensions and no NaN values."""
    if not isinstance(value, torch.Tensor):
        return False
    if value.dim() < min_dims:
        return False
    return not torch.isnan(value).any().item()
fastvideo.pipelines.stages.validators.StageValidators.tensor_shape_matches staticmethod
tensor_shape_matches(value: Any, expected_shape: tuple) -> bool

Check if tensor shape matches expected shape (None for any size) and no NaN values.

Source code in fastvideo/pipelines/stages/validators.py
@staticmethod
def tensor_shape_matches(value: Any, expected_shape: tuple) -> bool:
    """Check if tensor shape matches expected shape (None for any size) and no NaN values."""
    if not isinstance(value, torch.Tensor):
        return False
    if len(value.shape) != len(expected_shape):
        return False
    for actual, expected in zip(value.shape, expected_shape, strict=True):
        if expected is not None and actual != expected:
            return False
    return not torch.isnan(value).any().item()
fastvideo.pipelines.stages.validators.StageValidators.tensor_with_dims staticmethod
tensor_with_dims(value: Any, dims: int) -> bool

Check if value is a tensor with specific dimensions and no NaN values.

Source code in fastvideo/pipelines/stages/validators.py
@staticmethod
def tensor_with_dims(value: Any, dims: int) -> bool:
    """Check if value is a tensor with specific dimensions and no NaN values."""
    if not isinstance(value, torch.Tensor):
        return False
    if value.dim() != dims:
        return False
    return not torch.isnan(value).any().item()
fastvideo.pipelines.stages.validators.StageValidators.with_dims staticmethod
with_dims(dims: int) -> Callable[[Any], bool]

Return a validator that checks if tensor has specific dimensions and no NaN values.

Source code in fastvideo/pipelines/stages/validators.py
@staticmethod
def with_dims(dims: int) -> Callable[[Any], bool]:
    """Return a validator that checks if tensor has specific dimensions and no NaN values."""

    def validator(value: Any) -> bool:
        return StageValidators.tensor_with_dims(value, dims)

    return validator
fastvideo.pipelines.stages.validators.ValidationFailure
ValidationFailure(validator_name: str, actual_value: Any, expected: str | None = None, error_msg: str | None = None)

Details about a specific validation failure.

Source code in fastvideo/pipelines/stages/validators.py
def __init__(self,
             validator_name: str,
             actual_value: Any,
             expected: str | None = None,
             error_msg: str | None = None):
    self.validator_name = validator_name
    self.actual_value = actual_value
    self.expected = expected
    self.error_msg = error_msg
fastvideo.pipelines.stages.validators.VerificationResult
VerificationResult()

Wrapper class for stage verification results.

Source code in fastvideo/pipelines/stages/validators.py
def __init__(self) -> None:
    self._checks: dict[str, bool] = {}
    self._failures: dict[str, list[ValidationFailure]] = {}
Functions
fastvideo.pipelines.stages.validators.VerificationResult.add_check
add_check(field_name: str, value: Any, validators: Callable[[Any], bool] | list[Callable[[Any], bool]]) -> VerificationResult

Add a validation check for a field.

Parameters:

Name Type Description Default
field_name str

Name of the field being checked

required
value Any

The actual value to validate

required
validators Callable[[Any], bool] | list[Callable[[Any], bool]]

Single validation function or list of validation functions. Each function will be called with the value as its first argument.

required

Returns:

Type Description
VerificationResult

Self for method chaining

Examples:

Single validator

result.add_check("tensor", my_tensor, V.is_tensor)

Multiple validators (all must pass)

result.add_check("latents", batch.latents, [V.is_tensor, V.with_dims(5)])

Using partial functions for parameters

result.add_check("height", batch.height, [V.not_none, V.divisible(8)])

Source code in fastvideo/pipelines/stages/validators.py
def add_check(
    self, field_name: str, value: Any,
    validators: Callable[[Any], bool] | list[Callable[[Any], bool]]
) -> 'VerificationResult':
    """
    Add a validation check for a field.

    Args:
        field_name: Name of the field being checked
        value: The actual value to validate
        validators: Single validation function or list of validation functions.
                   Each function will be called with the value as its first argument.

    Returns:
        Self for method chaining

    Examples:
        # Single validator
        result.add_check("tensor", my_tensor, V.is_tensor)

        # Multiple validators (all must pass)
        result.add_check("latents", batch.latents, [V.is_tensor, V.with_dims(5)])

        # Using partial functions for parameters
        result.add_check("height", batch.height, [V.not_none, V.divisible(8)])
    """
    if not isinstance(validators, list):
        validators = [validators]

    failures = []
    all_passed = True

    # Apply all validators and collect detailed failure info
    for validator in validators:
        try:
            passed = validator(value)
            if not passed:
                all_passed = False
                failure = self._create_validation_failure(validator, value)
                failures.append(failure)
        except Exception as e:
            # If any validator raises an exception, consider the check failed
            all_passed = False
            validator_name = getattr(validator, '__name__', str(validator))
            failure = ValidationFailure(
                validator_name=validator_name,
                actual_value=value,
                error_msg=f"Exception during validation: {str(e)}")
            failures.append(failure)

    self._checks[field_name] = all_passed
    if not all_passed:
        self._failures[field_name] = failures

    return self
fastvideo.pipelines.stages.validators.VerificationResult.get_detailed_failures
get_detailed_failures() -> dict[str, list[ValidationFailure]]

Get detailed failure information for each failed field.

Source code in fastvideo/pipelines/stages/validators.py
def get_detailed_failures(self) -> dict[str, list[ValidationFailure]]:
    """Get detailed failure information for each failed field."""
    return self._failures.copy()
fastvideo.pipelines.stages.validators.VerificationResult.get_failed_fields
get_failed_fields() -> list[str]

Get list of fields that failed validation.

Source code in fastvideo/pipelines/stages/validators.py
def get_failed_fields(self) -> list[str]:
    """Get list of fields that failed validation."""
    return [field for field, passed in self._checks.items() if not passed]
fastvideo.pipelines.stages.validators.VerificationResult.get_failure_summary
get_failure_summary() -> str

Get a comprehensive summary of all validation failures.

Source code in fastvideo/pipelines/stages/validators.py
def get_failure_summary(self) -> str:
    """Get a comprehensive summary of all validation failures."""
    if self.is_valid():
        return "All validations passed"

    summary_parts = []
    for field_name, failures in self._failures.items():
        field_summary = f"\n  Field '{field_name}':"
        for i, failure in enumerate(failures, 1):
            field_summary += f"\n    {i}. {failure}"
        summary_parts.append(field_summary)

    return "Validation failures:" + "".join(summary_parts)
fastvideo.pipelines.stages.validators.VerificationResult.is_valid
is_valid() -> bool

Check if all validations passed.

Source code in fastvideo/pipelines/stages/validators.py
def is_valid(self) -> bool:
    """Check if all validations passed."""
    return all(self._checks.values())
fastvideo.pipelines.stages.validators.VerificationResult.to_dict
to_dict() -> dict

Convert to dictionary for backward compatibility.

Source code in fastvideo/pipelines/stages/validators.py
def to_dict(self) -> dict:
    """Convert to dictionary for backward compatibility."""
    return self._checks.copy()