Skip to content

wan_pipeline

Wan video diffusion pipeline implementation.

This module contains an implementation of the Wan video diffusion pipeline using the modular pipeline architecture.

Classes

fastvideo.pipelines.basic.wan.wan_pipeline.WanPipeline

WanPipeline(*args, **kwargs)

Bases: LoRAPipeline, ComposedPipelineBase

Wan video diffusion pipeline with LoRA support.

Source code in fastvideo/pipelines/lora_pipeline.py
def __init__(self, *args, **kwargs) -> None:
    super().__init__(*args, **kwargs)
    self.device = get_local_torch_device()
    self.exclude_lora_layers = self.modules[
        "transformer"].config.arch_config.exclude_lora_layers
    self.lora_target_modules = self.fastvideo_args.lora_target_modules
    self.lora_path = self.fastvideo_args.lora_path
    self.lora_nickname = self.fastvideo_args.lora_nickname
    self.training_mode = self.fastvideo_args.training_mode
    if self.training_mode and getattr(self.fastvideo_args, "lora_training",
                                      False):
        assert isinstance(self.fastvideo_args, TrainingArgs)
        if self.fastvideo_args.lora_alpha is None:
            self.fastvideo_args.lora_alpha = self.fastvideo_args.lora_rank
        self.lora_rank = self.fastvideo_args.lora_rank  # type: ignore
        self.lora_alpha = self.fastvideo_args.lora_alpha  # type: ignore
        logger.info("Using LoRA training with rank %d and alpha %d",
                    self.lora_rank, self.lora_alpha)
        if self.lora_target_modules is None:
            self.lora_target_modules = [
                "q_proj", "k_proj", "v_proj", "o_proj", "to_q", "to_k",
                "to_v", "to_out", "to_qkv"
            ]
        self.convert_to_lora_layers()
    # Inference
    elif not self.training_mode and self.lora_path is not None:
        self.convert_to_lora_layers()
        self.set_lora_adapter(
            self.lora_nickname,  # type: ignore
            self.lora_path)  # type: ignore

Functions

fastvideo.pipelines.basic.wan.wan_pipeline.WanPipeline.create_pipeline_stages
create_pipeline_stages(fastvideo_args: FastVideoArgs) -> None

Set up pipeline stages with proper dependency injection.

Source code in fastvideo/pipelines/basic/wan/wan_pipeline.py
def create_pipeline_stages(self, fastvideo_args: FastVideoArgs) -> None:
    """Set up pipeline stages with proper dependency injection."""

    self.add_stage(stage_name="input_validation_stage",
                   stage=InputValidationStage())

    self.add_stage(stage_name="prompt_encoding_stage",
                   stage=TextEncodingStage(
                       text_encoders=[self.get_module("text_encoder")],
                       tokenizers=[self.get_module("tokenizer")],
                   ))

    self.add_stage(stage_name="conditioning_stage",
                   stage=ConditioningStage())

    self.add_stage(stage_name="timestep_preparation_stage",
                   stage=TimestepPreparationStage(
                       scheduler=self.get_module("scheduler")))

    self.add_stage(stage_name="latent_preparation_stage",
                   stage=LatentPreparationStage(
                       scheduler=self.get_module("scheduler"),
                       transformer=self.get_module("transformer", None)))

    self.add_stage(stage_name="denoising_stage",
                   stage=DenoisingStage(
                       transformer=self.get_module("transformer"),
                       transformer_2=self.get_module("transformer_2", None),
                       scheduler=self.get_module("scheduler"),
                       vae=self.get_module("vae"),
                       pipeline=self))

    self.add_stage(stage_name="decoding_stage",
                   stage=DecodingStage(vae=self.get_module("vae"),
                                       pipeline=self))

Functions