Skip to content

composed_pipeline_base

Base class for composed pipelines.

This module defines the base class for pipelines that are composed of multiple stages.

Classes

fastvideo.pipelines.composed_pipeline_base.ComposedPipelineBase

ComposedPipelineBase(model_path: str, fastvideo_args: FastVideoArgs | TrainingArgs, required_config_modules: list[str] | None = None, loaded_modules: dict[str, Module] | None = None)

Bases: ABC

Base class for pipelines composed of multiple stages.

This class provides the framework for creating pipelines by composing multiple stages together. Each stage is responsible for a specific part of the diffusion process, and the pipeline orchestrates the execution of these stages.

Initialize the pipeline. After init, the pipeline should be ready to use. The pipeline should be stateless and not hold any batch state.

Source code in fastvideo/pipelines/composed_pipeline_base.py
def __init__(self,
             model_path: str,
             fastvideo_args: FastVideoArgs | TrainingArgs,
             required_config_modules: list[str] | None = None,
             loaded_modules: dict[str, torch.nn.Module] | None = None):
    """
    Initialize the pipeline. After __init__, the pipeline should be ready to
    use. The pipeline should be stateless and not hold any batch state.
    """
    self.fastvideo_args = fastvideo_args

    self.model_path: str = model_path
    self._stages: list[PipelineStage] = []
    self._stage_name_mapping: dict[str, PipelineStage] = {}

    if required_config_modules is not None:
        self._required_config_modules = required_config_modules

    if self._required_config_modules is None:
        raise NotImplementedError(
            "Subclass must set _required_config_modules")

    maybe_init_distributed_environment_and_model_parallel(
        fastvideo_args.tp_size, fastvideo_args.sp_size)

    # Torch profiler. Enabled and configured through env vars:
    # FASTVIDEO_TORCH_PROFILER_DIR=/path/to/save/trace
    trace_dir = envs.FASTVIDEO_TORCH_PROFILER_DIR
    self.profiler_controller = get_or_create_profiler(trace_dir)
    self.profiler = self.profiler_controller.profiler

    self.local_rank = get_world_group().local_rank

    # Load modules directly in initialization
    logger.info("Loading pipeline modules...")
    with self.profiler_controller.region("profiler_region_model_loading"):
        self.modules = self.load_modules(fastvideo_args, loaded_modules)

Attributes

fastvideo.pipelines.composed_pipeline_base.ComposedPipelineBase.required_config_modules property
required_config_modules: list[str]

List of modules that are required by the pipeline. The names should match the diffusers directory and model_index.json file. These modules will be loaded using the PipelineComponentLoader and made available in the modules dictionary. Access these modules using the get_module method.

class ConcretePipeline(ComposedPipelineBase): _required_config_modules = ["vae", "text_encoder", "transformer", "scheduler", "tokenizer"]

@property
def required_config_modules(self):
    return self._required_config_modules
fastvideo.pipelines.composed_pipeline_base.ComposedPipelineBase.stages property

List of stages in the pipeline.

Functions

fastvideo.pipelines.composed_pipeline_base.ComposedPipelineBase.create_pipeline_stages abstractmethod
create_pipeline_stages(fastvideo_args: FastVideoArgs)

Create the inference pipeline stages.

Source code in fastvideo/pipelines/composed_pipeline_base.py
@abstractmethod
def create_pipeline_stages(self, fastvideo_args: FastVideoArgs):
    """
    Create the inference pipeline stages.
    """
    raise NotImplementedError
fastvideo.pipelines.composed_pipeline_base.ComposedPipelineBase.create_training_stages
create_training_stages(training_args: TrainingArgs)

Create the training pipeline stages.

Source code in fastvideo/pipelines/composed_pipeline_base.py
def create_training_stages(self, training_args: TrainingArgs):
    """
    Create the training pipeline stages.
    """
    raise NotImplementedError
fastvideo.pipelines.composed_pipeline_base.ComposedPipelineBase.forward
forward(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> ForwardBatch

Generate a video or image using the pipeline.

Parameters:

Name Type Description Default
batch ForwardBatch

The batch to generate from.

required
fastvideo_args FastVideoArgs

The inference arguments.

required

Returns: ForwardBatch: The batch with the generated video or image.

Source code in fastvideo/pipelines/composed_pipeline_base.py
@torch.no_grad()
def forward(
    self,
    batch: ForwardBatch,
    fastvideo_args: FastVideoArgs,
) -> ForwardBatch:
    """
    Generate a video or image using the pipeline.

    Args:
        batch: The batch to generate from.
        fastvideo_args: The inference arguments.
    Returns:
        ForwardBatch: The batch with the generated video or image.
    """
    if not self.post_init_called:
        self.post_init()

    # Execute each stage
    logger.info("Running pipeline stages: %s",
                self._stage_name_mapping.keys())
    # logger.info("Batch: %s", batch)
    for stage in self.stages:
        batch = stage(batch, fastvideo_args)

    # Return the output
    return batch
fastvideo.pipelines.composed_pipeline_base.ComposedPipelineBase.from_pretrained classmethod
from_pretrained(model_path: str, device: str | None = None, torch_dtype: dtype | None = None, pipeline_config: str | PipelineConfig | None = None, args: Namespace | None = None, required_config_modules: list[str] | None = None, loaded_modules: dict[str, Module] | None = None, **kwargs) -> ComposedPipelineBase

Load a pipeline from a pretrained model. loaded_modules: Optional[Dict[str, torch.nn.Module]] = None, If provided, loaded_modules will be used instead of loading from config/pretrained weights.

Source code in fastvideo/pipelines/composed_pipeline_base.py
@classmethod
def from_pretrained(cls,
                    model_path: str,
                    device: str | None = None,
                    torch_dtype: torch.dtype | None = None,
                    pipeline_config: str | PipelineConfig | None = None,
                    args: argparse.Namespace | None = None,
                    required_config_modules: list[str] | None = None,
                    loaded_modules: dict[str, torch.nn.Module]
                    | None = None,
                    **kwargs) -> "ComposedPipelineBase":
    """
    Load a pipeline from a pretrained model.
    loaded_modules: Optional[Dict[str, torch.nn.Module]] = None,
    If provided, loaded_modules will be used instead of loading from config/pretrained weights.
    """
    if args is None or args.inference_mode:

        kwargs['model_path'] = model_path
        fastvideo_args = FastVideoArgs.from_kwargs(**kwargs)
    else:
        assert args is not None, "args must be provided for training mode"
        fastvideo_args = TrainingArgs.from_cli_args(args)
        # TODO(will): fix this so that its not so ugly
        fastvideo_args.model_path = model_path
        for key, value in kwargs.items():
            setattr(fastvideo_args, key, value)

        fastvideo_args.dit_cpu_offload = False
        # we hijack the precision to be the master weight type so that the
        # model is loaded with the correct precision. Subsequently we will
        # use FSDP2's MixedPrecisionPolicy to set the precision for the
        # fwd, bwd, and other operations' precision.
        assert fastvideo_args.pipeline_config.dit_precision == 'fp32', 'only fp32 is supported for training'

    logger.info("fastvideo_args in from_pretrained: %s", fastvideo_args)

    pipe = cls(model_path,
               fastvideo_args,
               required_config_modules=required_config_modules,
               loaded_modules=loaded_modules)
    pipe.post_init()
    return pipe
fastvideo.pipelines.composed_pipeline_base.ComposedPipelineBase.initialize_pipeline
initialize_pipeline(fastvideo_args: FastVideoArgs)

Initialize the pipeline.

Source code in fastvideo/pipelines/composed_pipeline_base.py
def initialize_pipeline(self, fastvideo_args: FastVideoArgs):
    """
    Initialize the pipeline.
    """
    return
fastvideo.pipelines.composed_pipeline_base.ComposedPipelineBase.load_modules
load_modules(fastvideo_args: FastVideoArgs, loaded_modules: dict[str, Module] | None = None) -> dict[str, Any]

Load the modules from the config. loaded_modules: Optional[Dict[str, torch.nn.Module]] = None, If provided, loaded_modules will be used instead of loading from config/pretrained weights.

Source code in fastvideo/pipelines/composed_pipeline_base.py
def load_modules(
    self,
    fastvideo_args: FastVideoArgs,
    loaded_modules: dict[str, torch.nn.Module] | None = None
) -> dict[str, Any]:
    """
    Load the modules from the config.
    loaded_modules: Optional[Dict[str, torch.nn.Module]] = None, 
    If provided, loaded_modules will be used instead of loading from config/pretrained weights.
    """

    model_index = self._load_config(self.model_path)
    logger.info("Loading pipeline modules from config: %s", model_index)

    # remove keys that are not pipeline modules
    model_index.pop("_class_name")
    model_index.pop("_diffusers_version")
    if "boundary_ratio" in model_index and model_index[
            "boundary_ratio"] is not None:
        logger.info(
            "MoE pipeline detected. Adding transformer_2 to self.required_config_modules..."
        )
        self.required_config_modules.append("transformer_2")
        logger.info("MoE pipeline detected. Setting boundary ratio to %s",
                    model_index["boundary_ratio"])
        fastvideo_args.pipeline_config.dit_config.boundary_ratio = model_index[
            "boundary_ratio"]

    model_index.pop("boundary_ratio", None)
    # used by Wan2.2 ti2v
    model_index.pop("expand_timesteps", None)

    # some sanity checks
    assert len(
        model_index
    ) > 1, "model_index.json must contain at least one pipeline module"

    for module_name in self.required_config_modules:
        if module_name not in model_index and module_name in self._extra_config_module_map:
            extra_module_value = self._extra_config_module_map[module_name]
            logger.warning(
                "model_index.json does not contain a %s module, but found {%s: %s} in _extra_config_module_map, adding to model_index.",
                module_name, module_name, extra_module_value)
            if extra_module_value in model_index:
                logger.info("Using module %s for %s", extra_module_value,
                            module_name)
                model_index[module_name] = model_index[extra_module_value]
                continue
            else:
                raise ValueError(
                    f"Required module key: {module_name} value: {model_index.get(module_name)} was not found in loaded modules {model_index.keys()}"
                )

    # all the component models used by the pipeline
    required_modules = self.required_config_modules
    logger.info("Loading required modules: %s", required_modules)

    modules = {}
    for module_name, (transformers_or_diffusers,
                      architecture) in model_index.items():
        if transformers_or_diffusers is None:
            logger.warning(
                "Module %s in model_index.json has null value, removing from required_config_modules",
                module_name)
            if module_name in self.required_config_modules:
                self.required_config_modules.remove(module_name)
            continue
        if module_name not in required_modules:
            logger.info("Skipping module %s", module_name)
            continue
        if loaded_modules is not None and module_name in loaded_modules:
            logger.info("Using module %s already provided", module_name)
            modules[module_name] = loaded_modules[module_name]
            continue

        # we load the module from the extra config module map if it exists
        if module_name in self._extra_config_module_map:
            load_module_name = self._extra_config_module_map[module_name]
        else:
            load_module_name = module_name

        component_model_path = os.path.join(self.model_path,
                                            load_module_name)
        module = PipelineComponentLoader.load_module(
            module_name=load_module_name,
            component_model_path=component_model_path,
            transformers_or_diffusers=transformers_or_diffusers,
            fastvideo_args=fastvideo_args,
        )
        logger.info("Loaded module %s from %s", module_name,
                    component_model_path)

        if module_name in modules:
            logger.warning("Overwriting module %s", module_name)
        modules[module_name] = module

    # Check if all required modules were loaded
    for module_name in required_modules:
        if module_name not in modules or modules[module_name] is None:
            raise ValueError(
                f"Required module key: {module_name} value: {modules.get(module_name)} was not found in loaded modules {modules.keys()}"
            )

    return modules

Functions