Skip to content

wan_distillation_pipeline

Classes

fastvideo.training.wan_distillation_pipeline.WanDistillationPipeline

WanDistillationPipeline(model_path: str, fastvideo_args: TrainingArgs, required_config_modules: list[str] | None = None, loaded_modules: dict[str, Module] | None = None)

Bases: DistillationPipeline

A distillation pipeline for Wan that uses a single transformer model. The main transformer serves as the student model, and copies are made for teacher and critic.

Source code in fastvideo/training/training_pipeline.py
def __init__(
        self,
        model_path: str,
        fastvideo_args: TrainingArgs,
        required_config_modules: list[str] | None = None,
        loaded_modules: dict[str, torch.nn.Module] | None = None) -> None:
    fastvideo_args.inference_mode = False
    self.lora_training = fastvideo_args.lora_training
    if self.lora_training and fastvideo_args.lora_rank is None:
        raise ValueError("lora rank must be set when using lora training")

    set_random_seed(fastvideo_args.seed)  # for lora param init
    super().__init__(model_path, fastvideo_args, required_config_modules,
                     loaded_modules)  # type: ignore
    self.tracker = DummyTracker()

Functions

fastvideo.training.wan_distillation_pipeline.WanDistillationPipeline.create_training_stages
create_training_stages(training_args: TrainingArgs)

May be used in future refactors.

Source code in fastvideo/training/wan_distillation_pipeline.py
def create_training_stages(self, training_args: TrainingArgs):
    """
    May be used in future refactors.
    """
    pass
fastvideo.training.wan_distillation_pipeline.WanDistillationPipeline.initialize_pipeline
initialize_pipeline(fastvideo_args: FastVideoArgs)

Initialize Wan-specific scheduler.

Source code in fastvideo/training/wan_distillation_pipeline.py
def initialize_pipeline(self, fastvideo_args: FastVideoArgs):
    """Initialize Wan-specific scheduler."""
    self.modules["scheduler"] = FlowMatchEulerDiscreteScheduler(
        shift=fastvideo_args.pipeline_config.flow_shift)

Functions