Skip to content

base

Base classes for pipeline stages.

This module defines the abstract base classes for pipeline stages that can be composed to create complete diffusion pipelines.

Classes

fastvideo.pipelines.stages.base.PipelineStage

Bases: ABC

Abstract base class for all pipeline stages.

A pipeline stage represents a discrete step in the diffusion process that can be composed with other stages to create a complete pipeline. Each stage is responsible for a specific part of the process, such as prompt encoding, latent preparation, etc.

Attributes

fastvideo.pipelines.stages.base.PipelineStage.device property
device: device

Get the device for this stage.

Functions

fastvideo.pipelines.stages.base.PipelineStage.__call__
__call__(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> ForwardBatch

Execute the stage's processing on the batch with optional verification and logging. Should not be overridden by subclasses.

Parameters:

Name Type Description Default
batch ForwardBatch

The current batch information.

required
fastvideo_args FastVideoArgs

The inference arguments.

required

Returns:

Type Description
ForwardBatch

The updated batch information after this stage's processing.

Source code in fastvideo/pipelines/stages/base.py
def __call__(
    self,
    batch: ForwardBatch,
    fastvideo_args: FastVideoArgs,
) -> ForwardBatch:
    """
    Execute the stage's processing on the batch with optional verification and logging.
    Should not be overridden by subclasses.

    Args:
        batch: The current batch information.
        fastvideo_args: The inference arguments.

    Returns:
        The updated batch information after this stage's processing.
    """
    stage_name = self.__class__.__name__

    # Check if verification is enabled (simple approach for prototype)
    enable_verification = getattr(fastvideo_args,
                                  'enable_stage_verification', False)

    if enable_verification:
        # Pre-execution input verification
        try:
            input_result = self.verify_input(batch, fastvideo_args)
            self._run_verification(input_result, stage_name, "input")
        except Exception as e:
            logger.error("Input verification failed for %s: %s", stage_name,
                         str(e))
            raise

    # Execute the actual stage logic
    if envs.FASTVIDEO_STAGE_LOGGING:
        logger.info("[%s] Starting execution", stage_name)
        start_time = time.perf_counter()

        try:
            result = self.forward(batch, fastvideo_args)
            execution_time = time.perf_counter() - start_time
            logger.info("[%s] Execution completed in %s ms", stage_name,
                        execution_time * 1000)
            batch.logging_info.add_stage_execution_time(
                stage_name, execution_time)
        except Exception as e:
            execution_time = time.perf_counter() - start_time
            logger.error("[%s] Error during execution after %s ms: %s",
                         stage_name, execution_time * 1000, e)
            logger.error("[%s] Traceback: %s", stage_name,
                         traceback.format_exc())
            raise
    else:
        # Direct execution (current behavior)
        result = self.forward(batch, fastvideo_args)

    if enable_verification:
        # Post-execution output verification
        try:
            output_result = self.verify_output(result, fastvideo_args)
            self._run_verification(output_result, stage_name, "output")
        except Exception as e:
            logger.error("Output verification failed for %s: %s",
                         stage_name, str(e))
            raise

    return result
fastvideo.pipelines.stages.base.PipelineStage.forward abstractmethod
forward(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> ForwardBatch

Forward pass of the stage's processing.

This method should be implemented by subclasses to provide the forward processing logic for the stage.

Parameters:

Name Type Description Default
batch ForwardBatch

The current batch information.

required
fastvideo_args FastVideoArgs

The inference arguments.

required

Returns:

Type Description
ForwardBatch

The updated batch information after this stage's processing.

Source code in fastvideo/pipelines/stages/base.py
@abstractmethod
def forward(
    self,
    batch: ForwardBatch,
    fastvideo_args: FastVideoArgs,
) -> ForwardBatch:
    """
    Forward pass of the stage's processing.

    This method should be implemented by subclasses to provide the forward
    processing logic for the stage.

    Args:
        batch: The current batch information.
        fastvideo_args: The inference arguments.

    Returns:
        The updated batch information after this stage's processing.
    """
    raise NotImplementedError
fastvideo.pipelines.stages.base.PipelineStage.set_logging
set_logging(enable: bool)

Enable or disable logging for this stage.

Parameters:

Name Type Description Default
enable bool

Whether to enable logging.

required
Source code in fastvideo/pipelines/stages/base.py
def set_logging(self, enable: bool):
    """
    Enable or disable logging for this stage.

    Args:
        enable: Whether to enable logging.
    """
    self._enable_logging = enable
fastvideo.pipelines.stages.base.PipelineStage.verify_input
verify_input(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult

Verify the input for the stage.

Example

from fastvideo.pipelines.stages.validators import V, VerificationResult

def verify_input(self, batch, fastvideo_args): result = VerificationResult() result.add_check("height", batch.height, V.positive_int_divisible(8)) result.add_check("width", batch.width, V.positive_int_divisible(8)) result.add_check("image_latent", batch.image_latent, V.is_tensor) return result

Parameters:

Name Type Description Default
batch ForwardBatch

The current batch information.

required
fastvideo_args FastVideoArgs

The inference arguments.

required

Returns:

Type Description
VerificationResult

A VerificationResult containing the verification status.

Source code in fastvideo/pipelines/stages/base.py
def verify_input(self, batch: ForwardBatch,
                 fastvideo_args: FastVideoArgs) -> VerificationResult:
    """
    Verify the input for the stage.

    Example:
        from fastvideo.pipelines.stages.validators import V, VerificationResult

        def verify_input(self, batch, fastvideo_args):
            result = VerificationResult()
            result.add_check("height", batch.height, V.positive_int_divisible(8))
            result.add_check("width", batch.width, V.positive_int_divisible(8))
            result.add_check("image_latent", batch.image_latent, V.is_tensor)
            return result

    Args:
        batch: The current batch information.
        fastvideo_args: The inference arguments.

    Returns:
        A VerificationResult containing the verification status.

    """
    # Default implementation - no verification
    return VerificationResult()
fastvideo.pipelines.stages.base.PipelineStage.verify_output
verify_output(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult

Verify the output for the stage.

Parameters:

Name Type Description Default
batch ForwardBatch

The current batch information.

required
fastvideo_args FastVideoArgs

The inference arguments.

required

Returns:

Type Description
VerificationResult

A VerificationResult containing the verification status.

Source code in fastvideo/pipelines/stages/base.py
def verify_output(self, batch: ForwardBatch,
                  fastvideo_args: FastVideoArgs) -> VerificationResult:
    """
    Verify the output for the stage.

    Args:
        batch: The current batch information.
        fastvideo_args: The inference arguments.

    Returns:
        A VerificationResult containing the verification status.
    """
    # Default implementation - no verification
    return VerificationResult()

fastvideo.pipelines.stages.base.StageVerificationError

Bases: Exception

Exception raised when stage verification fails.

Functions