Skip to content

weight_utils

Utilities for downloading and initializing model weights.

Functions

fastvideo.models.loader.weight_utils.default_weight_loader

default_weight_loader(param: Tensor, loaded_weight: Tensor) -> None

Default weight loader.

Source code in fastvideo/models/loader/weight_utils.py
def default_weight_loader(param: torch.Tensor,
                          loaded_weight: torch.Tensor) -> None:
    """Default weight loader."""
    try:
        if param.numel() == 1 and loaded_weight.numel() == 1:
            # Sometimes scalar values aren't considered tensors with shapes
            # so if both param and loaded_weight are a scalar,
            # "broadcast" instead of copy
            param.data.fill_(loaded_weight.item())
        else:
            assert param.size() == loaded_weight.size(), (
                f"Attempted to load weight ({loaded_weight.size()}) "
                f"into parameter ({param.size()})")

            param.data.copy_(loaded_weight)
    except Exception:
        # NOTE: This exception is added for the purpose of setting breakpoint to
        # debug weight loading issues.
        raise

fastvideo.models.loader.weight_utils.enable_hf_transfer

enable_hf_transfer() -> None

automatically activates hf_transfer

Source code in fastvideo/models/loader/weight_utils.py
def enable_hf_transfer() -> None:
    """automatically activates hf_transfer
    """
    if "HF_HUB_ENABLE_HF_TRANSFER" not in os.environ:
        try:
            # enable hf hub transfer if available
            import hf_transfer  # type: ignore # noqa
            huggingface_hub.constants.HF_HUB_ENABLE_HF_TRANSFER = True
        except ImportError:
            pass

fastvideo.models.loader.weight_utils.filter_files_not_needed_for_inference

filter_files_not_needed_for_inference(hf_weights_files: list[str]) -> list[str]

Exclude files that are not needed for inference.

See https://github.com/huggingface/transformers/blob/v4.34.0/src/transformers/trainer.py#L227-L233

Source code in fastvideo/models/loader/weight_utils.py
def filter_files_not_needed_for_inference(
        hf_weights_files: list[str]) -> list[str]:
    """
    Exclude files that are not needed for inference.

    See https://github.com/huggingface/transformers/blob/v4.34.0/src/transformers/trainer.py#L227-L233
    """
    blacklist = [
        "training_args.bin",
        "optimizer.bin",
        "optimizer.pt",
        "scheduler.pt",
        "scaler.pt",
    ]
    hf_weights_files = [
        f for f in hf_weights_files
        if not any(f.endswith(x) for x in blacklist)
    ]
    return hf_weights_files

fastvideo.models.loader.weight_utils.maybe_remap_kv_scale_name

maybe_remap_kv_scale_name(name: str, params_dict: dict) -> str | None

Remap the name of FP8 k/v_scale parameters.

This function handles the remapping of FP8 k/v_scale parameter names. It detects if the given name ends with a suffix and attempts to remap it to the expected name format in the model. If the remapped name is not found in the params_dict, a warning is printed and None is returned.

Parameters:

Name Type Description Default
name str

The original loaded checkpoint parameter name.

required
params_dict dict

Dictionary containing the model's named parameters.

required

Returns:

Name Type Description
str str | None

The remapped parameter name if successful, or the original name if no remapping is needed.

None str | None

If the remapped name is not found in params_dict.

Source code in fastvideo/models/loader/weight_utils.py
def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> str | None:
    """Remap the name of FP8 k/v_scale parameters.

    This function handles the remapping of FP8 k/v_scale parameter names.
    It detects if the given name ends with a suffix and attempts to remap
    it to the expected name format in the model. If the remapped name is not
    found in the params_dict, a warning is printed and None is returned.

    Args:
        name (str): The original loaded checkpoint parameter name.
        params_dict (dict): Dictionary containing the model's named parameters.

    Returns:
        str: The remapped parameter name if successful, or the original name
             if no remapping is needed.
        None: If the remapped name is not found in params_dict.
    """
    if name.endswith(".kv_scale"):
        logger.warning_once(
            "DEPRECATED. Found kv_scale in the checkpoint. "
            "This format is deprecated in favor of separate k_scale and "
            "v_scale tensors and will be removed in a future release. "
            "Functionally, we will remap kv_scale to k_scale and duplicate "
            "k_scale to v_scale")
        # NOTE: we remap the deprecated kv_scale to k_scale
        remapped_name = name.replace(".kv_scale", ".attn.k_scale")
        if remapped_name not in params_dict:
            logger.warning_once(
                f"Found kv_scale in the checkpoint (e.g. {name}), "
                "but not found the expected name in the model "
                f"(e.g. {remapped_name}). kv_scale is "
                "not loaded.")
            return None
        return remapped_name

    possible_scale_names = [".k_scale", ".v_scale"]
    modelopt_scale_names = [
        ".self_attn.k_proj.k_scale", ".self_attn.v_proj.v_scale"
    ]
    for scale_name in possible_scale_names:
        if name.endswith(scale_name):
            if any(mo_scale_name in name
                   for mo_scale_name in modelopt_scale_names):
                remapped_name = name.replace(
                    f".self_attn.{scale_name[1]}_proj{scale_name}",
                    f".self_attn.attn{scale_name}")
            else:
                remapped_name = name.replace(scale_name, f".attn{scale_name}")
            if remapped_name not in params_dict:
                logger.warning_once(
                    f"Found {scale_name} in the checkpoint (e.g. {name}), "
                    "but not found the expected name in the model "
                    f"(e.g. {remapped_name}). {scale_name} is "
                    "not loaded.")
                return None
            return remapped_name

    # If there were no matches, return the untouched param name
    return name

fastvideo.models.loader.weight_utils.pt_weights_iterator

pt_weights_iterator(hf_weights_files: list[str], to_cpu: bool = True) -> Generator[tuple[str, Tensor], None, None]

Iterate over the weights in the model bin/pt files.

Source code in fastvideo/models/loader/weight_utils.py
def pt_weights_iterator(
    hf_weights_files: list[str],
    to_cpu: bool = True,
) -> Generator[tuple[str, torch.Tensor], None, None]:
    """Iterate over the weights in the model bin/pt files."""
    device = "cpu" if to_cpu else str(get_local_torch_device())
    enable_tqdm = not torch.distributed.is_initialized(
    ) or torch.distributed.get_rank() == 0
    for bin_file in tqdm(
            hf_weights_files,
            desc="Loading pt checkpoint shards",
            disable=not enable_tqdm,
            bar_format=_BAR_FORMAT,
    ):
        state = torch.load(bin_file, map_location=device, weights_only=True)
        yield from state.items()
        del state

fastvideo.models.loader.weight_utils.safetensors_weights_iterator

safetensors_weights_iterator(hf_weights_files: list[str], to_cpu: bool = True) -> Generator[tuple[str, Tensor], None, None]

Iterate over the weights in the model safetensor files.

Source code in fastvideo/models/loader/weight_utils.py
def safetensors_weights_iterator(
    hf_weights_files: list[str],
    to_cpu: bool = True,
) -> Generator[tuple[str, torch.Tensor], None, None]:
    """Iterate over the weights in the model safetensor files."""
    enable_tqdm = not torch.distributed.is_initialized(
    ) or torch.distributed.get_rank() == 0
    device = "cpu" if to_cpu else str(get_local_torch_device())
    for st_file in tqdm(
            hf_weights_files,
            desc="Loading safetensors checkpoint shards",
            disable=not enable_tqdm,
            bar_format=_BAR_FORMAT,
    ):
        with safe_open(st_file, framework="pt", device=device) as f:
            for name in f.keys():  # noqa: SIM118
                param = f.get_tensor(name)
                yield name, param