Hunyuan video diffusion pipeline implementation.
This module contains an implementation of the Hunyuan video diffusion pipeline
using the modular pipeline architecture.
Classes
fastvideo.pipelines.basic.stepvideo.stepvideo_pipeline.StepVideoPipeline
StepVideoPipeline(*args, **kwargs)
Bases: LoRAPipeline, ComposedPipelineBase
Source code in fastvideo/pipelines/lora_pipeline.py
| def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.device = get_local_torch_device()
# build list of trainable transformers
for transformer_name in self.trainable_transformer_names:
if transformer_name in self.modules and self.modules[
transformer_name] is not None:
self.trainable_transformer_modules[
transformer_name] = self.modules[transformer_name]
# check for transformer_2 in case of Wan2.2 MoE or fake_score_transformer_2
if transformer_name.endswith("_2"):
raise ValueError(
f"trainable_transformer_name override in pipelines should not include _2 suffix: {transformer_name}"
)
secondary_transformer_name = transformer_name + "_2"
if secondary_transformer_name in self.modules and self.modules[
secondary_transformer_name] is not None:
self.trainable_transformer_modules[
secondary_transformer_name] = self.modules[
secondary_transformer_name]
logger.info("trainable_transformer_modules: %s",
self.trainable_transformer_modules.keys())
for transformer_name, transformer_module in self.trainable_transformer_modules.items(
):
self.exclude_lora_layers[
transformer_name] = transformer_module.config.arch_config.exclude_lora_layers
self.lora_target_modules = self.fastvideo_args.lora_target_modules
self.lora_path = self.fastvideo_args.lora_path
self.lora_nickname = self.fastvideo_args.lora_nickname
self.training_mode = self.fastvideo_args.training_mode
if self.training_mode and getattr(self.fastvideo_args, "lora_training",
False):
assert isinstance(self.fastvideo_args, TrainingArgs)
if self.fastvideo_args.lora_alpha is None:
self.fastvideo_args.lora_alpha = self.fastvideo_args.lora_rank
self.lora_rank = self.fastvideo_args.lora_rank # type: ignore
self.lora_alpha = self.fastvideo_args.lora_alpha # type: ignore
logger.info("Using LoRA training with rank %d and alpha %d",
self.lora_rank, self.lora_alpha)
if self.lora_target_modules is None:
self.lora_target_modules = [
"q_proj", "k_proj", "v_proj", "o_proj", "to_q", "to_k",
"to_v", "to_out", "to_qkv", "to_gate_compress"
]
logger.info(
"Using default lora_target_modules for all transformers: %s",
self.lora_target_modules)
else:
logger.warning(
"Using custom lora_target_modules for all transformers, which may not be intended: %s",
self.lora_target_modules)
self.convert_to_lora_layers()
# Inference
elif not self.training_mode and self.lora_path is not None:
self.convert_to_lora_layers()
self.set_lora_adapter(
self.lora_nickname, # type: ignore
self.lora_path) # type: ignore
|
Functions
fastvideo.pipelines.basic.stepvideo.stepvideo_pipeline.StepVideoPipeline.create_pipeline_stages
Set up pipeline stages with proper dependency injection.
Source code in fastvideo/pipelines/basic/stepvideo/stepvideo_pipeline.py
| def create_pipeline_stages(self, fastvideo_args: FastVideoArgs):
"""Set up pipeline stages with proper dependency injection."""
self.add_stage(stage_name="input_validation_stage",
stage=InputValidationStage())
self.add_stage(stage_name="prompt_encoding_stage",
stage=StepvideoPromptEncodingStage(
stepllm=self.get_module("text_encoder"),
clip=self.get_module("text_encoder_2"),
))
self.add_stage(stage_name="timestep_preparation_stage",
stage=TimestepPreparationStage(
scheduler=self.get_module("scheduler")))
self.add_stage(stage_name="latent_preparation_stage",
stage=LatentPreparationStage(
scheduler=self.get_module("scheduler"),
transformer=self.get_module("transformer"),
))
self.add_stage(stage_name="denoising_stage",
stage=DenoisingStage(
transformer=self.get_module("transformer"),
scheduler=self.get_module("scheduler")))
self.add_stage(stage_name="decoding_stage",
stage=DecodingStage(vae=self.get_module("vae")))
|
fastvideo.pipelines.basic.stepvideo.stepvideo_pipeline.StepVideoPipeline.initialize_pipeline
Initialize the pipeline.
Source code in fastvideo/pipelines/basic/stepvideo/stepvideo_pipeline.py
| def initialize_pipeline(self, fastvideo_args: FastVideoArgs):
"""
Initialize the pipeline.
"""
target_device = get_local_torch_device()
llm_dir = os.path.join(self.model_path, "step_llm")
clip_dir = os.path.join(self.model_path, "hunyuan_clip")
text_enc = self.build_llm(llm_dir, target_device)
clip_enc = self.build_clip(clip_dir, target_device)
self.add_module("text_encoder", text_enc)
self.add_module("text_encoder_2", clip_enc)
lib_path = (
os.path.join(
fastvideo_args.model_path,
'lib/liboptimus_ths-torch2.5-cu124.cpython-310-x86_64-linux-gnu.so'
) if os.path.isdir(fastvideo_args.model_path) # local checkout
else hf_hub_download(
repo_id=fastvideo_args.model_path,
filename=
'lib/liboptimus_ths-torch2.5-cu124.cpython-310-x86_64-linux-gnu.so'
))
torch.ops.load_library(lib_path)
|
fastvideo.pipelines.basic.stepvideo.stepvideo_pipeline.StepVideoPipeline.load_modules
Load the modules from the config.
Source code in fastvideo/pipelines/basic/stepvideo/stepvideo_pipeline.py
| def load_modules(self, fastvideo_args: FastVideoArgs) -> dict[str, Any]:
"""
Load the modules from the config.
"""
model_index = self._load_config(self.model_path)
logger.info("Loading pipeline modules from config: %s", model_index)
# remove keys that are not pipeline modules
model_index.pop("_class_name")
model_index.pop("_diffusers_version")
# some sanity checks
assert len(
model_index
) > 1, "model_index.json must contain at least one pipeline module"
required_modules = ["transformer", "scheduler", "vae"]
for module_name in required_modules:
if module_name not in model_index:
raise ValueError(
f"model_index.json must contain a {module_name} module")
logger.info("Diffusers config passed sanity checks")
# all the component models used by the pipeline
modules = {}
for module_name, (transformers_or_diffusers,
architecture) in model_index.items():
component_model_path = os.path.join(self.model_path, module_name)
module = PipelineComponentLoader.load_module(
module_name=module_name,
component_model_path=component_model_path,
transformers_or_diffusers=transformers_or_diffusers,
fastvideo_args=fastvideo_args,
)
logger.info("Loaded module %s from %s", module_name,
component_model_path)
if module_name in modules:
logger.warning("Overwriting module %s", module_name)
modules[module_name] = module
required_modules = self.required_config_modules
# Check if all required modules were loaded
for module_name in required_modules:
if module_name not in modules or modules[module_name] is None:
raise ValueError(
f"Required module {module_name} was not loaded properly")
return modules
|
Functions