def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.device = get_local_torch_device()
# build list of trainable transformers
for transformer_name in self.trainable_transformer_names:
if (transformer_name in self.modules
and self.modules[transformer_name] is not None):
self.trainable_transformer_modules[transformer_name] = (
self.modules[transformer_name])
# check for transformer_2 in case of Wan2.2 MoE or fake_score_transformer_2
if transformer_name.endswith("_2"):
raise ValueError(
f"trainable_transformer_name override in pipelines should not include _2 suffix: {transformer_name}"
)
secondary_transformer_name = transformer_name + "_2"
if (secondary_transformer_name in self.modules
and self.modules[secondary_transformer_name] is not None):
self.trainable_transformer_modules[
secondary_transformer_name] = self.modules[
secondary_transformer_name]
logger.info(
"trainable_transformer_modules: %s",
self.trainable_transformer_modules.keys(),
)
for (
transformer_name,
transformer_module,
) in self.trainable_transformer_modules.items():
self.exclude_lora_layers[transformer_name] = (
transformer_module.config.arch_config.exclude_lora_layers)
self.lora_target_modules = self.fastvideo_args.lora_target_modules
self.lora_path = self.fastvideo_args.lora_path
self.lora_nickname = self.fastvideo_args.lora_nickname
self.training_mode = self.fastvideo_args.training_mode
if self.training_mode and getattr(self.fastvideo_args, "lora_training",
False):
assert isinstance(self.fastvideo_args, TrainingArgs)
if self.fastvideo_args.lora_alpha is None:
self.fastvideo_args.lora_alpha = self.fastvideo_args.lora_rank
self.lora_rank = self.fastvideo_args.lora_rank # type: ignore
self.lora_alpha = self.fastvideo_args.lora_alpha # type: ignore
logger.info(
"Using LoRA training with rank %d and alpha %d",
self.lora_rank,
self.lora_alpha,
)
if self.lora_target_modules is None:
self.lora_target_modules = [
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"to_q",
"to_k",
"to_v",
"to_out",
"to_qkv",
"to_gate_compress",
]
logger.info(
"Using default lora_target_modules for all transformers: %s",
self.lora_target_modules,
)
else:
logger.warning(
"Using custom lora_target_modules for all transformers, which may not be intended: %s",
self.lora_target_modules,
)
self.convert_to_lora_layers()
# Inference
elif not self.training_mode and self.lora_path is not None:
self.convert_to_lora_layers()
self.set_lora_adapter(
self.lora_nickname, # type: ignore
self.lora_path,
) # type: ignore