Skip to content

utils

Utilities for selecting and loading models.

Functions

fastvideo.models.loader.utils.get_param_names_mapping

get_param_names_mapping(mapping_dict: dict[str, str]) -> Callable[[str], tuple[str, Any, Any]]

Creates a mapping function that transforms parameter names using regex patterns.

Parameters:

Name Type Description Default
mapping_dict Dict[str, str]

Dictionary mapping regex patterns to replacement patterns

required
param_name str

The parameter name to be transformed

required

Returns:

Type Description
Callable[[str], tuple[str, Any, Any]]

Callable[[str], str]: A function that maps parameter names from source to target format

Source code in fastvideo/models/loader/utils.py
def get_param_names_mapping(
        mapping_dict: dict[str, str]) -> Callable[[str], tuple[str, Any, Any]]:
    """
    Creates a mapping function that transforms parameter names using regex patterns.

    Args:
        mapping_dict (Dict[str, str]): Dictionary mapping regex patterns to replacement patterns
        param_name (str): The parameter name to be transformed

    Returns:
        Callable[[str], str]: A function that maps parameter names from source to target format
    """

    def mapping_fn(name: str) -> tuple[str, Any, Any]:
        # Try to match and transform the name using the regex patterns in mapping_dict
        for pattern, replacement in mapping_dict.items():
            match = re.match(pattern, name)
            if match:
                merge_index = None
                total_splitted_params = None
                if isinstance(replacement, tuple):
                    merge_index = replacement[1]
                    total_splitted_params = replacement[2]
                    replacement = replacement[0]
                name = re.sub(pattern, replacement, name)
                return name, merge_index, total_splitted_params

        # If no pattern matches, return the original name
        return name, None, None

    return mapping_fn

fastvideo.models.loader.utils.hf_to_custom_state_dict

hf_to_custom_state_dict(hf_param_sd: dict[str, Tensor] | Iterator[tuple[str, Tensor]], param_names_mapping: Callable[[str], tuple[str, Any, Any]]) -> tuple[dict[str, Tensor], dict[str, tuple[str, Any, Any]]]

Converts a Hugging Face parameter state dictionary to a custom parameter state dictionary.

Parameters:

Name Type Description Default
hf_param_sd Dict[str, Tensor]

The Hugging Face parameter state dictionary

required
param_names_mapping Callable[[str], tuple[str, Any, Any]]

A function that maps parameter names from source to target format

required

Returns:

Name Type Description
custom_param_sd Dict[str, Tensor]

The custom formatted parameter state dict

reverse_param_names_mapping Dict[str, Tuple[str, Any, Any]]

Maps back from custom to hf

Source code in fastvideo/models/loader/utils.py
def hf_to_custom_state_dict(
    hf_param_sd: dict[str, torch.Tensor] | Iterator[tuple[str, torch.Tensor]],
    param_names_mapping: Callable[[str], tuple[str, Any, Any]]
) -> tuple[dict[str, torch.Tensor], dict[str, tuple[str, Any, Any]]]:
    """
    Converts a Hugging Face parameter state dictionary to a custom parameter state dictionary.

    Args:
        hf_param_sd (Dict[str, torch.Tensor]): The Hugging Face parameter state dictionary
        param_names_mapping (Callable[[str], tuple[str, Any, Any]]): A function that maps parameter names from source to target format

    Returns:
        custom_param_sd (Dict[str, torch.Tensor]): The custom formatted parameter state dict
        reverse_param_names_mapping (Dict[str, Tuple[str, Any, Any]]): Maps back from custom to hf
    """
    custom_param_sd = {}
    to_merge_params = defaultdict(dict)  # type: ignore
    reverse_param_names_mapping = {}
    if isinstance(hf_param_sd, dict):
        hf_param_sd = hf_param_sd.items()  # type: ignore
    for source_param_name, full_tensor in hf_param_sd:  # type: ignore
        target_param_name, merge_index, num_params_to_merge = param_names_mapping(
            source_param_name)
        reverse_param_names_mapping[target_param_name] = (source_param_name,
                                                          merge_index,
                                                          num_params_to_merge)
        if merge_index is not None:
            to_merge_params[target_param_name][merge_index] = full_tensor
            if len(to_merge_params[target_param_name]) == num_params_to_merge:
                # cat at output dim according to the merge_index order
                sorted_tensors = [
                    to_merge_params[target_param_name][i]
                    for i in range(num_params_to_merge)
                ]
                full_tensor = torch.cat(sorted_tensors, dim=0)
                del to_merge_params[target_param_name]
            else:
                continue
        custom_param_sd[target_param_name] = full_tensor
    return custom_param_sd, reverse_param_names_mapping

fastvideo.models.loader.utils.set_default_torch_dtype

set_default_torch_dtype(dtype: dtype)

Sets the default torch dtype to the given dtype.

Source code in fastvideo/models/loader/utils.py
@contextlib.contextmanager
def set_default_torch_dtype(dtype: torch.dtype):
    """Sets the default torch dtype to the given dtype."""
    old_dtype = torch.get_default_dtype()
    torch.set_default_dtype(dtype)
    yield
    torch.set_default_dtype(old_dtype)