fastvideo.v1.models.loader.utils
#
Utilities for selecting and loading models.
Module Contents#
Functions#
Creates a mapping function that transforms parameter names using regex patterns. |
|
Converts a Hugging Face parameter state dictionary to a custom parameter state dictionary. |
|
Sets the default torch dtype to the given dtype. |
Data#
API#
- fastvideo.v1.models.loader.utils.get_param_names_mapping(mapping_dict: dict[str, str]) collections.abc.Callable[[str], tuple[str, Any, Any]] [source]#
Creates a mapping function that transforms parameter names using regex patterns.
- fastvideo.v1.models.loader.utils.hf_to_custom_state_dict(hf_param_sd: dict[str, torch.Tensor] | collections.abc.Iterator[tuple[str, torch.Tensor]], param_names_mapping: collections.abc.Callable[[str], tuple[str, Any, Any]]) tuple[dict[str, torch.Tensor], dict[str, tuple[str, Any, Any]]] [source]#
Converts a Hugging Face parameter state dictionary to a custom parameter state dictionary.
- Parameters:
hf_param_sd (Dict[str, torch.Tensor]) β The Hugging Face parameter state dictionary
param_names_mapping (Callable[[str], tuple[str, Any, Any]]) β A function that maps parameter names from source to target format
- Returns:
The custom formatted parameter state dict reverse_param_names_mapping (Dict[str, Tuple[str, Any, Any]]): Maps back from custom to hf
- Return type:
custom_param_sd (Dict[str, torch.Tensor])
- fastvideo.v1.models.loader.utils.set_default_torch_dtype(dtype: torch.dtype)[source]#
Sets the default torch dtype to the given dtype.