fastvideo.v1.models.loader.fsdp_load#

Module Contents#

Functions#

get_param_names_mapping

Creates a mapping function that transforms parameter names using regex patterns.

load_fsdp_model

load_fsdp_model_from_full_model_state_dict

Converting full state dict into a sharded state dict and loading it into FSDP model

set_default_dtype

Context manager to set torch’s default dtype.

shard_model

Utility to shard a model with FSDP using the PyTorch Distributed fully_shard API.

Data#

API#

fastvideo.v1.models.loader.fsdp_load.get_param_names_mapping(mapping_dict: Dict[str, str]) Callable[[str], tuple[str, Any, Any]][source]#

Creates a mapping function that transforms parameter names using regex patterns.

Parameters:
  • mapping_dict (Dict[str, str]) – Dictionary mapping regex patterns to replacement patterns

  • param_name (str) – The parameter name to be transformed

Returns:

A function that maps parameter names from source to target format

Return type:

Callable[[str], str]

fastvideo.v1.models.loader.fsdp_load.load_fsdp_model(model_cls: Type[torch.nn.Module], init_params: Dict[str, Any], weight_dir_list: List[str], device: torch.device, default_dtype: torch.dtype, param_dtype: torch.dtype, reduce_dtype: torch.dtype, cpu_offload: bool = False, output_dtype: Optional[torch.dtype] = None) torch.nn.Module[source]#
fastvideo.v1.models.loader.fsdp_load.load_fsdp_model_from_full_model_state_dict(model: torch.nn.Module, full_sd_iterator: Generator[Tuple[str, torch.Tensor], None, None], device: torch.device, strict: bool = False, cpu_offload: bool = False, param_names_mapping: Optional[Callable[[str], tuple[str, Any, Any]]] = None) torch.nn.modules.module._IncompatibleKeys[source]#

Converting full state dict into a sharded state dict and loading it into FSDP model

Parameters:
  • model (FSDPModule) – Model to generate fully qualified names for cpu_state_dict

  • full_sd_iterator (Generator) – an iterator yielding (param_name, tensor) pairs

  • device (torch.device) – device used to move full state dict tensors

  • strict (bool) – flag to check if to load the model in strict mode

  • cpu_offload (bool) – flag to check if offload to CPU is enabled

  • param_names_mapping (Optional[Callable[[str], str]]) – a function that maps full param name to sharded param name

Returns:

  • missing_keys is a list of str containing the missing keys

  • unexpected_keys is a list of str containing the unexpected keys

Return type:

NamedTuple with missing_keys and unexpected_keys fields

Raises:

NotImplementedError – If got FSDP with more than 1D.

fastvideo.v1.models.loader.fsdp_load.logger[source]#

β€˜init_logger(…)’

fastvideo.v1.models.loader.fsdp_load.set_default_dtype(dtype: torch.dtype) Generator[None, None, None][source]#

Context manager to set torch’s default dtype.

Parameters:

dtype (torch.dtype) – The desired default dtype inside the context manager.

Returns:

context manager for setting default dtype.

Return type:

ContextManager

.. rubric:: Example

with set_default_dtype(torch.bfloat16): x = torch.tensor([1, 2, 3]) x.dtype torch.bfloat16

fastvideo.v1.models.loader.fsdp_load.shard_model(model, *, cpu_offload: bool, reshard_after_forward: bool = True, mp_policy: Optional[torch.distributed.fsdp.MixedPrecisionPolicy] = None, dp_mesh: Optional[torch.distributed.DeviceMesh] = None) None[source]#

Utility to shard a model with FSDP using the PyTorch Distributed fully_shard API.

This method will over the model’s named modules from the bottom-up and apply shard modules based on whether they meet any of the criteria from shard_conditions.

Parameters:
  • model (TransformerDecoder) – Model to shard with FSDP.

  • shard_conditions (List[Callable[[str, nn.Module], bool]]) – A list of functions to determine which modules to shard with FSDP. Each function should take module name (relative to root) and the module itself, returning True if FSDP should shard the module and False otherwise. If any of shard_conditions return True for a given module, it will be sharded by FSDP.

  • cpu_offload (bool) – If set to True, FSDP will offload parameters, gradients, and optimizer states to CPU.

  • reshard_after_forward (bool) – Whether to reshard parameters and buffers after the forward pass. Setting this to True corresponds to the FULL_SHARD sharding strategy from FSDP1, while setting it to False corresponds to the SHARD_GRAD_OP sharding strategy.

  • dp_mesh (Optional[DeviceMesh]) – Device mesh to use for FSDP sharding under multiple parallelism. Default to None.

Raises:

ValueError – If no layer modules were sharded, indicating that no shard_condition was triggered.