def set_lora_adapter(self,
lora_nickname: str,
lora_path: str | None = None): # type: ignore
"""
Load a LoRA adapter into the pipeline and merge it into the transformer.
Args:
lora_nickname: The "nick name" of the adapter when referenced in the pipeline.
lora_path: The path to the adapter, either a local path or a Hugging Face repo id.
"""
if lora_nickname not in self.lora_adapters and lora_path is None:
raise ValueError(
f"Adapter {lora_nickname} not found in the pipeline. Please provide lora_path to load it."
)
if not self.lora_initialized:
self.convert_to_lora_layers()
adapter_updated = False
rank = dist.get_rank()
if lora_path is not None and lora_path != self.cur_adapter_path:
lora_local_path = maybe_download_lora(lora_path)
lora_state_dict = load_file(lora_local_path)
# Map the hf layer names to our custom layer names
param_names_mapping_fn = get_param_names_mapping(
self.modules["transformer"].param_names_mapping)
lora_param_names_mapping_fn = get_param_names_mapping(
self.modules["transformer"].lora_param_names_mapping)
# Extract alpha values and weights in a single pass
to_merge_params: defaultdict[Hashable,
dict[Any, Any]] = defaultdict(dict)
for name, weight in lora_state_dict.items():
# Extract weights (lora_A, lora_B, and lora_alpha)
name = name.replace("diffusion_model.", "")
name = name.replace(".weight", "")
if "lora_alpha" in name:
# Store alpha with minimal mapping - same processing as lora_A/lora_B
# but store in lora_adapters with ".lora_alpha" suffix
layer_name = name.replace(".lora_alpha", "")
layer_name, _, _ = lora_param_names_mapping_fn(layer_name)
target_name, _, _ = param_names_mapping_fn(layer_name)
# Store alpha alongside weights with same target_name base
alpha_key = target_name + ".lora_alpha"
self.lora_adapters[lora_nickname][alpha_key] = weight.item(
) if weight.numel() == 1 else float(weight.mean())
continue
name, _, _ = lora_param_names_mapping_fn(name)
target_name, merge_index, num_params_to_merge = param_names_mapping_fn(
name)
# for (in_dim, r) @ (r, out_dim), we only merge (r, out_dim * n) where n is the number of linear layers to fuse
# see param mapping in HunyuanVideoArchConfig
if merge_index is not None and "lora_B" in name:
to_merge_params[target_name][merge_index] = weight
if len(to_merge_params[target_name]) == num_params_to_merge:
# cat at output dim according to the merge_index order
sorted_tensors = [
to_merge_params[target_name][i]
for i in range(num_params_to_merge)
]
weight = torch.cat(sorted_tensors, dim=1)
del to_merge_params[target_name]
else:
continue
if target_name in self.lora_adapters[lora_nickname]:
raise ValueError(
f"Target name {target_name} already exists in lora_adapters[{lora_nickname}]"
)
self.lora_adapters[lora_nickname][target_name] = weight.to(
self.device)
adapter_updated = True
self.cur_adapter_path = lora_path
logger.info("Rank %d: loaded LoRA adapter %s", rank, lora_path)
if not adapter_updated and self.cur_adapter_name == lora_nickname:
return
self.cur_adapter_name = lora_nickname
# Merge the new adapter
adapted_count = 0
for name, layer in self.lora_layers.items():
lora_A_name = name + ".lora_A"
lora_B_name = name + ".lora_B"
lora_alpha_name = name + ".lora_alpha"
if lora_A_name in self.lora_adapters[lora_nickname]\
and lora_B_name in self.lora_adapters[lora_nickname]:
# Get alpha value for this layer (defaults to None if not present)
lora_A = self.lora_adapters[lora_nickname][lora_A_name]
lora_B = self.lora_adapters[lora_nickname][lora_B_name]
# Simple lookup - alpha stored with same naming scheme as lora_A/lora_B
alpha = self.lora_adapters[lora_nickname].get(
lora_alpha_name) if adapter_updated else None
layer.set_lora_weights(
lora_A,
lora_B,
lora_alpha=alpha,
training_mode=self.fastvideo_args.training_mode,
lora_path=lora_path)
adapted_count += 1
else:
if rank == 0:
logger.warning(
"LoRA adapter %s does not contain the weights for layer %s. LoRA will not be applied to it.",
lora_path, name)
layer.disable_lora = True
logger.info("Rank %d: LoRA adapter %s applied to %d layers", rank,
lora_path, adapted_count)