def set_lora_adapter(self,
lora_nickname: str,
lora_path: str | None = None): # type: ignore
"""
Load a LoRA adapter into the pipeline and merge it into the transformer.
Args:
lora_nickname: The "nick name" of the adapter when referenced in the pipeline.
lora_path: The path to the adapter, either a local path or a Hugging Face repo id.
"""
if lora_nickname not in self.lora_adapters and lora_path is None:
raise ValueError(
f"Adapter {lora_nickname} not found in the pipeline. Please provide lora_path to load it."
)
if not self.lora_initialized:
self.convert_to_lora_layers()
adapter_updated = False
rank = dist.get_rank()
if lora_path is not None and lora_path != self.cur_adapter_path:
lora_local_path = maybe_download_lora(lora_path)
lora_state_dict = load_file(lora_local_path)
# Map the hf layer names to our custom layer names
param_names_mapping_fn = get_param_names_mapping(
self.modules["transformer"].param_names_mapping)
lora_param_names_mapping_fn = get_param_names_mapping(
self.modules["transformer"].lora_param_names_mapping)
# Extract alpha values and weights in a single pass
to_merge_params: defaultdict[Hashable,
dict[Any, Any]] = (defaultdict(dict))
for name, weight in lora_state_dict.items():
# Extract weights (lora_A, lora_B, and lora_alpha)
name = name.replace("diffusion_model.", "")
name = name.replace(".weight", "")
if "lora_alpha" in name:
# Store alpha with minimal mapping - same processing as lora_A/lora_B
# but store in lora_adapters with ".lora_alpha" suffix
layer_name = name.replace(".lora_alpha", "")
layer_name, _, _ = lora_param_names_mapping_fn(layer_name)
target_name, _, _ = param_names_mapping_fn(layer_name)
# Store alpha alongside weights with same target_name base
alpha_key = target_name + ".lora_alpha"
self.lora_adapters[lora_nickname][alpha_key] = (
weight.item()
if weight.numel() == 1 else float(weight.mean()))
continue
name, _, _ = lora_param_names_mapping_fn(name)
target_name, merge_index, num_params_to_merge = (
param_names_mapping_fn(name))
# for (in_dim, r) @ (r, out_dim), we only merge (r, out_dim * n) where n is the number of linear layers to fuse
# see param mapping in HunyuanVideoArchConfig
if merge_index is not None and "lora_B" in name:
to_merge_params[target_name][merge_index] = weight
if len(to_merge_params[target_name]) == num_params_to_merge:
# cat at output dim according to the merge_index order
sorted_tensors = [
to_merge_params[target_name][i]
for i in range(num_params_to_merge)
]
weight = torch.cat(sorted_tensors, dim=1)
del to_merge_params[target_name]
else:
continue
if target_name in self.lora_adapters[lora_nickname]:
raise ValueError(
f"Target name {target_name} already exists in lora_adapters[{lora_nickname}]"
)
self.lora_adapters[lora_nickname][target_name] = weight.to(
self.device)
adapter_updated = True
self.cur_adapter_path = lora_path
logger.info("Rank %d: loaded LoRA adapter %s", rank, lora_path)
if not adapter_updated and self.cur_adapter_name == lora_nickname:
return
self.cur_adapter_name = lora_nickname
# Merge the new adapter
adapted_count = 0
for (
transformer_name,
transformer_lora_layers,
) in self.lora_layers.items():
for (
module,
layers,
) in transformer_lora_layers.lora_layers_by_block():
with _get_hook_ctx(module):
for name, layer in layers.items():
lora_A_name = name + ".lora_A"
lora_B_name = name + ".lora_B"
lora_alpha_name = name + ".lora_alpha"
if (lora_A_name in self.lora_adapters[lora_nickname]
and lora_B_name
in self.lora_adapters[lora_nickname]):
# Get alpha value for this layer (defaults to None if not present)
lora_A = self.lora_adapters[lora_nickname][
lora_A_name]
lora_B = self.lora_adapters[lora_nickname][
lora_B_name]
# Simple lookup - alpha stored with same naming scheme as lora_A/lora_B
alpha = (self.lora_adapters[lora_nickname].get(
lora_alpha_name) if adapter_updated else None)
try:
layer.set_lora_weights(
lora_A,
lora_B,
lora_alpha=alpha,
training_mode=self.fastvideo_args.
training_mode,
lora_path=lora_path,
)
except Exception as e:
logger.error(
"Error setting LoRA weights for layer %s: %s",
name,
str(e),
)
raise e
adapted_count += 1
else:
if rank == 0:
logger.warning(
"LoRA adapter %s does not contain the weights for layer %s. LoRA will not be applied to it.",
lora_path,
name,
)
layer.disable_lora = True
logger.info(
"Rank %d: LoRA adapter %s applied to %d layers",
rank,
lora_path,
adapted_count,
)