Skip to content

pipeline_batch_info

Data structures for functional pipeline processing.

This module defines the dataclasses used to pass state between pipeline components in a functional manner, reducing the need for explicit parameter passing.

Classes

fastvideo.pipelines.pipeline_batch_info.ForwardBatch dataclass

ForwardBatch(data_type: str, generator: Generator | list[Generator] | None = None, image_path: str | None = None, image_embeds: list[Tensor] = list(), pil_image: Tensor | Image | None = None, preprocessed_image: Tensor | None = None, prompt: str | list[str] | None = None, negative_prompt: str | list[str] | None = None, prompt_path: str | None = None, output_path: str = 'outputs/', output_video_name: str | None = None, video_path: str | None = None, video_latent: Tensor | None = None, prompt_embeds: list[Tensor] = list(), negative_prompt_embeds: list[Tensor] | None = None, prompt_attention_mask: list[Tensor] | None = None, negative_attention_mask: list[Tensor] | None = None, clip_embedding_pos: list[Tensor] | None = None, clip_embedding_neg: list[Tensor] | None = None, max_sequence_length: int | None = None, prompt_template: dict[str, Any] | None = None, do_classifier_free_guidance: bool = False, batch_size: int | None = None, num_videos_per_prompt: int = 1, seed: int | None = None, seeds: list[int] | None = None, is_prompt_processed: bool = False, latents: Tensor | None = None, raw_latent_shape: Tensor | None = None, noise_pred: Tensor | None = None, image_latent: Tensor | None = None, height_latents: list[int] | int | None = None, width_latents: list[int] | int | None = None, num_frames: list[int] | int = 1, num_frames_round_down: bool = False, height: list[int] | int | None = None, width: list[int] | int | None = None, fps: list[int] | int | None = None, timesteps: Tensor | None = None, timestep: Tensor | float | int | None = None, step_index: int | None = None, boundary_ratio: float | None = None, num_inference_steps: int = 50, guidance_scale: float = 1.0, guidance_scale_2: float | None = None, guidance_rescale: float = 0.0, eta: float = 0.0, sigmas: list[float] | None = None, n_tokens: int | None = None, extra_step_kwargs: dict[str, Any] = dict(), modules: dict[str, Any] = dict(), output: Tensor | None = None, return_trajectory_latents: bool = False, return_trajectory_decoded: bool = False, trajectory_timesteps: list[Tensor] | None = None, trajectory_latents: Tensor | None = None, trajectory_decoded: list[Tensor] | None = None, extra: dict[str, Any] = dict(), save_video: bool = True, return_frames: bool = False, enable_teacache: bool = False, teacache_params: TeaCacheParams | WanTeaCacheParams | None = None, STA_param: list | None = None, is_cfg_negative: bool = False, mask_search_final_result_pos: list[list] | None = None, mask_search_final_result_neg: list[list] | None = None, VSA_sparsity: float = 0.0, logging_info: PipelineLoggingInfo = PipelineLoggingInfo())

Complete state passed through the pipeline execution.

This dataclass contains all information needed during the diffusion pipeline execution, allowing methods to update specific components without needing to manage numerous individual parameters.

Functions

fastvideo.pipelines.pipeline_batch_info.ForwardBatch.__post_init__
__post_init__()

Initialize dependent fields after dataclass initialization.

Source code in fastvideo/pipelines/pipeline_batch_info.py
def __post_init__(self):
    """Initialize dependent fields after dataclass initialization."""

    # Set do_classifier_free_guidance based on guidance scale and negative prompt
    if self.guidance_scale > 1.0:
        self.do_classifier_free_guidance = True
    if self.negative_prompt_embeds is None:
        self.negative_prompt_embeds = []
    if self.guidance_scale_2 is None:
        self.guidance_scale_2 = self.guidance_scale

fastvideo.pipelines.pipeline_batch_info.PipelineLoggingInfo

PipelineLoggingInfo()

Simple approach using OrderedDict to track stage metrics.

Source code in fastvideo/pipelines/pipeline_batch_info.py
def __init__(self):
    # OrderedDict preserves insertion order and allows easy access
    self.stages: OrderedDict[str, dict[str, Any]] = OrderedDict()

Functions

fastvideo.pipelines.pipeline_batch_info.PipelineLoggingInfo.add_stage_execution_time
add_stage_execution_time(stage_name: str, execution_time: float)

Add execution time for a stage.

Source code in fastvideo/pipelines/pipeline_batch_info.py
def add_stage_execution_time(self, stage_name: str, execution_time: float):
    """Add execution time for a stage."""
    if stage_name not in self.stages:
        self.stages[stage_name] = {}
    self.stages[stage_name]['execution_time'] = execution_time
    self.stages[stage_name]['timestamp'] = time.time()
fastvideo.pipelines.pipeline_batch_info.PipelineLoggingInfo.add_stage_metric
add_stage_metric(stage_name: str, metric_name: str, value: Any)

Add any metric for a stage.

Source code in fastvideo/pipelines/pipeline_batch_info.py
def add_stage_metric(self, stage_name: str, metric_name: str, value: Any):
    """Add any metric for a stage."""
    if stage_name not in self.stages:
        self.stages[stage_name] = {}
    self.stages[stage_name][metric_name] = value
fastvideo.pipelines.pipeline_batch_info.PipelineLoggingInfo.get_execution_order
get_execution_order() -> list[str]

Get stages in execution order.

Source code in fastvideo/pipelines/pipeline_batch_info.py
def get_execution_order(self) -> list[str]:
    """Get stages in execution order."""
    return list(self.stages.keys())
fastvideo.pipelines.pipeline_batch_info.PipelineLoggingInfo.get_stage_info
get_stage_info(stage_name: str) -> dict[str, Any]

Get all info for a specific stage.

Source code in fastvideo/pipelines/pipeline_batch_info.py
def get_stage_info(self, stage_name: str) -> dict[str, Any]:
    """Get all info for a specific stage."""
    return self.stages.get(stage_name, {})
fastvideo.pipelines.pipeline_batch_info.PipelineLoggingInfo.get_total_execution_time
get_total_execution_time() -> float

Get total pipeline execution time.

Source code in fastvideo/pipelines/pipeline_batch_info.py
def get_total_execution_time(self) -> float:
    """Get total pipeline execution time."""
    return sum(
        stage.get('execution_time', 0) for stage in self.stages.values())