Skip to content

loader

Modules

fastvideo.models.loader.component_loader

Classes

fastvideo.models.loader.component_loader.ComponentLoader
ComponentLoader(device=None)

Bases: ABC

Base class for loading a specific type of model component.

Source code in fastvideo/models/loader/component_loader.py
def __init__(self, device=None) -> None:
    self.device = device
Functions
fastvideo.models.loader.component_loader.ComponentLoader.for_module_type classmethod
for_module_type(module_type: str, transformers_or_diffusers: str) -> ComponentLoader

Factory method to create a component loader for a specific module type.

Parameters:

Name Type Description Default
module_type str

Type of module (e.g., "vae", "text_encoder", "transformer", "scheduler")

required
transformers_or_diffusers str

Whether the module is from transformers or diffusers

required

Returns:

Type Description
ComponentLoader

A component loader for the specified module type

Source code in fastvideo/models/loader/component_loader.py
@classmethod
def for_module_type(cls, module_type: str,
                    transformers_or_diffusers: str) -> 'ComponentLoader':
    """
    Factory method to create a component loader for a specific module type.

    Args:
        module_type: Type of module (e.g., "vae", "text_encoder", "transformer", "scheduler")
        transformers_or_diffusers: Whether the module is from transformers or diffusers

    Returns:
        A component loader for the specified module type
    """
    # Map of module types to their loader classes and expected library
    module_loaders = {
        "scheduler": (SchedulerLoader, "diffusers"),
        "transformer": (TransformerLoader, "diffusers"),
        "transformer_2": (TransformerLoader, "diffusers"),
        "vae": (VAELoader, "diffusers"),
        "text_encoder": (TextEncoderLoader, "transformers"),
        "text_encoder_2": (TextEncoderLoader, "transformers"),
        "tokenizer": (TokenizerLoader, "transformers"),
        "tokenizer_2": (TokenizerLoader, "transformers"),
        "image_processor": (ImageProcessorLoader, "transformers"),
        "image_encoder": (ImageEncoderLoader, "transformers"),
    }

    if module_type in module_loaders:
        loader_cls, expected_library = module_loaders[module_type]
        # Assert that the library matches what's expected for this module type
        assert transformers_or_diffusers == expected_library, f"{module_type} must be loaded from {expected_library}, got {transformers_or_diffusers}"
        return loader_cls()

    # For unknown module types, use a generic loader
    logger.warning(
        "No specific loader found for module type: %s. Using generic loader.",
        module_type)
    return GenericComponentLoader(transformers_or_diffusers)
fastvideo.models.loader.component_loader.ComponentLoader.load abstractmethod
load(model_path: str, fastvideo_args: FastVideoArgs)

Load the component based on the model path, architecture, and inference args.

Parameters:

Name Type Description Default
model_path str

Path to the component model

required
fastvideo_args FastVideoArgs

FastVideoArgs

required

Returns:

Type Description

The loaded component

Source code in fastvideo/models/loader/component_loader.py
@abstractmethod
def load(self, model_path: str, fastvideo_args: FastVideoArgs):
    """
    Load the component based on the model path, architecture, and inference args.

    Args:
        model_path: Path to the component model
        fastvideo_args: FastVideoArgs

    Returns:
        The loaded component
    """
    raise NotImplementedError
fastvideo.models.loader.component_loader.GenericComponentLoader
GenericComponentLoader(library='transformers')

Bases: ComponentLoader

Generic loader for components that don't have a specific loader.

Source code in fastvideo/models/loader/component_loader.py
def __init__(self, library="transformers") -> None:
    super().__init__()
    self.library = library
Functions
fastvideo.models.loader.component_loader.GenericComponentLoader.load
load(model_path: str, fastvideo_args: FastVideoArgs)

Load a generic component based on the model path, and inference args.

Source code in fastvideo/models/loader/component_loader.py
def load(self, model_path: str, fastvideo_args: FastVideoArgs):
    """Load a generic component based on the model path, and inference args."""
    logger.warning("Using generic loader for %s with library %s",
                   model_path, self.library)

    if self.library == "transformers":
        from transformers import AutoModel

        model = AutoModel.from_pretrained(
            model_path,
            trust_remote_code=fastvideo_args.trust_remote_code,
            revision=fastvideo_args.revision,
        )
        logger.info("Loaded generic transformers model: %s",
                    model.__class__.__name__)
        return model
    elif self.library == "diffusers":
        logger.warning(
            "Generic loading for diffusers components is not fully implemented"
        )

        model_config = get_diffusers_config(model=model_path)
        logger.info("Diffusers Model config: %s", model_config)
        # This is a placeholder - in a real implementation, you'd need to handle this properly
        return None
    else:
        raise ValueError(f"Unsupported library: {self.library}")
fastvideo.models.loader.component_loader.ImageEncoderLoader
ImageEncoderLoader(device=None)

Bases: TextEncoderLoader

Source code in fastvideo/models/loader/component_loader.py
def __init__(self, device=None) -> None:
    self.device = device
Functions
fastvideo.models.loader.component_loader.ImageEncoderLoader.load
load(model_path: str, fastvideo_args: FastVideoArgs)

Load the text encoders based on the model path, and inference args.

Source code in fastvideo/models/loader/component_loader.py
def load(self, model_path: str, fastvideo_args: FastVideoArgs):
    """Load the text encoders based on the model path, and inference args."""
    # model_config: PretrainedConfig = get_hf_config(
    #     model=model_path,
    #     trust_remote_code=fastvideo_args.trust_remote_code,
    #     revision=fastvideo_args.revision,
    #     model_override_args=None,
    # )
    with open(os.path.join(model_path, "config.json")) as f:
        model_config = json.load(f)
    model_config.pop("_name_or_path", None)
    model_config.pop("transformers_version", None)
    model_config.pop("torch_dtype", None)
    model_config.pop("model_type", None)
    logger.info("HF Model config: %s", model_config)

    encoder_config = fastvideo_args.pipeline_config.image_encoder_config
    encoder_config.update_model_arch(model_config)

    from fastvideo.platforms import current_platform

    if fastvideo_args.image_encoder_cpu_offload:
        target_device = torch.device("mps") if current_platform.is_mps() else torch.device("cpu")
    else:
        target_device = get_local_torch_device()
    # TODO(will): add support for other dtypes
    return self.load_model(
        model_path, encoder_config, target_device, fastvideo_args,
        fastvideo_args.pipeline_config.image_encoder_precision)
fastvideo.models.loader.component_loader.ImageProcessorLoader
ImageProcessorLoader(device=None)

Bases: ComponentLoader

Loader for image processor.

Source code in fastvideo/models/loader/component_loader.py
def __init__(self, device=None) -> None:
    self.device = device
Functions
fastvideo.models.loader.component_loader.ImageProcessorLoader.load
load(model_path: str, fastvideo_args: FastVideoArgs)

Load the image processor based on the model path, and inference args.

Source code in fastvideo/models/loader/component_loader.py
def load(self, model_path: str, fastvideo_args: FastVideoArgs):
    """Load the image processor based on the model path, and inference args."""
    logger.info("Loading image processor from %s", model_path)

    image_processor = AutoImageProcessor.from_pretrained(model_path, )
    logger.info("Loaded image processor: %s",
                image_processor.__class__.__name__)
    return image_processor
fastvideo.models.loader.component_loader.PipelineComponentLoader

Utility class for loading pipeline components. This replaces the chain of if-else statements in load_pipeline_module.

Functions
fastvideo.models.loader.component_loader.PipelineComponentLoader.load_module staticmethod
load_module(module_name: str, component_model_path: str, transformers_or_diffusers: str, fastvideo_args: FastVideoArgs)

Load a pipeline module.

Parameters:

Name Type Description Default
module_name str

Name of the module (e.g., "vae", "text_encoder", "transformer", "scheduler")

required
component_model_path str

Path to the component model

required
transformers_or_diffusers str

Whether the module is from transformers or diffusers

required
pipeline_args

Inference arguments

required

Returns:

Type Description

The loaded module

Source code in fastvideo/models/loader/component_loader.py
@staticmethod
def load_module(module_name: str, component_model_path: str,
                transformers_or_diffusers: str,
                fastvideo_args: FastVideoArgs):
    """
    Load a pipeline module.

    Args:
        module_name: Name of the module (e.g., "vae", "text_encoder", "transformer", "scheduler")
        component_model_path: Path to the component model
        transformers_or_diffusers: Whether the module is from transformers or diffusers
        pipeline_args: Inference arguments

    Returns:
        The loaded module
    """
    logger.info(
        "Loading %s using %s from %s",
        module_name,
        transformers_or_diffusers,
        component_model_path,
    )

    # Get the appropriate loader for this module type
    loader = ComponentLoader.for_module_type(module_name,
                                             transformers_or_diffusers)

    # Load the module
    return loader.load(component_model_path, fastvideo_args)
fastvideo.models.loader.component_loader.SchedulerLoader
SchedulerLoader(device=None)

Bases: ComponentLoader

Loader for scheduler.

Source code in fastvideo/models/loader/component_loader.py
def __init__(self, device=None) -> None:
    self.device = device
Functions
fastvideo.models.loader.component_loader.SchedulerLoader.load
load(model_path: str, fastvideo_args: FastVideoArgs)

Load the scheduler based on the model path, and inference args.

Source code in fastvideo/models/loader/component_loader.py
def load(self, model_path: str, fastvideo_args: FastVideoArgs):
    """Load the scheduler based on the model path, and inference args."""
    config = get_diffusers_config(model=model_path)

    class_name = config.pop("_class_name")
    assert class_name is not None, "Model config does not contain a _class_name attribute. Only diffusers format is supported."

    scheduler_cls, _ = ModelRegistry.resolve_model_cls(class_name)

    scheduler = scheduler_cls(**config)
    if fastvideo_args.pipeline_config.flow_shift is not None:
        scheduler.set_shift(fastvideo_args.pipeline_config.flow_shift)
    if fastvideo_args.pipeline_config.timesteps_scale is not None:
        scheduler.set_timesteps_scale(
            fastvideo_args.pipeline_config.timesteps_scale)
    return scheduler
fastvideo.models.loader.component_loader.TextEncoderLoader
TextEncoderLoader(device=None)

Bases: ComponentLoader

Loader for text encoders.

Source code in fastvideo/models/loader/component_loader.py
def __init__(self, device=None) -> None:
    self.device = device
Classes
fastvideo.models.loader.component_loader.TextEncoderLoader.Source dataclass
Source(model_or_path: str, prefix: str = '', fall_back_to_pt: bool = True, allow_patterns_overrides: list[str] | None = None)

A source for weights.

Attributes
fastvideo.models.loader.component_loader.TextEncoderLoader.Source.allow_patterns_overrides class-attribute instance-attribute
allow_patterns_overrides: list[str] | None = None

If defined, weights will load exclusively using these patterns.

fastvideo.models.loader.component_loader.TextEncoderLoader.Source.fall_back_to_pt class-attribute instance-attribute
fall_back_to_pt: bool = True

Whether .pt weights can be used.

fastvideo.models.loader.component_loader.TextEncoderLoader.Source.model_or_path instance-attribute
model_or_path: str

The model ID or path.

fastvideo.models.loader.component_loader.TextEncoderLoader.Source.prefix class-attribute instance-attribute
prefix: str = ''

A prefix to prepend to all weights.

Functions
fastvideo.models.loader.component_loader.TextEncoderLoader.load
load(model_path: str, fastvideo_args: FastVideoArgs)

Load the text encoders based on the model path, and inference args.

Source code in fastvideo/models/loader/component_loader.py
def load(self, model_path: str, fastvideo_args: FastVideoArgs):
    """Load the text encoders based on the model path, and inference args."""
    # model_config: PretrainedConfig = get_hf_config(
    #     model=model_path,
    #     trust_remote_code=fastvideo_args.trust_remote_code,
    #     revision=fastvideo_args.revision,
    #     model_override_args=None,
    # )
    model_config = get_diffusers_config(model=model_path)
    model_config.pop("_name_or_path", None)
    model_config.pop("transformers_version", None)
    model_config.pop("model_type", None)
    model_config.pop("tokenizer_class", None)
    model_config.pop("torch_dtype", None)
    logger.info("HF Model config: %s", model_config)

    # @TODO(Wei): Better way to handle this?
    try:
        encoder_config = fastvideo_args.pipeline_config.text_encoder_configs[
            0]
        encoder_config.update_model_arch(model_config)
        encoder_precision = fastvideo_args.pipeline_config.text_encoder_precisions[
            0]
    except Exception:
        encoder_config = fastvideo_args.pipeline_config.text_encoder_configs[
            1]
        encoder_config.update_model_arch(model_config)
        encoder_precision = fastvideo_args.pipeline_config.text_encoder_precisions[
            1]

    target_device = get_local_torch_device()
    # TODO(will): add support for other dtypes
    return self.load_model(model_path, encoder_config, target_device,
                           fastvideo_args, encoder_precision)
fastvideo.models.loader.component_loader.TokenizerLoader
TokenizerLoader(device=None)

Bases: ComponentLoader

Loader for tokenizers.

Source code in fastvideo/models/loader/component_loader.py
def __init__(self, device=None) -> None:
    self.device = device
Functions
fastvideo.models.loader.component_loader.TokenizerLoader.load
load(model_path: str, fastvideo_args: FastVideoArgs)

Load the tokenizer based on the model path, and inference args.

Source code in fastvideo/models/loader/component_loader.py
def load(self, model_path: str, fastvideo_args: FastVideoArgs):
    """Load the tokenizer based on the model path, and inference args."""
    logger.info("Loading tokenizer from %s", model_path)

    tokenizer = AutoTokenizer.from_pretrained(
        model_path,  # "<path to model>/tokenizer"
        # in v0, this was same string as encoder_name "ClipTextModel"
        # TODO(will): pass these tokenizer kwargs from inference args? Maybe
        # other method of config?
        padding_size='right',
    )
    logger.info("Loaded tokenizer: %s", tokenizer.__class__.__name__)
    return tokenizer
fastvideo.models.loader.component_loader.TransformerLoader
TransformerLoader(device=None)

Bases: ComponentLoader

Loader for transformer.

Source code in fastvideo/models/loader/component_loader.py
def __init__(self, device=None) -> None:
    self.device = device
Functions
fastvideo.models.loader.component_loader.TransformerLoader.load
load(model_path: str, fastvideo_args: FastVideoArgs)

Load the transformer based on the model path, and inference args.

Source code in fastvideo/models/loader/component_loader.py
def load(self, model_path: str, fastvideo_args: FastVideoArgs):
    """Load the transformer based on the model path, and inference args."""
    config = get_diffusers_config(model=model_path)
    hf_config = deepcopy(config)
    cls_name = config.pop("_class_name")
    if cls_name is None:
        raise ValueError(
            "Model config does not contain a _class_name attribute. "
            "Only diffusers format is supported.")

    logger.info("transformer cls_name: %s", cls_name)
    if fastvideo_args.override_transformer_cls_name is not None:
        cls_name = fastvideo_args.override_transformer_cls_name
        logger.info("Overriding transformer cls_name to %s", cls_name)

    fastvideo_args.model_paths["transformer"] = model_path

    # Config from Diffusers supersedes fastvideo's model config
    dit_config = fastvideo_args.pipeline_config.dit_config
    dit_config.update_model_arch(config)

    model_cls, _ = ModelRegistry.resolve_model_cls(cls_name)

    # Find all safetensors files
    safetensors_list = glob.glob(
        os.path.join(str(model_path), "*.safetensors"))
    if not safetensors_list:
        raise ValueError(f"No safetensors files found in {model_path}")

    # Check if we should use custom initialization weights
    custom_weights_path = getattr(fastvideo_args, 'init_weights_from_safetensors', None)
    use_custom_weights = (custom_weights_path and os.path.exists(custom_weights_path) and 
                        not hasattr(fastvideo_args, '_loading_teacher_critic_model'))

    if use_custom_weights:
        if 'transformer_2' in model_path:
            custom_weights_path = getattr(fastvideo_args, 'init_weights_from_safetensors_2', None)
        assert custom_weights_path is not None, "Custom initialization weights must be provided"
        if os.path.isdir(custom_weights_path):
            safetensors_list = glob.glob(
                os.path.join(str(custom_weights_path), "*.safetensors"))
        else:
            assert custom_weights_path.endswith(".safetensors"), "Custom initialization weights must be a safetensors file"
            safetensors_list = [custom_weights_path]

    logger.info("Loading model from %s safetensors files: %s",
                len(safetensors_list), safetensors_list)

    default_dtype = PRECISION_TO_TYPE[
        fastvideo_args.pipeline_config.dit_precision]

    # Load the model using FSDP loader
    logger.info("Loading model from %s, default_dtype: %s", cls_name,
                default_dtype)
    assert fastvideo_args.hsdp_shard_dim is not None
    model = maybe_load_fsdp_model(
        model_cls=model_cls,
        init_params={
            "config": dit_config,
            "hf_config": hf_config
        },
        weight_dir_list=safetensors_list,
        device=get_local_torch_device(),
        hsdp_replicate_dim=fastvideo_args.hsdp_replicate_dim,
        hsdp_shard_dim=fastvideo_args.hsdp_shard_dim,
        cpu_offload=fastvideo_args.dit_cpu_offload,
        pin_cpu_memory=fastvideo_args.pin_cpu_memory,
        fsdp_inference=fastvideo_args.use_fsdp_inference,
        # TODO(will): make these configurable
        default_dtype=default_dtype,
        param_dtype=torch.bfloat16,
        reduce_dtype=torch.float32,
        output_dtype=None,
        training_mode=fastvideo_args.training_mode,
        enable_torch_compile=fastvideo_args.enable_torch_compile,
        torch_compile_kwargs=fastvideo_args.torch_compile_kwargs)


    total_params = sum(p.numel() for p in model.parameters())
    logger.info("Loaded model with %.2fB parameters", total_params / 1e9)

    assert next(model.parameters()).dtype == default_dtype, "Model dtype does not match default dtype"

    model = model.eval()
    return model
fastvideo.models.loader.component_loader.VAELoader
VAELoader(device=None)

Bases: ComponentLoader

Loader for VAE.

Source code in fastvideo/models/loader/component_loader.py
def __init__(self, device=None) -> None:
    self.device = device
Functions
fastvideo.models.loader.component_loader.VAELoader.load
load(model_path: str, fastvideo_args: FastVideoArgs)

Load the VAE based on the model path, and inference args.

Source code in fastvideo/models/loader/component_loader.py
def load(self, model_path: str, fastvideo_args: FastVideoArgs):
    """Load the VAE based on the model path, and inference args."""
    config = get_diffusers_config(model=model_path)
    class_name = config.pop("_class_name")
    assert class_name is not None, "Model config does not contain a _class_name attribute. Only diffusers format is supported."
    fastvideo_args.model_paths["vae"] = model_path

    vae_config = fastvideo_args.pipeline_config.vae_config
    vae_config.update_model_arch(config)

    from fastvideo.platforms import current_platform

    if fastvideo_args.vae_cpu_offload:
        target_device = torch.device("mps") if current_platform.is_mps() else torch.device("cpu")
    else:
        target_device = get_local_torch_device()

    with set_default_torch_dtype(PRECISION_TO_TYPE[
            fastvideo_args.pipeline_config.vae_precision]):
        vae_cls, _ = ModelRegistry.resolve_model_cls(class_name)
        vae = vae_cls(vae_config).to(target_device)

    # Find all safetensors files
    safetensors_list = glob.glob(
        os.path.join(str(model_path), "*.safetensors"))
    # TODO(PY)
    assert len(
        safetensors_list
    ) == 1, f"Found {len(safetensors_list)} safetensors files in {model_path}"
    loaded = safetensors_load_file(safetensors_list[0])
    vae.load_state_dict(
        loaded, strict=False)  # We might only load encoder or decoder

    return vae.eval()

Functions

fastvideo.models.loader.fsdp_load

Functions

fastvideo.models.loader.fsdp_load.load_model_from_full_model_state_dict
load_model_from_full_model_state_dict(model: FSDPModule | Module, full_sd_iterator: Generator[tuple[str, Tensor], None, None], device: device, param_dtype: dtype, strict: bool = False, cpu_offload: bool = False, param_names_mapping: Callable[[str], tuple[str, Any, Any]] | None = None, training_mode: bool = True) -> _IncompatibleKeys

Converting full state dict into a sharded state dict and loading it into FSDP model (if training) or normal huggingface model Args: model (Union[FSDPModule, torch.nn.Module]): Model to generate fully qualified names for cpu_state_dict full_sd_iterator (Generator): an iterator yielding (param_name, tensor) pairs device (torch.device): device used to move full state dict tensors param_dtype (torch.dtype): dtype used to move full state dict tensors strict (bool): flag to check if to load the model in strict mode cpu_offload (bool): flag to check if FSDP offload is enabled param_names_mapping (Optional[Callable[[str], str]]): a function that maps full param name to sharded param name training_mode (bool): apply FSDP only for training Returns: NamedTuple with missing_keys and unexpected_keys fields: * missing_keys is a list of str containing the missing keys * unexpected_keys is a list of str containing the unexpected keys

Raises:

Type Description
NotImplementedError

If got FSDP with more than 1D.

Source code in fastvideo/models/loader/fsdp_load.py
def load_model_from_full_model_state_dict(
    model: FSDPModule | torch.nn.Module,
    full_sd_iterator: Generator[tuple[str, torch.Tensor], None, None],
    device: torch.device,
    param_dtype: torch.dtype,
    strict: bool = False,
    cpu_offload: bool = False,
    param_names_mapping: Callable[[str], tuple[str, Any, Any]] | None = None,
    training_mode: bool = True,
) -> _IncompatibleKeys:
    """
    Converting full state dict into a sharded state dict
    and loading it into FSDP model (if training) or normal huggingface model
    Args:
        model (Union[FSDPModule, torch.nn.Module]): Model to generate fully qualified names for cpu_state_dict
        full_sd_iterator (Generator): an iterator yielding (param_name, tensor) pairs
        device (torch.device): device used to move full state dict tensors
        param_dtype (torch.dtype): dtype used to move full state dict tensors
        strict (bool): flag to check if to load the model in strict mode
        cpu_offload (bool): flag to check if FSDP offload is enabled
        param_names_mapping (Optional[Callable[[str], str]]): a function that maps full param name to sharded param name
        training_mode (bool): apply FSDP only for training
    Returns:
        ``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields:
            * **missing_keys** is a list of str containing the missing keys
            * **unexpected_keys** is a list of str containing the unexpected keys

    Raises:
        NotImplementedError: If got FSDP with more than 1D.
    """
    meta_sd = model.state_dict()
    sharded_sd = {}
    custom_param_sd, reverse_param_names_mapping = hf_to_custom_state_dict(
        full_sd_iterator, param_names_mapping)  # type: ignore
    for target_param_name, full_tensor in custom_param_sd.items():
        meta_sharded_param = meta_sd.get(target_param_name)
        if meta_sharded_param is None:
            raise ValueError(
                f"Parameter {target_param_name} not found in custom model state dict. The hf to custom mapping may be incorrect."
            )
        if not hasattr(meta_sharded_param, "device_mesh"):
            full_tensor = full_tensor.to(device=device, dtype=param_dtype)
            # In cases where parts of the model aren't sharded, some parameters will be plain tensors
            sharded_tensor = full_tensor
        else:
            full_tensor = full_tensor.to(device=device, dtype=param_dtype)
            sharded_tensor = distribute_tensor(
                full_tensor,
                meta_sharded_param.device_mesh,
                meta_sharded_param.placements,
            )
            if cpu_offload:
                sharded_tensor = sharded_tensor.cpu()
        sharded_sd[target_param_name] = nn.Parameter(sharded_tensor)

    model.reverse_param_names_mapping = reverse_param_names_mapping
    unused_keys = set(meta_sd.keys()) - set(sharded_sd.keys())
    if unused_keys:
        logger.warning("Found unloaded parameters in meta state dict: %s",
                       unused_keys)

    # List of allowed parameter name patterns
    ALLOWED_NEW_PARAM_PATTERNS = ["gate_compress"]  # Can be extended as needed
    for new_param_name in unused_keys:
        if not any(pattern in new_param_name
                   for pattern in ALLOWED_NEW_PARAM_PATTERNS):
            logger.error("Unsupported new parameter: %s. Allowed patterns: %s",
                         new_param_name, ALLOWED_NEW_PARAM_PATTERNS)
            raise ValueError(
                f"New parameter '{new_param_name}' is not supported. "
                f"Currently only parameters containing {ALLOWED_NEW_PARAM_PATTERNS} are allowed."
            )
        meta_sharded_param = meta_sd.get(new_param_name)
        if not hasattr(meta_sharded_param, "device_mesh"):
            # Initialize with zeros
            sharded_tensor = torch.zeros_like(meta_sharded_param,
                                              device=device,
                                              dtype=param_dtype)
        else:
            # Initialize with zeros and distribute
            full_tensor = torch.zeros_like(meta_sharded_param,
                                           device=device,
                                           dtype=param_dtype)
            sharded_tensor = distribute_tensor(
                full_tensor,
                meta_sharded_param.device_mesh,
                meta_sharded_param.placements,
            )
            if cpu_offload:
                sharded_tensor = sharded_tensor.cpu()
        sharded_sd[new_param_name] = nn.Parameter(sharded_tensor)

    # choose `assign=True` since we cannot call `copy_` on meta tensor
    return model.load_state_dict(sharded_sd, strict=strict, assign=True)
fastvideo.models.loader.fsdp_load.maybe_load_fsdp_model
maybe_load_fsdp_model(model_cls: type[Module], init_params: dict[str, Any], weight_dir_list: list[str], device: device, hsdp_replicate_dim: int, hsdp_shard_dim: int, default_dtype: dtype, param_dtype: dtype, reduce_dtype: dtype, cpu_offload: bool = False, fsdp_inference: bool = False, output_dtype: dtype | None = None, training_mode: bool = True, pin_cpu_memory: bool = True, enable_torch_compile: bool = False, torch_compile_kwargs: dict[str, Any] | None = None) -> Module

Load the model with FSDP if is training, else load the model without FSDP.

Source code in fastvideo/models/loader/fsdp_load.py
def maybe_load_fsdp_model(
    model_cls: type[nn.Module],
    init_params: dict[str, Any],
    weight_dir_list: list[str],
    device: torch.device,
    hsdp_replicate_dim: int,
    hsdp_shard_dim: int,
    default_dtype: torch.dtype,
    param_dtype: torch.dtype,
    reduce_dtype: torch.dtype,
    cpu_offload: bool = False,
    fsdp_inference: bool = False,
    output_dtype: torch.dtype | None = None,
    training_mode: bool = True,
    pin_cpu_memory: bool = True,
    enable_torch_compile: bool = False,
    torch_compile_kwargs: dict[str, Any] | None = None,
) -> torch.nn.Module:
    """
    Load the model with FSDP if is training, else load the model without FSDP.
    """
    # NOTE(will): cast_forward_inputs=True shouldn't be needed as we are
    # manually casting the inputs to the model
    mp_policy = MixedPrecisionPolicy(param_dtype,
                                     reduce_dtype,
                                     output_dtype,
                                     cast_forward_inputs=False)

    set_mixed_precision_policy(
        param_dtype=param_dtype,
        reduce_dtype=reduce_dtype,
        output_dtype=output_dtype,
        mp_policy=mp_policy,
    )

    logger.info("Loading model with default_dtype: %s", default_dtype)
    with set_default_dtype(default_dtype), torch.device("meta"):
        model = model_cls(**init_params)

    # Check if we should use FSDP
    use_fsdp = training_mode or fsdp_inference

    # Disable FSDP for MPS as it's not compatible
    from fastvideo.platforms import current_platform
    if current_platform.is_mps():
        use_fsdp = False
        logger.info("Disabling FSDP for MPS platform as it's not compatible")

    if use_fsdp:
        world_size = hsdp_replicate_dim * hsdp_shard_dim
        if not training_mode and not fsdp_inference:
            hsdp_replicate_dim = world_size
            hsdp_shard_dim = 1

        if current_platform.is_npu():
            with torch.device("cpu"):
                device_mesh = init_device_mesh(
                    "npu",
                    # (Replicate(), Shard(dim=0))
                    mesh_shape=(hsdp_replicate_dim, hsdp_shard_dim),
                    mesh_dim_names=("replicate", "shard"),
                )
        else:
            device_mesh = init_device_mesh(
            "cuda",
            # (Replicate(), Shard(dim=0))
            mesh_shape=(hsdp_replicate_dim, hsdp_shard_dim),
            mesh_dim_names=("replicate", "shard"),
        )
        shard_model(model,
                    cpu_offload=cpu_offload,
                    reshard_after_forward=True,
                    mp_policy=mp_policy,
                    mesh=device_mesh,
                    fsdp_shard_conditions=model._fsdp_shard_conditions,
                    pin_cpu_memory=pin_cpu_memory)

    weight_iterator = safetensors_weights_iterator(weight_dir_list)
    param_names_mapping_fn = get_param_names_mapping(model.param_names_mapping)
    load_model_from_full_model_state_dict(
        model,
        weight_iterator,
        device,
        default_dtype,
        strict=True,
        cpu_offload=cpu_offload,
        param_names_mapping=param_names_mapping_fn,
    )
    for n, p in chain(model.named_parameters(), model.named_buffers()):
        if p.is_meta:
            raise RuntimeError(
                f"Unexpected param or buffer {n} on meta device.")
        # Avoid unintended computation graph accumulation during inference
        if isinstance(p, torch.nn.Parameter):
            p.requires_grad = False

    compile_in_loader = enable_torch_compile and training_mode
    if compile_in_loader:
        compile_kwargs = torch_compile_kwargs or {}
        logger.info("Enabling torch.compile for FSDP training module with kwargs=%s",
                    compile_kwargs)
        model = torch.compile(model, **compile_kwargs)
        logger.info("torch.compile enabled for %s", type(model).__name__)
    return model
fastvideo.models.loader.fsdp_load.set_default_dtype
set_default_dtype(dtype: dtype) -> Generator[None, None, None]

Context manager to set torch's default dtype.

Parameters:

Name Type Description Default
dtype dtype

The desired default dtype inside the context manager.

required

Returns:

Name Type Description
ContextManager None

context manager for setting default dtype.

Example

with set_default_dtype(torch.bfloat16): x = torch.tensor([1, 2, 3]) x.dtype torch.bfloat16

Source code in fastvideo/models/loader/fsdp_load.py
@contextlib.contextmanager
def set_default_dtype(dtype: torch.dtype) -> Generator[None, None, None]:
    """
    Context manager to set torch's default dtype.

    Args:
        dtype (torch.dtype): The desired default dtype inside the context manager.

    Returns:
        ContextManager: context manager for setting default dtype.

    Example:
        >>> with set_default_dtype(torch.bfloat16):
        >>>     x = torch.tensor([1, 2, 3])
        >>>     x.dtype
        torch.bfloat16


    """
    old_dtype = torch.get_default_dtype()
    torch.set_default_dtype(dtype)
    try:
        yield
    finally:
        torch.set_default_dtype(old_dtype)
fastvideo.models.loader.fsdp_load.shard_model
shard_model(model, *, cpu_offload: bool, reshard_after_forward: bool = True, mp_policy: MixedPrecisionPolicy | None = MixedPrecisionPolicy(), mesh: DeviceMesh | None = None, fsdp_shard_conditions: list[Callable[[str, Module], bool]] = [], pin_cpu_memory: bool = True) -> None

Utility to shard a model with FSDP using the PyTorch Distributed fully_shard API.

This method will over the model's named modules from the bottom-up and apply shard modules based on whether they meet any of the criteria from shard_conditions.

Parameters:

Name Type Description Default
model TransformerDecoder

Model to shard with FSDP.

required
shard_conditions List[Callable[[str, Module], bool]]

A list of functions to determine which modules to shard with FSDP. Each function should take module name (relative to root) and the module itself, returning True if FSDP should shard the module and False otherwise. If any of shard_conditions return True for a given module, it will be sharded by FSDP.

required
cpu_offload bool

If set to True, FSDP will offload parameters, gradients, and optimizer states to CPU.

required
reshard_after_forward bool

Whether to reshard parameters and buffers after the forward pass. Setting this to True corresponds to the FULL_SHARD sharding strategy from FSDP1, while setting it to False corresponds to the SHARD_GRAD_OP sharding strategy.

True
mesh Optional[DeviceMesh]

Device mesh to use for FSDP sharding under multiple parallelism. Default to None.

None
fsdp_shard_conditions List[Callable[[str, Module], bool]]

A list of functions to determine which modules to shard with FSDP.

[]
pin_cpu_memory bool

If set to True, FSDP will pin the CPU memory of the offloaded parameters.

True

Raises:

Type Description
ValueError

If no layer modules were sharded, indicating that no shard_condition was triggered.

Source code in fastvideo/models/loader/fsdp_load.py
def shard_model(
    model,
    *,
    cpu_offload: bool,
    reshard_after_forward: bool = True,
    mp_policy: MixedPrecisionPolicy | None = MixedPrecisionPolicy(),  # noqa
    mesh: DeviceMesh | None = None,
    fsdp_shard_conditions: list[Callable[[str, nn.Module], bool]] = [],  # noqa
    pin_cpu_memory: bool = True,
) -> None:
    """
    Utility to shard a model with FSDP using the PyTorch Distributed fully_shard API.

    This method will over the model's named modules from the bottom-up and apply shard modules
    based on whether they meet any of the criteria from shard_conditions.

    Args:
        model (TransformerDecoder): Model to shard with FSDP.
        shard_conditions (List[Callable[[str, nn.Module], bool]]): A list of functions to determine
            which modules to shard with FSDP. Each function should take module name (relative to root)
            and the module itself, returning True if FSDP should shard the module and False otherwise.
            If any of shard_conditions return True for a given module, it will be sharded by FSDP.
        cpu_offload (bool): If set to True, FSDP will offload parameters, gradients, and optimizer
            states to CPU.
        reshard_after_forward (bool): Whether to reshard parameters and buffers after
            the forward pass. Setting this to True corresponds to the FULL_SHARD sharding strategy
            from FSDP1, while setting it to False corresponds to the SHARD_GRAD_OP sharding strategy.
        mesh (Optional[DeviceMesh]): Device mesh to use for FSDP sharding under multiple parallelism.
            Default to None.
        fsdp_shard_conditions (List[Callable[[str, nn.Module], bool]]): A list of functions to determine
            which modules to shard with FSDP.
        pin_cpu_memory (bool): If set to True, FSDP will pin the CPU memory of the offloaded parameters.

    Raises:
        ValueError: If no layer modules were sharded, indicating that no shard_condition was triggered.
    """
    # Check if we should use size-based filtering
    use_size_filtering = os.environ.get("FASTVIDEO_FSDP2_AUTOWRAP", "0") == "1"

    if not fsdp_shard_conditions:
        logger.warning("No FSDP shard conditions provided; nothing will be sharded.")
        return

    fsdp_kwargs = {
        "reshard_after_forward": reshard_after_forward,
        "mesh": mesh,
        "mp_policy": mp_policy,
    }
    if cpu_offload:
        fsdp_kwargs["offload_policy"] = CPUOffloadPolicy(
            pin_memory=pin_cpu_memory)

    # iterating in reverse to start with
    # lowest-level modules first
    num_layers_sharded = 0

    if use_size_filtering:
        # Size-based filtering mode
        min_params = int(os.environ.get("FASTVIDEO_FSDP2_MIN_PARAMS", "10000000"))
        logger.info("Using size-based filtering with threshold: %.2fM", min_params / 1e6)

        for n, m in reversed(list(model.named_modules())):
            if any([shard_condition(n, m) for shard_condition in fsdp_shard_conditions]):
                # Count all parameters
                param_count = sum(p.numel() for p in m.parameters(recurse=True))

                # Skip small modules
                if param_count < min_params:
                    logger.info("Skipping module %s (%.2fM params < %.2fM threshold)", 
                               n, param_count / 1e6, min_params / 1e6)
                    continue

                # Shard this module
                logger.info("Sharding module %s (%.2fM params)", n, param_count / 1e6)
                fully_shard(m, **fsdp_kwargs)
                num_layers_sharded += 1
    else:
        # Shard all modules matching conditions        
        for n, m in reversed(list(model.named_modules())):
            if any([shard_condition(n, m) for shard_condition in fsdp_shard_conditions]):
                fully_shard(m, **fsdp_kwargs)
                num_layers_sharded += 1

        if num_layers_sharded == 0:
            raise ValueError(
                "No layer modules were sharded. Please check if shard conditions are working as expected."
            )

    # Finally shard the entire model to account for any stragglers
    fully_shard(model, **fsdp_kwargs)

fastvideo.models.loader.utils

Utilities for selecting and loading models.

Functions

fastvideo.models.loader.utils.get_param_names_mapping
get_param_names_mapping(mapping_dict: dict[str, str]) -> Callable[[str], tuple[str, Any, Any]]

Creates a mapping function that transforms parameter names using regex patterns.

Parameters:

Name Type Description Default
mapping_dict Dict[str, str]

Dictionary mapping regex patterns to replacement patterns

required
param_name str

The parameter name to be transformed

required

Returns:

Type Description
Callable[[str], tuple[str, Any, Any]]

Callable[[str], str]: A function that maps parameter names from source to target format

Source code in fastvideo/models/loader/utils.py
def get_param_names_mapping(
        mapping_dict: dict[str, str]) -> Callable[[str], tuple[str, Any, Any]]:
    """
    Creates a mapping function that transforms parameter names using regex patterns.

    Args:
        mapping_dict (Dict[str, str]): Dictionary mapping regex patterns to replacement patterns
        param_name (str): The parameter name to be transformed

    Returns:
        Callable[[str], str]: A function that maps parameter names from source to target format
    """

    def mapping_fn(name: str) -> tuple[str, Any, Any]:
        # Try to match and transform the name using the regex patterns in mapping_dict
        for pattern, replacement in mapping_dict.items():
            match = re.match(pattern, name)
            if match:
                merge_index = None
                total_splitted_params = None
                if isinstance(replacement, tuple):
                    merge_index = replacement[1]
                    total_splitted_params = replacement[2]
                    replacement = replacement[0]
                name = re.sub(pattern, replacement, name)
                return name, merge_index, total_splitted_params

        # If no pattern matches, return the original name
        return name, None, None

    return mapping_fn
fastvideo.models.loader.utils.hf_to_custom_state_dict
hf_to_custom_state_dict(hf_param_sd: dict[str, Tensor] | Iterator[tuple[str, Tensor]], param_names_mapping: Callable[[str], tuple[str, Any, Any]]) -> tuple[dict[str, Tensor], dict[str, tuple[str, Any, Any]]]

Converts a Hugging Face parameter state dictionary to a custom parameter state dictionary.

Parameters:

Name Type Description Default
hf_param_sd Dict[str, Tensor]

The Hugging Face parameter state dictionary

required
param_names_mapping Callable[[str], tuple[str, Any, Any]]

A function that maps parameter names from source to target format

required

Returns:

Name Type Description
custom_param_sd Dict[str, Tensor]

The custom formatted parameter state dict

reverse_param_names_mapping Dict[str, Tuple[str, Any, Any]]

Maps back from custom to hf

Source code in fastvideo/models/loader/utils.py
def hf_to_custom_state_dict(
    hf_param_sd: dict[str, torch.Tensor] | Iterator[tuple[str, torch.Tensor]],
    param_names_mapping: Callable[[str], tuple[str, Any, Any]]
) -> tuple[dict[str, torch.Tensor], dict[str, tuple[str, Any, Any]]]:
    """
    Converts a Hugging Face parameter state dictionary to a custom parameter state dictionary.

    Args:
        hf_param_sd (Dict[str, torch.Tensor]): The Hugging Face parameter state dictionary
        param_names_mapping (Callable[[str], tuple[str, Any, Any]]): A function that maps parameter names from source to target format

    Returns:
        custom_param_sd (Dict[str, torch.Tensor]): The custom formatted parameter state dict
        reverse_param_names_mapping (Dict[str, Tuple[str, Any, Any]]): Maps back from custom to hf
    """
    custom_param_sd = {}
    to_merge_params = defaultdict(dict)  # type: ignore
    reverse_param_names_mapping = {}
    if isinstance(hf_param_sd, dict):
        hf_param_sd = hf_param_sd.items()  # type: ignore
    for source_param_name, full_tensor in hf_param_sd:  # type: ignore
        target_param_name, merge_index, num_params_to_merge = param_names_mapping(
            source_param_name)
        reverse_param_names_mapping[target_param_name] = (source_param_name,
                                                          merge_index,
                                                          num_params_to_merge)
        if merge_index is not None:
            to_merge_params[target_param_name][merge_index] = full_tensor
            if len(to_merge_params[target_param_name]) == num_params_to_merge:
                # cat at output dim according to the merge_index order
                sorted_tensors = [
                    to_merge_params[target_param_name][i]
                    for i in range(num_params_to_merge)
                ]
                full_tensor = torch.cat(sorted_tensors, dim=0)
                del to_merge_params[target_param_name]
            else:
                continue
        custom_param_sd[target_param_name] = full_tensor
    return custom_param_sd, reverse_param_names_mapping
fastvideo.models.loader.utils.set_default_torch_dtype
set_default_torch_dtype(dtype: dtype)

Sets the default torch dtype to the given dtype.

Source code in fastvideo/models/loader/utils.py
@contextlib.contextmanager
def set_default_torch_dtype(dtype: torch.dtype):
    """Sets the default torch dtype to the given dtype."""
    old_dtype = torch.get_default_dtype()
    torch.set_default_dtype(dtype)
    yield
    torch.set_default_dtype(old_dtype)

fastvideo.models.loader.weight_utils

Utilities for downloading and initializing model weights.

Functions

fastvideo.models.loader.weight_utils.default_weight_loader
default_weight_loader(param: Tensor, loaded_weight: Tensor) -> None

Default weight loader.

Source code in fastvideo/models/loader/weight_utils.py
def default_weight_loader(param: torch.Tensor,
                          loaded_weight: torch.Tensor) -> None:
    """Default weight loader."""
    try:
        if param.numel() == 1 and loaded_weight.numel() == 1:
            # Sometimes scalar values aren't considered tensors with shapes
            # so if both param and loaded_weight are a scalar,
            # "broadcast" instead of copy
            param.data.fill_(loaded_weight.item())
        else:
            assert param.size() == loaded_weight.size(), (
                f"Attempted to load weight ({loaded_weight.size()}) "
                f"into parameter ({param.size()})")

            param.data.copy_(loaded_weight)
    except Exception:
        # NOTE: This exception is added for the purpose of setting breakpoint to
        # debug weight loading issues.
        raise
fastvideo.models.loader.weight_utils.enable_hf_transfer
enable_hf_transfer() -> None

automatically activates hf_transfer

Source code in fastvideo/models/loader/weight_utils.py
def enable_hf_transfer() -> None:
    """automatically activates hf_transfer
    """
    if "HF_HUB_ENABLE_HF_TRANSFER" not in os.environ:
        try:
            # enable hf hub transfer if available
            import hf_transfer  # type: ignore # noqa
            huggingface_hub.constants.HF_HUB_ENABLE_HF_TRANSFER = True
        except ImportError:
            pass
fastvideo.models.loader.weight_utils.filter_files_not_needed_for_inference
filter_files_not_needed_for_inference(hf_weights_files: list[str]) -> list[str]

Exclude files that are not needed for inference.

See https://github.com/huggingface/transformers/blob/v4.34.0/src/transformers/trainer.py#L227-L233

Source code in fastvideo/models/loader/weight_utils.py
def filter_files_not_needed_for_inference(
        hf_weights_files: list[str]) -> list[str]:
    """
    Exclude files that are not needed for inference.

    See https://github.com/huggingface/transformers/blob/v4.34.0/src/transformers/trainer.py#L227-L233
    """
    blacklist = [
        "training_args.bin",
        "optimizer.bin",
        "optimizer.pt",
        "scheduler.pt",
        "scaler.pt",
    ]
    hf_weights_files = [
        f for f in hf_weights_files
        if not any(f.endswith(x) for x in blacklist)
    ]
    return hf_weights_files
fastvideo.models.loader.weight_utils.maybe_remap_kv_scale_name
maybe_remap_kv_scale_name(name: str, params_dict: dict) -> str | None

Remap the name of FP8 k/v_scale parameters.

This function handles the remapping of FP8 k/v_scale parameter names. It detects if the given name ends with a suffix and attempts to remap it to the expected name format in the model. If the remapped name is not found in the params_dict, a warning is printed and None is returned.

Parameters:

Name Type Description Default
name str

The original loaded checkpoint parameter name.

required
params_dict dict

Dictionary containing the model's named parameters.

required

Returns:

Name Type Description
str str | None

The remapped parameter name if successful, or the original name if no remapping is needed.

None str | None

If the remapped name is not found in params_dict.

Source code in fastvideo/models/loader/weight_utils.py
def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> str | None:
    """Remap the name of FP8 k/v_scale parameters.

    This function handles the remapping of FP8 k/v_scale parameter names.
    It detects if the given name ends with a suffix and attempts to remap
    it to the expected name format in the model. If the remapped name is not
    found in the params_dict, a warning is printed and None is returned.

    Args:
        name (str): The original loaded checkpoint parameter name.
        params_dict (dict): Dictionary containing the model's named parameters.

    Returns:
        str: The remapped parameter name if successful, or the original name
             if no remapping is needed.
        None: If the remapped name is not found in params_dict.
    """
    if name.endswith(".kv_scale"):
        logger.warning_once(
            "DEPRECATED. Found kv_scale in the checkpoint. "
            "This format is deprecated in favor of separate k_scale and "
            "v_scale tensors and will be removed in a future release. "
            "Functionally, we will remap kv_scale to k_scale and duplicate "
            "k_scale to v_scale")
        # NOTE: we remap the deprecated kv_scale to k_scale
        remapped_name = name.replace(".kv_scale", ".attn.k_scale")
        if remapped_name not in params_dict:
            logger.warning_once(
                f"Found kv_scale in the checkpoint (e.g. {name}), "
                "but not found the expected name in the model "
                f"(e.g. {remapped_name}). kv_scale is "
                "not loaded.")
            return None
        return remapped_name

    possible_scale_names = [".k_scale", ".v_scale"]
    modelopt_scale_names = [
        ".self_attn.k_proj.k_scale", ".self_attn.v_proj.v_scale"
    ]
    for scale_name in possible_scale_names:
        if name.endswith(scale_name):
            if any(mo_scale_name in name
                   for mo_scale_name in modelopt_scale_names):
                remapped_name = name.replace(
                    f".self_attn.{scale_name[1]}_proj{scale_name}",
                    f".self_attn.attn{scale_name}")
            else:
                remapped_name = name.replace(scale_name, f".attn{scale_name}")
            if remapped_name not in params_dict:
                logger.warning_once(
                    f"Found {scale_name} in the checkpoint (e.g. {name}), "
                    "but not found the expected name in the model "
                    f"(e.g. {remapped_name}). {scale_name} is "
                    "not loaded.")
                return None
            return remapped_name

    # If there were no matches, return the untouched param name
    return name
fastvideo.models.loader.weight_utils.pt_weights_iterator
pt_weights_iterator(hf_weights_files: list[str], to_cpu: bool = True) -> Generator[tuple[str, Tensor], None, None]

Iterate over the weights in the model bin/pt files.

Source code in fastvideo/models/loader/weight_utils.py
def pt_weights_iterator(
    hf_weights_files: list[str],
    to_cpu: bool = True,
) -> Generator[tuple[str, torch.Tensor], None, None]:
    """Iterate over the weights in the model bin/pt files."""
    device = "cpu" if to_cpu else str(get_local_torch_device())
    enable_tqdm = not torch.distributed.is_initialized(
    ) or torch.distributed.get_rank() == 0
    for bin_file in tqdm(
            hf_weights_files,
            desc="Loading pt checkpoint shards",
            disable=not enable_tqdm,
            bar_format=_BAR_FORMAT,
    ):
        state = torch.load(bin_file, map_location=device, weights_only=True)
        yield from state.items()
        del state
fastvideo.models.loader.weight_utils.safetensors_weights_iterator
safetensors_weights_iterator(hf_weights_files: list[str], to_cpu: bool = True) -> Generator[tuple[str, Tensor], None, None]

Iterate over the weights in the model safetensor files.

Source code in fastvideo/models/loader/weight_utils.py
def safetensors_weights_iterator(
    hf_weights_files: list[str],
    to_cpu: bool = True,
) -> Generator[tuple[str, torch.Tensor], None, None]:
    """Iterate over the weights in the model safetensor files."""
    enable_tqdm = not torch.distributed.is_initialized(
    ) or torch.distributed.get_rank() == 0
    device = "cpu" if to_cpu else str(get_local_torch_device())
    for st_file in tqdm(
            hf_weights_files,
            desc="Loading safetensors checkpoint shards",
            disable=not enable_tqdm,
            bar_format=_BAR_FORMAT,
    ):
        with safe_open(st_file, framework="pt", device=device) as f:
            for name in f.keys():  # noqa: SIM118
                param = f.get_tensor(name)
                yield name, param