Wan video diffusion pipeline implementation.
This module contains an implementation of the Wan video diffusion pipeline
using the modular pipeline architecture.
Classes
fastvideo.pipelines.basic.wan.wan_i2v_dmd_pipeline.WanImageToVideoDmdPipeline
WanImageToVideoDmdPipeline(*args, **kwargs)
Bases: LoRAPipeline, ComposedPipelineBase
Source code in fastvideo/pipelines/lora_pipeline.py
| def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.device = get_local_torch_device()
self.exclude_lora_layers = self.modules[
"transformer"].config.arch_config.exclude_lora_layers
self.lora_target_modules = self.fastvideo_args.lora_target_modules
self.lora_path = self.fastvideo_args.lora_path
self.lora_nickname = self.fastvideo_args.lora_nickname
self.training_mode = self.fastvideo_args.training_mode
if self.training_mode and getattr(self.fastvideo_args, "lora_training",
False):
assert isinstance(self.fastvideo_args, TrainingArgs)
if self.fastvideo_args.lora_alpha is None:
self.fastvideo_args.lora_alpha = self.fastvideo_args.lora_rank
self.lora_rank = self.fastvideo_args.lora_rank # type: ignore
self.lora_alpha = self.fastvideo_args.lora_alpha # type: ignore
logger.info("Using LoRA training with rank %d and alpha %d",
self.lora_rank, self.lora_alpha)
if self.lora_target_modules is None:
self.lora_target_modules = [
"q_proj", "k_proj", "v_proj", "o_proj", "to_q", "to_k",
"to_v", "to_out", "to_qkv"
]
self.convert_to_lora_layers()
# Inference
elif not self.training_mode and self.lora_path is not None:
self.convert_to_lora_layers()
self.set_lora_adapter(
self.lora_nickname, # type: ignore
self.lora_path) # type: ignore
|
Functions
fastvideo.pipelines.basic.wan.wan_i2v_dmd_pipeline.WanImageToVideoDmdPipeline.create_pipeline_stages
Set up pipeline stages with proper dependency injection.
Source code in fastvideo/pipelines/basic/wan/wan_i2v_dmd_pipeline.py
| def create_pipeline_stages(self, fastvideo_args: FastVideoArgs):
"""Set up pipeline stages with proper dependency injection."""
self.add_stage(stage_name="input_validation_stage",
stage=InputValidationStage())
self.add_stage(stage_name="prompt_encoding_stage",
stage=TextEncodingStage(
text_encoders=[self.get_module("text_encoder")],
tokenizers=[self.get_module("tokenizer")],
))
self.add_stage(stage_name="image_encoding_stage",
stage=ImageEncodingStage(
image_encoder=self.get_module("image_encoder"),
image_processor=self.get_module("image_processor"),
))
self.add_stage(stage_name="conditioning_stage",
stage=ConditioningStage())
self.add_stage(stage_name="timestep_preparation_stage",
stage=TimestepPreparationStage(
scheduler=self.get_module("scheduler")))
self.add_stage(stage_name="latent_preparation_stage",
stage=LatentPreparationStage(
scheduler=self.get_module("scheduler"),
transformer=self.get_module("transformer")))
self.add_stage(stage_name="image_latent_preparation_stage",
stage=ImageVAEEncodingStage(vae=self.get_module("vae")))
self.add_stage(stage_name="denoising_stage",
stage=DmdDenoisingStage(
transformer=self.get_module("transformer"),
scheduler=self.get_module("scheduler")))
self.add_stage(stage_name="decoding_stage",
stage=DecodingStage(vae=self.get_module("vae")))
|
Functions