fastvideo.v1.models.loader.fsdp_load#

Module Contents#

Functions#

load_model_from_full_model_state_dict

Converting full state dict into a sharded state dict and loading it into FSDP model (if training) or normal huggingface model

maybe_load_fsdp_model

Load the model with FSDP if is training, else load the model without FSDP.

set_default_dtype

Context manager to set torch’s default dtype.

shard_model

Utility to shard a model with FSDP using the PyTorch Distributed fully_shard API.

Data#

API#

fastvideo.v1.models.loader.fsdp_load.load_model_from_full_model_state_dict(model: Union[torch.distributed.fsdp.FSDPModule, torch.nn.Module], full_sd_iterator: Generator[Tuple[str, torch.Tensor], None, None], device: torch.device, param_dtype: torch.dtype, strict: bool = False, cpu_offload: bool = False, param_names_mapping: Optional[Callable[[str], tuple[str, Any, Any]]] = None, training_mode: bool = True) torch.nn.modules.module._IncompatibleKeys[source]#

Converting full state dict into a sharded state dict and loading it into FSDP model (if training) or normal huggingface model

Parameters:
  • model (Union[FSDPModule, torch.nn.Module]) – Model to generate fully qualified names for cpu_state_dict

  • full_sd_iterator (Generator) – an iterator yielding (param_name, tensor) pairs

  • device (torch.device) – device used to move full state dict tensors

  • param_dtype (torch.dtype) – dtype used to move full state dict tensors

  • strict (bool) – flag to check if to load the model in strict mode

  • cpu_offload (bool) – flag to check if FSDP offload is enabled

  • param_names_mapping (Optional[Callable[[str], str]]) – a function that maps full param name to sharded param name

  • training_mode (bool) – apply FSDP only for training

Returns:

  • missing_keys is a list of str containing the missing keys

  • unexpected_keys is a list of str containing the unexpected keys

Return type:

NamedTuple with missing_keys and unexpected_keys fields

Raises:

NotImplementedError – If got FSDP with more than 1D.

fastvideo.v1.models.loader.fsdp_load.logger[source]#

β€˜init_logger(…)’

fastvideo.v1.models.loader.fsdp_load.maybe_load_fsdp_model(model_cls: Type[torch.nn.Module], init_params: Dict[str, Any], weight_dir_list: List[str], device: torch.device, hsdp_replicate_dim: int, hsdp_shard_dim: int, default_dtype: torch.dtype, param_dtype: torch.dtype, reduce_dtype: torch.dtype, cpu_offload: bool = False, fsdp_inference: bool = False, output_dtype: Optional[torch.dtype] = None, training_mode: bool = True) torch.nn.Module[source]#

Load the model with FSDP if is training, else load the model without FSDP.

fastvideo.v1.models.loader.fsdp_load.set_default_dtype(dtype: torch.dtype) Generator[None, None, None][source]#

Context manager to set torch’s default dtype.

Parameters:

dtype (torch.dtype) – The desired default dtype inside the context manager.

Returns:

context manager for setting default dtype.

Return type:

ContextManager

.. rubric:: Example

with set_default_dtype(torch.bfloat16): x = torch.tensor([1, 2, 3]) x.dtype torch.bfloat16

fastvideo.v1.models.loader.fsdp_load.shard_model(model, *, cpu_offload: bool, reshard_after_forward: bool = True, mp_policy: Optional[torch.distributed.fsdp.MixedPrecisionPolicy] = None, dp_mesh: Optional[torch.distributed.DeviceMesh] = None, mesh: Optional[torch.distributed.DeviceMesh] = None) None[source]#

Utility to shard a model with FSDP using the PyTorch Distributed fully_shard API.

This method will over the model’s named modules from the bottom-up and apply shard modules based on whether they meet any of the criteria from shard_conditions.

Parameters:
  • model (TransformerDecoder) – Model to shard with FSDP.

  • shard_conditions (List[Callable[[str, nn.Module], bool]]) – A list of functions to determine which modules to shard with FSDP. Each function should take module name (relative to root) and the module itself, returning True if FSDP should shard the module and False otherwise. If any of shard_conditions return True for a given module, it will be sharded by FSDP.

  • cpu_offload (bool) – If set to True, FSDP will offload parameters, gradients, and optimizer states to CPU.

  • reshard_after_forward (bool) – Whether to reshard parameters and buffers after the forward pass. Setting this to True corresponds to the FULL_SHARD sharding strategy from FSDP1, while setting it to False corresponds to the SHARD_GRAD_OP sharding strategy.

  • dp_mesh (Optional[DeviceMesh]) – Device mesh to use for FSDP sharding under multiple parallelism. Default to None.

Raises:

ValueError – If no layer modules were sharded, indicating that no shard_condition was triggered.