Skip to content

stepvideo

Modules

fastvideo.pipelines.basic.stepvideo.stepvideo_pipeline

Hunyuan video diffusion pipeline implementation.

This module contains an implementation of the Hunyuan video diffusion pipeline using the modular pipeline architecture.

Classes

fastvideo.pipelines.basic.stepvideo.stepvideo_pipeline.StepVideoPipeline
StepVideoPipeline(*args, **kwargs)

Bases: LoRAPipeline, ComposedPipelineBase

Source code in fastvideo/pipelines/lora_pipeline.py
def __init__(self, *args, **kwargs) -> None:
    super().__init__(*args, **kwargs)
    self.device = get_local_torch_device()
    self.exclude_lora_layers = self.modules[
        "transformer"].config.arch_config.exclude_lora_layers
    self.lora_target_modules = self.fastvideo_args.lora_target_modules
    self.lora_path = self.fastvideo_args.lora_path
    self.lora_nickname = self.fastvideo_args.lora_nickname
    self.training_mode = self.fastvideo_args.training_mode
    if self.training_mode and getattr(self.fastvideo_args, "lora_training",
                                      False):
        assert isinstance(self.fastvideo_args, TrainingArgs)
        if self.fastvideo_args.lora_alpha is None:
            self.fastvideo_args.lora_alpha = self.fastvideo_args.lora_rank
        self.lora_rank = self.fastvideo_args.lora_rank  # type: ignore
        self.lora_alpha = self.fastvideo_args.lora_alpha  # type: ignore
        logger.info("Using LoRA training with rank %d and alpha %d",
                    self.lora_rank, self.lora_alpha)
        if self.lora_target_modules is None:
            self.lora_target_modules = [
                "q_proj", "k_proj", "v_proj", "o_proj", "to_q", "to_k",
                "to_v", "to_out", "to_qkv"
            ]
        self.convert_to_lora_layers()
    # Inference
    elif not self.training_mode and self.lora_path is not None:
        self.convert_to_lora_layers()
        self.set_lora_adapter(
            self.lora_nickname,  # type: ignore
            self.lora_path)  # type: ignore
Functions
fastvideo.pipelines.basic.stepvideo.stepvideo_pipeline.StepVideoPipeline.create_pipeline_stages
create_pipeline_stages(fastvideo_args: FastVideoArgs)

Set up pipeline stages with proper dependency injection.

Source code in fastvideo/pipelines/basic/stepvideo/stepvideo_pipeline.py
def create_pipeline_stages(self, fastvideo_args: FastVideoArgs):
    """Set up pipeline stages with proper dependency injection."""

    self.add_stage(stage_name="input_validation_stage",
                   stage=InputValidationStage())

    self.add_stage(stage_name="prompt_encoding_stage",
                   stage=StepvideoPromptEncodingStage(
                       stepllm=self.get_module("text_encoder"),
                       clip=self.get_module("text_encoder_2"),
                   ))

    self.add_stage(stage_name="timestep_preparation_stage",
                   stage=TimestepPreparationStage(
                       scheduler=self.get_module("scheduler")))

    self.add_stage(stage_name="latent_preparation_stage",
                   stage=LatentPreparationStage(
                       scheduler=self.get_module("scheduler"),
                       transformer=self.get_module("transformer"),
                   ))

    self.add_stage(stage_name="denoising_stage",
                   stage=DenoisingStage(
                       transformer=self.get_module("transformer"),
                       scheduler=self.get_module("scheduler")))

    self.add_stage(stage_name="decoding_stage",
                   stage=DecodingStage(vae=self.get_module("vae")))
fastvideo.pipelines.basic.stepvideo.stepvideo_pipeline.StepVideoPipeline.initialize_pipeline
initialize_pipeline(fastvideo_args: FastVideoArgs)

Initialize the pipeline.

Source code in fastvideo/pipelines/basic/stepvideo/stepvideo_pipeline.py
def initialize_pipeline(self, fastvideo_args: FastVideoArgs):
    """
    Initialize the pipeline.
    """
    target_device = get_local_torch_device()
    llm_dir = os.path.join(self.model_path, "step_llm")
    clip_dir = os.path.join(self.model_path, "hunyuan_clip")
    text_enc = self.build_llm(llm_dir, target_device)
    clip_enc = self.build_clip(clip_dir, target_device)
    self.add_module("text_encoder", text_enc)
    self.add_module("text_encoder_2", clip_enc)
    lib_path = (
        os.path.join(
            fastvideo_args.model_path,
            'lib/liboptimus_ths-torch2.5-cu124.cpython-310-x86_64-linux-gnu.so'
        ) if os.path.isdir(fastvideo_args.model_path)  # local checkout
        else hf_hub_download(
            repo_id=fastvideo_args.model_path,
            filename=
            'lib/liboptimus_ths-torch2.5-cu124.cpython-310-x86_64-linux-gnu.so'
        ))
    torch.ops.load_library(lib_path)
fastvideo.pipelines.basic.stepvideo.stepvideo_pipeline.StepVideoPipeline.load_modules
load_modules(fastvideo_args: FastVideoArgs) -> dict[str, Any]

Load the modules from the config.

Source code in fastvideo/pipelines/basic/stepvideo/stepvideo_pipeline.py
def load_modules(self, fastvideo_args: FastVideoArgs) -> dict[str, Any]:
    """
    Load the modules from the config.
    """
    model_index = self._load_config(self.model_path)
    logger.info("Loading pipeline modules from config: %s", model_index)

    # remove keys that are not pipeline modules
    model_index.pop("_class_name")
    model_index.pop("_diffusers_version")

    # some sanity checks
    assert len(
        model_index
    ) > 1, "model_index.json must contain at least one pipeline module"

    required_modules = ["transformer", "scheduler", "vae"]
    for module_name in required_modules:
        if module_name not in model_index:
            raise ValueError(
                f"model_index.json must contain a {module_name} module")
    logger.info("Diffusers config passed sanity checks")

    # all the component models used by the pipeline
    modules = {}
    for module_name, (transformers_or_diffusers,
                      architecture) in model_index.items():
        component_model_path = os.path.join(self.model_path, module_name)
        module = PipelineComponentLoader.load_module(
            module_name=module_name,
            component_model_path=component_model_path,
            transformers_or_diffusers=transformers_or_diffusers,
            fastvideo_args=fastvideo_args,
        )
        logger.info("Loaded module %s from %s", module_name,
                    component_model_path)

        if module_name in modules:
            logger.warning("Overwriting module %s", module_name)
        modules[module_name] = module

    required_modules = self.required_config_modules
    # Check if all required modules were loaded
    for module_name in required_modules:
        if module_name not in modules or modules[module_name] is None:
            raise ValueError(
                f"Required module {module_name} was not loaded properly")

    return modules

Functions