fastvideo.v1.models.loader.fsdp_load
#
Module Contents#
Functions#
Creates a mapping function that transforms parameter names using regex patterns. |
|
Converting full state dict into a sharded state dict and loading it into FSDP model |
|
Context manager to set torchβs default dtype. |
|
Utility to shard a model with FSDP using the PyTorch Distributed fully_shard API. |
Data#
API#
- fastvideo.v1.models.loader.fsdp_load.get_param_names_mapping(mapping_dict: Dict[str, str]) Callable[[str], tuple[str, Any, Any]] [source]#
Creates a mapping function that transforms parameter names using regex patterns.
- fastvideo.v1.models.loader.fsdp_load.load_fsdp_model(model_cls: Type[torch.nn.Module], init_params: Dict[str, Any], weight_dir_list: List[str], device: torch.device, default_dtype: torch.dtype, param_dtype: torch.dtype, reduce_dtype: torch.dtype, cpu_offload: bool = False, output_dtype: Optional[torch.dtype] = None) torch.nn.Module [source]#
- fastvideo.v1.models.loader.fsdp_load.load_fsdp_model_from_full_model_state_dict(model: torch.nn.Module, full_sd_iterator: Generator[Tuple[str, torch.Tensor], None, None], device: torch.device, strict: bool = False, cpu_offload: bool = False, param_names_mapping: Optional[Callable[[str], tuple[str, Any, Any]]] = None) torch.nn.modules.module._IncompatibleKeys [source]#
Converting full state dict into a sharded state dict and loading it into FSDP model
- Parameters:
model (FSDPModule) β Model to generate fully qualified names for cpu_state_dict
full_sd_iterator (Generator) β an iterator yielding (param_name, tensor) pairs
device (torch.device) β device used to move full state dict tensors
strict (bool) β flag to check if to load the model in strict mode
cpu_offload (bool) β flag to check if offload to CPU is enabled
param_names_mapping (Optional[Callable[[str], str]]) β a function that maps full param name to sharded param name
- Returns:
missing_keys is a list of str containing the missing keys
unexpected_keys is a list of str containing the unexpected keys
- Return type:
NamedTuple
withmissing_keys
andunexpected_keys
fields- Raises:
NotImplementedError β If got FSDP with more than 1D.
- fastvideo.v1.models.loader.fsdp_load.set_default_dtype(dtype: torch.dtype) Generator[None, None, None] [source]#
Context manager to set torchβs default dtype.
- Parameters:
dtype (torch.dtype) β The desired default dtype inside the context manager.
- Returns:
context manager for setting default dtype.
- Return type:
ContextManager
.. rubric:: Example
with set_default_dtype(torch.bfloat16): x = torch.tensor([1, 2, 3]) x.dtype torch.bfloat16
- fastvideo.v1.models.loader.fsdp_load.shard_model(model, *, cpu_offload: bool, reshard_after_forward: bool = True, mp_policy: Optional[torch.distributed.fsdp.MixedPrecisionPolicy] = None, dp_mesh: Optional[torch.distributed.DeviceMesh] = None) None [source]#
Utility to shard a model with FSDP using the PyTorch Distributed fully_shard API.
This method will over the modelβs named modules from the bottom-up and apply shard modules based on whether they meet any of the criteria from shard_conditions.
- Parameters:
model (TransformerDecoder) β Model to shard with FSDP.
shard_conditions (List[Callable[[str, nn.Module], bool]]) β A list of functions to determine which modules to shard with FSDP. Each function should take module name (relative to root) and the module itself, returning True if FSDP should shard the module and False otherwise. If any of shard_conditions return True for a given module, it will be sharded by FSDP.
cpu_offload (bool) β If set to True, FSDP will offload parameters, gradients, and optimizer states to CPU.
reshard_after_forward (bool) β Whether to reshard parameters and buffers after the forward pass. Setting this to True corresponds to the FULL_SHARD sharding strategy from FSDP1, while setting it to False corresponds to the SHARD_GRAD_OP sharding strategy.
dp_mesh (Optional[DeviceMesh]) β Device mesh to use for FSDP sharding under multiple parallelism. Default to None.
- Raises:
ValueError β If no layer modules were sharded, indicating that no shard_condition was triggered.