Skip to content

models

Modules

fastvideo.models.hf_transformer_utils

Utilities for Huggingface Transformers.

Functions

fastvideo.models.hf_transformer_utils.check_gguf_file
check_gguf_file(model: str | PathLike) -> bool

Check if the file is a GGUF model.

Source code in fastvideo/models/hf_transformer_utils.py
def check_gguf_file(model: str | os.PathLike) -> bool:
    """Check if the file is a GGUF model."""
    model = Path(model)
    if not model.is_file():
        return False
    elif model.suffix == ".gguf":
        return True

    with open(model, "rb") as f:
        header = f.read(4)
    return header == b"GGUF"
fastvideo.models.hf_transformer_utils.get_diffusers_config
get_diffusers_config(model: str) -> dict[str, Any]

Gets a configuration for the given diffusers model.

Parameters:

Name Type Description Default
model str

The model name or path.

required
fastvideo_args

Optional inference arguments to override in the config.

required

Returns:

Type Description
dict[str, Any]

The loaded configuration.

Source code in fastvideo/models/hf_transformer_utils.py
def get_diffusers_config(model: str, ) -> dict[str, Any]:
    """Gets a configuration for the given diffusers model.

    Args:
        model: The model name or path.
        fastvideo_args: Optional inference arguments to override in the config.

    Returns:
        The loaded configuration.
    """
    config_name = "config.json"
    if "scheduler" in model:
        config_name = "scheduler_config.json"
    # Check if the model path exists
    if os.path.exists(model):
        config_file = os.path.join(model, config_name)
        if os.path.exists(config_file):
            try:
                # Load the config directly from the file
                with open(config_file) as f:
                    config_dict: dict[str, Any] = json.load(f)
                if "_diffusers_version" in config_dict:
                    config_dict.pop("_diffusers_version")
                # TODO(will): apply any overrides from inference args
                return config_dict
            except Exception as e:
                raise RuntimeError(
                    f"Failed to load diffusers config from {config_file}: {e}"
                ) from e
        raise RuntimeError(f"Config file not found at {config_file}")
    else:
        raise RuntimeError(f"Diffusers config file not found at {model}")

fastvideo.models.loader

Modules

fastvideo.models.loader.component_loader
Classes
fastvideo.models.loader.component_loader.ComponentLoader
ComponentLoader(device=None)

Bases: ABC

Base class for loading a specific type of model component.

Source code in fastvideo/models/loader/component_loader.py
def __init__(self, device=None) -> None:
    self.device = device
Functions
fastvideo.models.loader.component_loader.ComponentLoader.for_module_type classmethod
for_module_type(module_type: str, transformers_or_diffusers: str) -> ComponentLoader

Factory method to create a component loader for a specific module type.

Parameters:

Name Type Description Default
module_type str

Type of module (e.g., "vae", "text_encoder", "transformer", "scheduler")

required
transformers_or_diffusers str

Whether the module is from transformers or diffusers

required

Returns:

Type Description
ComponentLoader

A component loader for the specified module type

Source code in fastvideo/models/loader/component_loader.py
@classmethod
def for_module_type(cls, module_type: str,
                    transformers_or_diffusers: str) -> 'ComponentLoader':
    """
    Factory method to create a component loader for a specific module type.

    Args:
        module_type: Type of module (e.g., "vae", "text_encoder", "transformer", "scheduler")
        transformers_or_diffusers: Whether the module is from transformers or diffusers

    Returns:
        A component loader for the specified module type
    """
    # Map of module types to their loader classes and expected library
    module_loaders = {
        "scheduler": (SchedulerLoader, "diffusers"),
        "transformer": (TransformerLoader, "diffusers"),
        "transformer_2": (TransformerLoader, "diffusers"),
        "vae": (VAELoader, "diffusers"),
        "text_encoder": (TextEncoderLoader, "transformers"),
        "text_encoder_2": (TextEncoderLoader, "transformers"),
        "tokenizer": (TokenizerLoader, "transformers"),
        "tokenizer_2": (TokenizerLoader, "transformers"),
        "image_processor": (ImageProcessorLoader, "transformers"),
        "image_encoder": (ImageEncoderLoader, "transformers"),
    }

    if module_type in module_loaders:
        loader_cls, expected_library = module_loaders[module_type]
        # Assert that the library matches what's expected for this module type
        assert transformers_or_diffusers == expected_library, f"{module_type} must be loaded from {expected_library}, got {transformers_or_diffusers}"
        return loader_cls()

    # For unknown module types, use a generic loader
    logger.warning(
        "No specific loader found for module type: %s. Using generic loader.",
        module_type)
    return GenericComponentLoader(transformers_or_diffusers)
fastvideo.models.loader.component_loader.ComponentLoader.load abstractmethod
load(model_path: str, fastvideo_args: FastVideoArgs)

Load the component based on the model path, architecture, and inference args.

Parameters:

Name Type Description Default
model_path str

Path to the component model

required
fastvideo_args FastVideoArgs

FastVideoArgs

required

Returns:

Type Description

The loaded component

Source code in fastvideo/models/loader/component_loader.py
@abstractmethod
def load(self, model_path: str, fastvideo_args: FastVideoArgs):
    """
    Load the component based on the model path, architecture, and inference args.

    Args:
        model_path: Path to the component model
        fastvideo_args: FastVideoArgs

    Returns:
        The loaded component
    """
    raise NotImplementedError
fastvideo.models.loader.component_loader.GenericComponentLoader
GenericComponentLoader(library='transformers')

Bases: ComponentLoader

Generic loader for components that don't have a specific loader.

Source code in fastvideo/models/loader/component_loader.py
def __init__(self, library="transformers") -> None:
    super().__init__()
    self.library = library
Functions
fastvideo.models.loader.component_loader.GenericComponentLoader.load
load(model_path: str, fastvideo_args: FastVideoArgs)

Load a generic component based on the model path, and inference args.

Source code in fastvideo/models/loader/component_loader.py
def load(self, model_path: str, fastvideo_args: FastVideoArgs):
    """Load a generic component based on the model path, and inference args."""
    logger.warning("Using generic loader for %s with library %s",
                   model_path, self.library)

    if self.library == "transformers":
        from transformers import AutoModel

        model = AutoModel.from_pretrained(
            model_path,
            trust_remote_code=fastvideo_args.trust_remote_code,
            revision=fastvideo_args.revision,
        )
        logger.info("Loaded generic transformers model: %s",
                    model.__class__.__name__)
        return model
    elif self.library == "diffusers":
        logger.warning(
            "Generic loading for diffusers components is not fully implemented"
        )

        model_config = get_diffusers_config(model=model_path)
        logger.info("Diffusers Model config: %s", model_config)
        # This is a placeholder - in a real implementation, you'd need to handle this properly
        return None
    else:
        raise ValueError(f"Unsupported library: {self.library}")
fastvideo.models.loader.component_loader.ImageEncoderLoader
ImageEncoderLoader(device=None)

Bases: TextEncoderLoader

Source code in fastvideo/models/loader/component_loader.py
def __init__(self, device=None) -> None:
    self.device = device
Functions
fastvideo.models.loader.component_loader.ImageEncoderLoader.load
load(model_path: str, fastvideo_args: FastVideoArgs)

Load the text encoders based on the model path, and inference args.

Source code in fastvideo/models/loader/component_loader.py
def load(self, model_path: str, fastvideo_args: FastVideoArgs):
    """Load the text encoders based on the model path, and inference args."""
    # model_config: PretrainedConfig = get_hf_config(
    #     model=model_path,
    #     trust_remote_code=fastvideo_args.trust_remote_code,
    #     revision=fastvideo_args.revision,
    #     model_override_args=None,
    # )
    with open(os.path.join(model_path, "config.json")) as f:
        model_config = json.load(f)
    model_config.pop("_name_or_path", None)
    model_config.pop("transformers_version", None)
    model_config.pop("torch_dtype", None)
    model_config.pop("model_type", None)
    logger.info("HF Model config: %s", model_config)

    encoder_config = fastvideo_args.pipeline_config.image_encoder_config
    encoder_config.update_model_arch(model_config)

    from fastvideo.platforms import current_platform

    if fastvideo_args.image_encoder_cpu_offload:
        target_device = torch.device("mps") if current_platform.is_mps() else torch.device("cpu")
    else:
        target_device = get_local_torch_device()
    # TODO(will): add support for other dtypes
    return self.load_model(
        model_path, encoder_config, target_device, fastvideo_args,
        fastvideo_args.pipeline_config.image_encoder_precision)
fastvideo.models.loader.component_loader.ImageProcessorLoader
ImageProcessorLoader(device=None)

Bases: ComponentLoader

Loader for image processor.

Source code in fastvideo/models/loader/component_loader.py
def __init__(self, device=None) -> None:
    self.device = device
Functions
fastvideo.models.loader.component_loader.ImageProcessorLoader.load
load(model_path: str, fastvideo_args: FastVideoArgs)

Load the image processor based on the model path, and inference args.

Source code in fastvideo/models/loader/component_loader.py
def load(self, model_path: str, fastvideo_args: FastVideoArgs):
    """Load the image processor based on the model path, and inference args."""
    logger.info("Loading image processor from %s", model_path)

    image_processor = AutoImageProcessor.from_pretrained(model_path, )
    logger.info("Loaded image processor: %s",
                image_processor.__class__.__name__)
    return image_processor
fastvideo.models.loader.component_loader.PipelineComponentLoader

Utility class for loading pipeline components. This replaces the chain of if-else statements in load_pipeline_module.

Functions
fastvideo.models.loader.component_loader.PipelineComponentLoader.load_module staticmethod
load_module(module_name: str, component_model_path: str, transformers_or_diffusers: str, fastvideo_args: FastVideoArgs)

Load a pipeline module.

Parameters:

Name Type Description Default
module_name str

Name of the module (e.g., "vae", "text_encoder", "transformer", "scheduler")

required
component_model_path str

Path to the component model

required
transformers_or_diffusers str

Whether the module is from transformers or diffusers

required
pipeline_args

Inference arguments

required

Returns:

Type Description

The loaded module

Source code in fastvideo/models/loader/component_loader.py
@staticmethod
def load_module(module_name: str, component_model_path: str,
                transformers_or_diffusers: str,
                fastvideo_args: FastVideoArgs):
    """
    Load a pipeline module.

    Args:
        module_name: Name of the module (e.g., "vae", "text_encoder", "transformer", "scheduler")
        component_model_path: Path to the component model
        transformers_or_diffusers: Whether the module is from transformers or diffusers
        pipeline_args: Inference arguments

    Returns:
        The loaded module
    """
    logger.info(
        "Loading %s using %s from %s",
        module_name,
        transformers_or_diffusers,
        component_model_path,
    )

    # Get the appropriate loader for this module type
    loader = ComponentLoader.for_module_type(module_name,
                                             transformers_or_diffusers)

    # Load the module
    return loader.load(component_model_path, fastvideo_args)
fastvideo.models.loader.component_loader.SchedulerLoader
SchedulerLoader(device=None)

Bases: ComponentLoader

Loader for scheduler.

Source code in fastvideo/models/loader/component_loader.py
def __init__(self, device=None) -> None:
    self.device = device
Functions
fastvideo.models.loader.component_loader.SchedulerLoader.load
load(model_path: str, fastvideo_args: FastVideoArgs)

Load the scheduler based on the model path, and inference args.

Source code in fastvideo/models/loader/component_loader.py
def load(self, model_path: str, fastvideo_args: FastVideoArgs):
    """Load the scheduler based on the model path, and inference args."""
    config = get_diffusers_config(model=model_path)

    class_name = config.pop("_class_name")
    assert class_name is not None, "Model config does not contain a _class_name attribute. Only diffusers format is supported."

    scheduler_cls, _ = ModelRegistry.resolve_model_cls(class_name)

    scheduler = scheduler_cls(**config)
    if fastvideo_args.pipeline_config.flow_shift is not None:
        scheduler.set_shift(fastvideo_args.pipeline_config.flow_shift)
    if fastvideo_args.pipeline_config.timesteps_scale is not None:
        scheduler.set_timesteps_scale(
            fastvideo_args.pipeline_config.timesteps_scale)
    return scheduler
fastvideo.models.loader.component_loader.TextEncoderLoader
TextEncoderLoader(device=None)

Bases: ComponentLoader

Loader for text encoders.

Source code in fastvideo/models/loader/component_loader.py
def __init__(self, device=None) -> None:
    self.device = device
Classes
fastvideo.models.loader.component_loader.TextEncoderLoader.Source dataclass
Source(model_or_path: str, prefix: str = '', fall_back_to_pt: bool = True, allow_patterns_overrides: list[str] | None = None)

A source for weights.

Attributes
fastvideo.models.loader.component_loader.TextEncoderLoader.Source.allow_patterns_overrides class-attribute instance-attribute
allow_patterns_overrides: list[str] | None = None

If defined, weights will load exclusively using these patterns.

fastvideo.models.loader.component_loader.TextEncoderLoader.Source.fall_back_to_pt class-attribute instance-attribute
fall_back_to_pt: bool = True

Whether .pt weights can be used.

fastvideo.models.loader.component_loader.TextEncoderLoader.Source.model_or_path instance-attribute
model_or_path: str

The model ID or path.

fastvideo.models.loader.component_loader.TextEncoderLoader.Source.prefix class-attribute instance-attribute
prefix: str = ''

A prefix to prepend to all weights.

Functions
fastvideo.models.loader.component_loader.TextEncoderLoader.load
load(model_path: str, fastvideo_args: FastVideoArgs)

Load the text encoders based on the model path, and inference args.

Source code in fastvideo/models/loader/component_loader.py
def load(self, model_path: str, fastvideo_args: FastVideoArgs):
    """Load the text encoders based on the model path, and inference args."""
    # model_config: PretrainedConfig = get_hf_config(
    #     model=model_path,
    #     trust_remote_code=fastvideo_args.trust_remote_code,
    #     revision=fastvideo_args.revision,
    #     model_override_args=None,
    # )
    model_config = get_diffusers_config(model=model_path)
    model_config.pop("_name_or_path", None)
    model_config.pop("transformers_version", None)
    model_config.pop("model_type", None)
    model_config.pop("tokenizer_class", None)
    model_config.pop("torch_dtype", None)
    logger.info("HF Model config: %s", model_config)

    # @TODO(Wei): Better way to handle this?
    try:
        encoder_config = fastvideo_args.pipeline_config.text_encoder_configs[
            0]
        encoder_config.update_model_arch(model_config)
        encoder_precision = fastvideo_args.pipeline_config.text_encoder_precisions[
            0]
    except Exception:
        encoder_config = fastvideo_args.pipeline_config.text_encoder_configs[
            1]
        encoder_config.update_model_arch(model_config)
        encoder_precision = fastvideo_args.pipeline_config.text_encoder_precisions[
            1]

    target_device = get_local_torch_device()
    # TODO(will): add support for other dtypes
    return self.load_model(model_path, encoder_config, target_device,
                           fastvideo_args, encoder_precision)
fastvideo.models.loader.component_loader.TokenizerLoader
TokenizerLoader(device=None)

Bases: ComponentLoader

Loader for tokenizers.

Source code in fastvideo/models/loader/component_loader.py
def __init__(self, device=None) -> None:
    self.device = device
Functions
fastvideo.models.loader.component_loader.TokenizerLoader.load
load(model_path: str, fastvideo_args: FastVideoArgs)

Load the tokenizer based on the model path, and inference args.

Source code in fastvideo/models/loader/component_loader.py
def load(self, model_path: str, fastvideo_args: FastVideoArgs):
    """Load the tokenizer based on the model path, and inference args."""
    logger.info("Loading tokenizer from %s", model_path)

    tokenizer = AutoTokenizer.from_pretrained(
        model_path,  # "<path to model>/tokenizer"
        # in v0, this was same string as encoder_name "ClipTextModel"
        # TODO(will): pass these tokenizer kwargs from inference args? Maybe
        # other method of config?
        padding_size='right',
    )
    logger.info("Loaded tokenizer: %s", tokenizer.__class__.__name__)
    return tokenizer
fastvideo.models.loader.component_loader.TransformerLoader
TransformerLoader(device=None)

Bases: ComponentLoader

Loader for transformer.

Source code in fastvideo/models/loader/component_loader.py
def __init__(self, device=None) -> None:
    self.device = device
Functions
fastvideo.models.loader.component_loader.TransformerLoader.load
load(model_path: str, fastvideo_args: FastVideoArgs)

Load the transformer based on the model path, and inference args.

Source code in fastvideo/models/loader/component_loader.py
def load(self, model_path: str, fastvideo_args: FastVideoArgs):
    """Load the transformer based on the model path, and inference args."""
    config = get_diffusers_config(model=model_path)
    hf_config = deepcopy(config)
    cls_name = config.pop("_class_name")
    if cls_name is None:
        raise ValueError(
            "Model config does not contain a _class_name attribute. "
            "Only diffusers format is supported.")

    logger.info("transformer cls_name: %s", cls_name)
    if fastvideo_args.override_transformer_cls_name is not None:
        cls_name = fastvideo_args.override_transformer_cls_name
        logger.info("Overriding transformer cls_name to %s", cls_name)

    fastvideo_args.model_paths["transformer"] = model_path

    # Config from Diffusers supersedes fastvideo's model config
    dit_config = fastvideo_args.pipeline_config.dit_config
    dit_config.update_model_arch(config)

    model_cls, _ = ModelRegistry.resolve_model_cls(cls_name)

    # Find all safetensors files
    safetensors_list = glob.glob(
        os.path.join(str(model_path), "*.safetensors"))
    if not safetensors_list:
        raise ValueError(f"No safetensors files found in {model_path}")

    # Check if we should use custom initialization weights
    custom_weights_path = getattr(fastvideo_args, 'init_weights_from_safetensors', None)
    use_custom_weights = (custom_weights_path and os.path.exists(custom_weights_path) and 
                        not hasattr(fastvideo_args, '_loading_teacher_critic_model'))

    if use_custom_weights:
        if 'transformer_2' in model_path:
            custom_weights_path = getattr(fastvideo_args, 'init_weights_from_safetensors_2', None)
        assert custom_weights_path is not None, "Custom initialization weights must be provided"
        if os.path.isdir(custom_weights_path):
            safetensors_list = glob.glob(
                os.path.join(str(custom_weights_path), "*.safetensors"))
        else:
            assert custom_weights_path.endswith(".safetensors"), "Custom initialization weights must be a safetensors file"
            safetensors_list = [custom_weights_path]

    logger.info("Loading model from %s safetensors files: %s",
                len(safetensors_list), safetensors_list)

    default_dtype = PRECISION_TO_TYPE[
        fastvideo_args.pipeline_config.dit_precision]

    # Load the model using FSDP loader
    logger.info("Loading model from %s, default_dtype: %s", cls_name,
                default_dtype)
    assert fastvideo_args.hsdp_shard_dim is not None
    model = maybe_load_fsdp_model(
        model_cls=model_cls,
        init_params={
            "config": dit_config,
            "hf_config": hf_config
        },
        weight_dir_list=safetensors_list,
        device=get_local_torch_device(),
        hsdp_replicate_dim=fastvideo_args.hsdp_replicate_dim,
        hsdp_shard_dim=fastvideo_args.hsdp_shard_dim,
        cpu_offload=fastvideo_args.dit_cpu_offload,
        pin_cpu_memory=fastvideo_args.pin_cpu_memory,
        fsdp_inference=fastvideo_args.use_fsdp_inference,
        # TODO(will): make these configurable
        default_dtype=default_dtype,
        param_dtype=torch.bfloat16,
        reduce_dtype=torch.float32,
        output_dtype=None,
        training_mode=fastvideo_args.training_mode,
        enable_torch_compile=fastvideo_args.enable_torch_compile,
        torch_compile_kwargs=fastvideo_args.torch_compile_kwargs)


    total_params = sum(p.numel() for p in model.parameters())
    logger.info("Loaded model with %.2fB parameters", total_params / 1e9)

    assert next(model.parameters()).dtype == default_dtype, "Model dtype does not match default dtype"

    model = model.eval()
    return model
fastvideo.models.loader.component_loader.VAELoader
VAELoader(device=None)

Bases: ComponentLoader

Loader for VAE.

Source code in fastvideo/models/loader/component_loader.py
def __init__(self, device=None) -> None:
    self.device = device
Functions
fastvideo.models.loader.component_loader.VAELoader.load
load(model_path: str, fastvideo_args: FastVideoArgs)

Load the VAE based on the model path, and inference args.

Source code in fastvideo/models/loader/component_loader.py
def load(self, model_path: str, fastvideo_args: FastVideoArgs):
    """Load the VAE based on the model path, and inference args."""
    config = get_diffusers_config(model=model_path)
    class_name = config.pop("_class_name")
    assert class_name is not None, "Model config does not contain a _class_name attribute. Only diffusers format is supported."
    fastvideo_args.model_paths["vae"] = model_path

    vae_config = fastvideo_args.pipeline_config.vae_config
    vae_config.update_model_arch(config)

    from fastvideo.platforms import current_platform

    if fastvideo_args.vae_cpu_offload:
        target_device = torch.device("mps") if current_platform.is_mps() else torch.device("cpu")
    else:
        target_device = get_local_torch_device()

    with set_default_torch_dtype(PRECISION_TO_TYPE[
            fastvideo_args.pipeline_config.vae_precision]):
        vae_cls, _ = ModelRegistry.resolve_model_cls(class_name)
        vae = vae_cls(vae_config).to(target_device)

    # Find all safetensors files
    safetensors_list = glob.glob(
        os.path.join(str(model_path), "*.safetensors"))
    # TODO(PY)
    assert len(
        safetensors_list
    ) == 1, f"Found {len(safetensors_list)} safetensors files in {model_path}"
    loaded = safetensors_load_file(safetensors_list[0])
    vae.load_state_dict(
        loaded, strict=False)  # We might only load encoder or decoder

    return vae.eval()
Functions
fastvideo.models.loader.fsdp_load
Functions
fastvideo.models.loader.fsdp_load.load_model_from_full_model_state_dict
load_model_from_full_model_state_dict(model: FSDPModule | Module, full_sd_iterator: Generator[tuple[str, Tensor], None, None], device: device, param_dtype: dtype, strict: bool = False, cpu_offload: bool = False, param_names_mapping: Callable[[str], tuple[str, Any, Any]] | None = None, training_mode: bool = True) -> _IncompatibleKeys

Converting full state dict into a sharded state dict and loading it into FSDP model (if training) or normal huggingface model Args: model (Union[FSDPModule, torch.nn.Module]): Model to generate fully qualified names for cpu_state_dict full_sd_iterator (Generator): an iterator yielding (param_name, tensor) pairs device (torch.device): device used to move full state dict tensors param_dtype (torch.dtype): dtype used to move full state dict tensors strict (bool): flag to check if to load the model in strict mode cpu_offload (bool): flag to check if FSDP offload is enabled param_names_mapping (Optional[Callable[[str], str]]): a function that maps full param name to sharded param name training_mode (bool): apply FSDP only for training Returns: NamedTuple with missing_keys and unexpected_keys fields: * missing_keys is a list of str containing the missing keys * unexpected_keys is a list of str containing the unexpected keys

Raises:

Type Description
NotImplementedError

If got FSDP with more than 1D.

Source code in fastvideo/models/loader/fsdp_load.py
def load_model_from_full_model_state_dict(
    model: FSDPModule | torch.nn.Module,
    full_sd_iterator: Generator[tuple[str, torch.Tensor], None, None],
    device: torch.device,
    param_dtype: torch.dtype,
    strict: bool = False,
    cpu_offload: bool = False,
    param_names_mapping: Callable[[str], tuple[str, Any, Any]] | None = None,
    training_mode: bool = True,
) -> _IncompatibleKeys:
    """
    Converting full state dict into a sharded state dict
    and loading it into FSDP model (if training) or normal huggingface model
    Args:
        model (Union[FSDPModule, torch.nn.Module]): Model to generate fully qualified names for cpu_state_dict
        full_sd_iterator (Generator): an iterator yielding (param_name, tensor) pairs
        device (torch.device): device used to move full state dict tensors
        param_dtype (torch.dtype): dtype used to move full state dict tensors
        strict (bool): flag to check if to load the model in strict mode
        cpu_offload (bool): flag to check if FSDP offload is enabled
        param_names_mapping (Optional[Callable[[str], str]]): a function that maps full param name to sharded param name
        training_mode (bool): apply FSDP only for training
    Returns:
        ``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields:
            * **missing_keys** is a list of str containing the missing keys
            * **unexpected_keys** is a list of str containing the unexpected keys

    Raises:
        NotImplementedError: If got FSDP with more than 1D.
    """
    meta_sd = model.state_dict()
    sharded_sd = {}
    custom_param_sd, reverse_param_names_mapping = hf_to_custom_state_dict(
        full_sd_iterator, param_names_mapping)  # type: ignore
    for target_param_name, full_tensor in custom_param_sd.items():
        meta_sharded_param = meta_sd.get(target_param_name)
        if meta_sharded_param is None:
            raise ValueError(
                f"Parameter {target_param_name} not found in custom model state dict. The hf to custom mapping may be incorrect."
            )
        if not hasattr(meta_sharded_param, "device_mesh"):
            full_tensor = full_tensor.to(device=device, dtype=param_dtype)
            # In cases where parts of the model aren't sharded, some parameters will be plain tensors
            sharded_tensor = full_tensor
        else:
            full_tensor = full_tensor.to(device=device, dtype=param_dtype)
            sharded_tensor = distribute_tensor(
                full_tensor,
                meta_sharded_param.device_mesh,
                meta_sharded_param.placements,
            )
            if cpu_offload:
                sharded_tensor = sharded_tensor.cpu()
        sharded_sd[target_param_name] = nn.Parameter(sharded_tensor)

    model.reverse_param_names_mapping = reverse_param_names_mapping
    unused_keys = set(meta_sd.keys()) - set(sharded_sd.keys())
    if unused_keys:
        logger.warning("Found unloaded parameters in meta state dict: %s",
                       unused_keys)

    # List of allowed parameter name patterns
    ALLOWED_NEW_PARAM_PATTERNS = ["gate_compress"]  # Can be extended as needed
    for new_param_name in unused_keys:
        if not any(pattern in new_param_name
                   for pattern in ALLOWED_NEW_PARAM_PATTERNS):
            logger.error("Unsupported new parameter: %s. Allowed patterns: %s",
                         new_param_name, ALLOWED_NEW_PARAM_PATTERNS)
            raise ValueError(
                f"New parameter '{new_param_name}' is not supported. "
                f"Currently only parameters containing {ALLOWED_NEW_PARAM_PATTERNS} are allowed."
            )
        meta_sharded_param = meta_sd.get(new_param_name)
        if not hasattr(meta_sharded_param, "device_mesh"):
            # Initialize with zeros
            sharded_tensor = torch.zeros_like(meta_sharded_param,
                                              device=device,
                                              dtype=param_dtype)
        else:
            # Initialize with zeros and distribute
            full_tensor = torch.zeros_like(meta_sharded_param,
                                           device=device,
                                           dtype=param_dtype)
            sharded_tensor = distribute_tensor(
                full_tensor,
                meta_sharded_param.device_mesh,
                meta_sharded_param.placements,
            )
            if cpu_offload:
                sharded_tensor = sharded_tensor.cpu()
        sharded_sd[new_param_name] = nn.Parameter(sharded_tensor)

    # choose `assign=True` since we cannot call `copy_` on meta tensor
    return model.load_state_dict(sharded_sd, strict=strict, assign=True)
fastvideo.models.loader.fsdp_load.maybe_load_fsdp_model
maybe_load_fsdp_model(model_cls: type[Module], init_params: dict[str, Any], weight_dir_list: list[str], device: device, hsdp_replicate_dim: int, hsdp_shard_dim: int, default_dtype: dtype, param_dtype: dtype, reduce_dtype: dtype, cpu_offload: bool = False, fsdp_inference: bool = False, output_dtype: dtype | None = None, training_mode: bool = True, pin_cpu_memory: bool = True, enable_torch_compile: bool = False, torch_compile_kwargs: dict[str, Any] | None = None) -> Module

Load the model with FSDP if is training, else load the model without FSDP.

Source code in fastvideo/models/loader/fsdp_load.py
def maybe_load_fsdp_model(
    model_cls: type[nn.Module],
    init_params: dict[str, Any],
    weight_dir_list: list[str],
    device: torch.device,
    hsdp_replicate_dim: int,
    hsdp_shard_dim: int,
    default_dtype: torch.dtype,
    param_dtype: torch.dtype,
    reduce_dtype: torch.dtype,
    cpu_offload: bool = False,
    fsdp_inference: bool = False,
    output_dtype: torch.dtype | None = None,
    training_mode: bool = True,
    pin_cpu_memory: bool = True,
    enable_torch_compile: bool = False,
    torch_compile_kwargs: dict[str, Any] | None = None,
) -> torch.nn.Module:
    """
    Load the model with FSDP if is training, else load the model without FSDP.
    """
    # NOTE(will): cast_forward_inputs=True shouldn't be needed as we are
    # manually casting the inputs to the model
    mp_policy = MixedPrecisionPolicy(param_dtype,
                                     reduce_dtype,
                                     output_dtype,
                                     cast_forward_inputs=False)

    set_mixed_precision_policy(
        param_dtype=param_dtype,
        reduce_dtype=reduce_dtype,
        output_dtype=output_dtype,
        mp_policy=mp_policy,
    )

    logger.info("Loading model with default_dtype: %s", default_dtype)
    with set_default_dtype(default_dtype), torch.device("meta"):
        model = model_cls(**init_params)

    # Check if we should use FSDP
    use_fsdp = training_mode or fsdp_inference

    # Disable FSDP for MPS as it's not compatible
    from fastvideo.platforms import current_platform
    if current_platform.is_mps():
        use_fsdp = False
        logger.info("Disabling FSDP for MPS platform as it's not compatible")

    if use_fsdp:
        world_size = hsdp_replicate_dim * hsdp_shard_dim
        if not training_mode and not fsdp_inference:
            hsdp_replicate_dim = world_size
            hsdp_shard_dim = 1

        if current_platform.is_npu():
            with torch.device("cpu"):
                device_mesh = init_device_mesh(
                    "npu",
                    # (Replicate(), Shard(dim=0))
                    mesh_shape=(hsdp_replicate_dim, hsdp_shard_dim),
                    mesh_dim_names=("replicate", "shard"),
                )
        else:
            device_mesh = init_device_mesh(
            "cuda",
            # (Replicate(), Shard(dim=0))
            mesh_shape=(hsdp_replicate_dim, hsdp_shard_dim),
            mesh_dim_names=("replicate", "shard"),
        )
        shard_model(model,
                    cpu_offload=cpu_offload,
                    reshard_after_forward=True,
                    mp_policy=mp_policy,
                    mesh=device_mesh,
                    fsdp_shard_conditions=model._fsdp_shard_conditions,
                    pin_cpu_memory=pin_cpu_memory)

    weight_iterator = safetensors_weights_iterator(weight_dir_list)
    param_names_mapping_fn = get_param_names_mapping(model.param_names_mapping)
    load_model_from_full_model_state_dict(
        model,
        weight_iterator,
        device,
        default_dtype,
        strict=True,
        cpu_offload=cpu_offload,
        param_names_mapping=param_names_mapping_fn,
    )
    for n, p in chain(model.named_parameters(), model.named_buffers()):
        if p.is_meta:
            raise RuntimeError(
                f"Unexpected param or buffer {n} on meta device.")
        # Avoid unintended computation graph accumulation during inference
        if isinstance(p, torch.nn.Parameter):
            p.requires_grad = False

    compile_in_loader = enable_torch_compile and training_mode
    if compile_in_loader:
        compile_kwargs = torch_compile_kwargs or {}
        logger.info("Enabling torch.compile for FSDP training module with kwargs=%s",
                    compile_kwargs)
        model = torch.compile(model, **compile_kwargs)
        logger.info("torch.compile enabled for %s", type(model).__name__)
    return model
fastvideo.models.loader.fsdp_load.set_default_dtype
set_default_dtype(dtype: dtype) -> Generator[None, None, None]

Context manager to set torch's default dtype.

Parameters:

Name Type Description Default
dtype dtype

The desired default dtype inside the context manager.

required

Returns:

Name Type Description
ContextManager None

context manager for setting default dtype.

Example

with set_default_dtype(torch.bfloat16): x = torch.tensor([1, 2, 3]) x.dtype torch.bfloat16

Source code in fastvideo/models/loader/fsdp_load.py
@contextlib.contextmanager
def set_default_dtype(dtype: torch.dtype) -> Generator[None, None, None]:
    """
    Context manager to set torch's default dtype.

    Args:
        dtype (torch.dtype): The desired default dtype inside the context manager.

    Returns:
        ContextManager: context manager for setting default dtype.

    Example:
        >>> with set_default_dtype(torch.bfloat16):
        >>>     x = torch.tensor([1, 2, 3])
        >>>     x.dtype
        torch.bfloat16


    """
    old_dtype = torch.get_default_dtype()
    torch.set_default_dtype(dtype)
    try:
        yield
    finally:
        torch.set_default_dtype(old_dtype)
fastvideo.models.loader.fsdp_load.shard_model
shard_model(model, *, cpu_offload: bool, reshard_after_forward: bool = True, mp_policy: MixedPrecisionPolicy | None = MixedPrecisionPolicy(), mesh: DeviceMesh | None = None, fsdp_shard_conditions: list[Callable[[str, Module], bool]] = [], pin_cpu_memory: bool = True) -> None

Utility to shard a model with FSDP using the PyTorch Distributed fully_shard API.

This method will over the model's named modules from the bottom-up and apply shard modules based on whether they meet any of the criteria from shard_conditions.

Parameters:

Name Type Description Default
model TransformerDecoder

Model to shard with FSDP.

required
shard_conditions List[Callable[[str, Module], bool]]

A list of functions to determine which modules to shard with FSDP. Each function should take module name (relative to root) and the module itself, returning True if FSDP should shard the module and False otherwise. If any of shard_conditions return True for a given module, it will be sharded by FSDP.

required
cpu_offload bool

If set to True, FSDP will offload parameters, gradients, and optimizer states to CPU.

required
reshard_after_forward bool

Whether to reshard parameters and buffers after the forward pass. Setting this to True corresponds to the FULL_SHARD sharding strategy from FSDP1, while setting it to False corresponds to the SHARD_GRAD_OP sharding strategy.

True
mesh Optional[DeviceMesh]

Device mesh to use for FSDP sharding under multiple parallelism. Default to None.

None
fsdp_shard_conditions List[Callable[[str, Module], bool]]

A list of functions to determine which modules to shard with FSDP.

[]
pin_cpu_memory bool

If set to True, FSDP will pin the CPU memory of the offloaded parameters.

True

Raises:

Type Description
ValueError

If no layer modules were sharded, indicating that no shard_condition was triggered.

Source code in fastvideo/models/loader/fsdp_load.py
def shard_model(
    model,
    *,
    cpu_offload: bool,
    reshard_after_forward: bool = True,
    mp_policy: MixedPrecisionPolicy | None = MixedPrecisionPolicy(),  # noqa
    mesh: DeviceMesh | None = None,
    fsdp_shard_conditions: list[Callable[[str, nn.Module], bool]] = [],  # noqa
    pin_cpu_memory: bool = True,
) -> None:
    """
    Utility to shard a model with FSDP using the PyTorch Distributed fully_shard API.

    This method will over the model's named modules from the bottom-up and apply shard modules
    based on whether they meet any of the criteria from shard_conditions.

    Args:
        model (TransformerDecoder): Model to shard with FSDP.
        shard_conditions (List[Callable[[str, nn.Module], bool]]): A list of functions to determine
            which modules to shard with FSDP. Each function should take module name (relative to root)
            and the module itself, returning True if FSDP should shard the module and False otherwise.
            If any of shard_conditions return True for a given module, it will be sharded by FSDP.
        cpu_offload (bool): If set to True, FSDP will offload parameters, gradients, and optimizer
            states to CPU.
        reshard_after_forward (bool): Whether to reshard parameters and buffers after
            the forward pass. Setting this to True corresponds to the FULL_SHARD sharding strategy
            from FSDP1, while setting it to False corresponds to the SHARD_GRAD_OP sharding strategy.
        mesh (Optional[DeviceMesh]): Device mesh to use for FSDP sharding under multiple parallelism.
            Default to None.
        fsdp_shard_conditions (List[Callable[[str, nn.Module], bool]]): A list of functions to determine
            which modules to shard with FSDP.
        pin_cpu_memory (bool): If set to True, FSDP will pin the CPU memory of the offloaded parameters.

    Raises:
        ValueError: If no layer modules were sharded, indicating that no shard_condition was triggered.
    """
    # Check if we should use size-based filtering
    use_size_filtering = os.environ.get("FASTVIDEO_FSDP2_AUTOWRAP", "0") == "1"

    if not fsdp_shard_conditions:
        logger.warning("No FSDP shard conditions provided; nothing will be sharded.")
        return

    fsdp_kwargs = {
        "reshard_after_forward": reshard_after_forward,
        "mesh": mesh,
        "mp_policy": mp_policy,
    }
    if cpu_offload:
        fsdp_kwargs["offload_policy"] = CPUOffloadPolicy(
            pin_memory=pin_cpu_memory)

    # iterating in reverse to start with
    # lowest-level modules first
    num_layers_sharded = 0

    if use_size_filtering:
        # Size-based filtering mode
        min_params = int(os.environ.get("FASTVIDEO_FSDP2_MIN_PARAMS", "10000000"))
        logger.info("Using size-based filtering with threshold: %.2fM", min_params / 1e6)

        for n, m in reversed(list(model.named_modules())):
            if any([shard_condition(n, m) for shard_condition in fsdp_shard_conditions]):
                # Count all parameters
                param_count = sum(p.numel() for p in m.parameters(recurse=True))

                # Skip small modules
                if param_count < min_params:
                    logger.info("Skipping module %s (%.2fM params < %.2fM threshold)", 
                               n, param_count / 1e6, min_params / 1e6)
                    continue

                # Shard this module
                logger.info("Sharding module %s (%.2fM params)", n, param_count / 1e6)
                fully_shard(m, **fsdp_kwargs)
                num_layers_sharded += 1
    else:
        # Shard all modules matching conditions        
        for n, m in reversed(list(model.named_modules())):
            if any([shard_condition(n, m) for shard_condition in fsdp_shard_conditions]):
                fully_shard(m, **fsdp_kwargs)
                num_layers_sharded += 1

        if num_layers_sharded == 0:
            raise ValueError(
                "No layer modules were sharded. Please check if shard conditions are working as expected."
            )

    # Finally shard the entire model to account for any stragglers
    fully_shard(model, **fsdp_kwargs)
fastvideo.models.loader.utils

Utilities for selecting and loading models.

Functions
fastvideo.models.loader.utils.get_param_names_mapping
get_param_names_mapping(mapping_dict: dict[str, str]) -> Callable[[str], tuple[str, Any, Any]]

Creates a mapping function that transforms parameter names using regex patterns.

Parameters:

Name Type Description Default
mapping_dict Dict[str, str]

Dictionary mapping regex patterns to replacement patterns

required
param_name str

The parameter name to be transformed

required

Returns:

Type Description
Callable[[str], tuple[str, Any, Any]]

Callable[[str], str]: A function that maps parameter names from source to target format

Source code in fastvideo/models/loader/utils.py
def get_param_names_mapping(
        mapping_dict: dict[str, str]) -> Callable[[str], tuple[str, Any, Any]]:
    """
    Creates a mapping function that transforms parameter names using regex patterns.

    Args:
        mapping_dict (Dict[str, str]): Dictionary mapping regex patterns to replacement patterns
        param_name (str): The parameter name to be transformed

    Returns:
        Callable[[str], str]: A function that maps parameter names from source to target format
    """

    def mapping_fn(name: str) -> tuple[str, Any, Any]:
        # Try to match and transform the name using the regex patterns in mapping_dict
        for pattern, replacement in mapping_dict.items():
            match = re.match(pattern, name)
            if match:
                merge_index = None
                total_splitted_params = None
                if isinstance(replacement, tuple):
                    merge_index = replacement[1]
                    total_splitted_params = replacement[2]
                    replacement = replacement[0]
                name = re.sub(pattern, replacement, name)
                return name, merge_index, total_splitted_params

        # If no pattern matches, return the original name
        return name, None, None

    return mapping_fn
fastvideo.models.loader.utils.hf_to_custom_state_dict
hf_to_custom_state_dict(hf_param_sd: dict[str, Tensor] | Iterator[tuple[str, Tensor]], param_names_mapping: Callable[[str], tuple[str, Any, Any]]) -> tuple[dict[str, Tensor], dict[str, tuple[str, Any, Any]]]

Converts a Hugging Face parameter state dictionary to a custom parameter state dictionary.

Parameters:

Name Type Description Default
hf_param_sd Dict[str, Tensor]

The Hugging Face parameter state dictionary

required
param_names_mapping Callable[[str], tuple[str, Any, Any]]

A function that maps parameter names from source to target format

required

Returns:

Name Type Description
custom_param_sd Dict[str, Tensor]

The custom formatted parameter state dict

reverse_param_names_mapping Dict[str, Tuple[str, Any, Any]]

Maps back from custom to hf

Source code in fastvideo/models/loader/utils.py
def hf_to_custom_state_dict(
    hf_param_sd: dict[str, torch.Tensor] | Iterator[tuple[str, torch.Tensor]],
    param_names_mapping: Callable[[str], tuple[str, Any, Any]]
) -> tuple[dict[str, torch.Tensor], dict[str, tuple[str, Any, Any]]]:
    """
    Converts a Hugging Face parameter state dictionary to a custom parameter state dictionary.

    Args:
        hf_param_sd (Dict[str, torch.Tensor]): The Hugging Face parameter state dictionary
        param_names_mapping (Callable[[str], tuple[str, Any, Any]]): A function that maps parameter names from source to target format

    Returns:
        custom_param_sd (Dict[str, torch.Tensor]): The custom formatted parameter state dict
        reverse_param_names_mapping (Dict[str, Tuple[str, Any, Any]]): Maps back from custom to hf
    """
    custom_param_sd = {}
    to_merge_params = defaultdict(dict)  # type: ignore
    reverse_param_names_mapping = {}
    if isinstance(hf_param_sd, dict):
        hf_param_sd = hf_param_sd.items()  # type: ignore
    for source_param_name, full_tensor in hf_param_sd:  # type: ignore
        target_param_name, merge_index, num_params_to_merge = param_names_mapping(
            source_param_name)
        reverse_param_names_mapping[target_param_name] = (source_param_name,
                                                          merge_index,
                                                          num_params_to_merge)
        if merge_index is not None:
            to_merge_params[target_param_name][merge_index] = full_tensor
            if len(to_merge_params[target_param_name]) == num_params_to_merge:
                # cat at output dim according to the merge_index order
                sorted_tensors = [
                    to_merge_params[target_param_name][i]
                    for i in range(num_params_to_merge)
                ]
                full_tensor = torch.cat(sorted_tensors, dim=0)
                del to_merge_params[target_param_name]
            else:
                continue
        custom_param_sd[target_param_name] = full_tensor
    return custom_param_sd, reverse_param_names_mapping
fastvideo.models.loader.utils.set_default_torch_dtype
set_default_torch_dtype(dtype: dtype)

Sets the default torch dtype to the given dtype.

Source code in fastvideo/models/loader/utils.py
@contextlib.contextmanager
def set_default_torch_dtype(dtype: torch.dtype):
    """Sets the default torch dtype to the given dtype."""
    old_dtype = torch.get_default_dtype()
    torch.set_default_dtype(dtype)
    yield
    torch.set_default_dtype(old_dtype)
fastvideo.models.loader.weight_utils

Utilities for downloading and initializing model weights.

Functions
fastvideo.models.loader.weight_utils.default_weight_loader
default_weight_loader(param: Tensor, loaded_weight: Tensor) -> None

Default weight loader.

Source code in fastvideo/models/loader/weight_utils.py
def default_weight_loader(param: torch.Tensor,
                          loaded_weight: torch.Tensor) -> None:
    """Default weight loader."""
    try:
        if param.numel() == 1 and loaded_weight.numel() == 1:
            # Sometimes scalar values aren't considered tensors with shapes
            # so if both param and loaded_weight are a scalar,
            # "broadcast" instead of copy
            param.data.fill_(loaded_weight.item())
        else:
            assert param.size() == loaded_weight.size(), (
                f"Attempted to load weight ({loaded_weight.size()}) "
                f"into parameter ({param.size()})")

            param.data.copy_(loaded_weight)
    except Exception:
        # NOTE: This exception is added for the purpose of setting breakpoint to
        # debug weight loading issues.
        raise
fastvideo.models.loader.weight_utils.enable_hf_transfer
enable_hf_transfer() -> None

automatically activates hf_transfer

Source code in fastvideo/models/loader/weight_utils.py
def enable_hf_transfer() -> None:
    """automatically activates hf_transfer
    """
    if "HF_HUB_ENABLE_HF_TRANSFER" not in os.environ:
        try:
            # enable hf hub transfer if available
            import hf_transfer  # type: ignore # noqa
            huggingface_hub.constants.HF_HUB_ENABLE_HF_TRANSFER = True
        except ImportError:
            pass
fastvideo.models.loader.weight_utils.filter_files_not_needed_for_inference
filter_files_not_needed_for_inference(hf_weights_files: list[str]) -> list[str]

Exclude files that are not needed for inference.

See https://github.com/huggingface/transformers/blob/v4.34.0/src/transformers/trainer.py#L227-L233

Source code in fastvideo/models/loader/weight_utils.py
def filter_files_not_needed_for_inference(
        hf_weights_files: list[str]) -> list[str]:
    """
    Exclude files that are not needed for inference.

    See https://github.com/huggingface/transformers/blob/v4.34.0/src/transformers/trainer.py#L227-L233
    """
    blacklist = [
        "training_args.bin",
        "optimizer.bin",
        "optimizer.pt",
        "scheduler.pt",
        "scaler.pt",
    ]
    hf_weights_files = [
        f for f in hf_weights_files
        if not any(f.endswith(x) for x in blacklist)
    ]
    return hf_weights_files
fastvideo.models.loader.weight_utils.maybe_remap_kv_scale_name
maybe_remap_kv_scale_name(name: str, params_dict: dict) -> str | None

Remap the name of FP8 k/v_scale parameters.

This function handles the remapping of FP8 k/v_scale parameter names. It detects if the given name ends with a suffix and attempts to remap it to the expected name format in the model. If the remapped name is not found in the params_dict, a warning is printed and None is returned.

Parameters:

Name Type Description Default
name str

The original loaded checkpoint parameter name.

required
params_dict dict

Dictionary containing the model's named parameters.

required

Returns:

Name Type Description
str str | None

The remapped parameter name if successful, or the original name if no remapping is needed.

None str | None

If the remapped name is not found in params_dict.

Source code in fastvideo/models/loader/weight_utils.py
def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> str | None:
    """Remap the name of FP8 k/v_scale parameters.

    This function handles the remapping of FP8 k/v_scale parameter names.
    It detects if the given name ends with a suffix and attempts to remap
    it to the expected name format in the model. If the remapped name is not
    found in the params_dict, a warning is printed and None is returned.

    Args:
        name (str): The original loaded checkpoint parameter name.
        params_dict (dict): Dictionary containing the model's named parameters.

    Returns:
        str: The remapped parameter name if successful, or the original name
             if no remapping is needed.
        None: If the remapped name is not found in params_dict.
    """
    if name.endswith(".kv_scale"):
        logger.warning_once(
            "DEPRECATED. Found kv_scale in the checkpoint. "
            "This format is deprecated in favor of separate k_scale and "
            "v_scale tensors and will be removed in a future release. "
            "Functionally, we will remap kv_scale to k_scale and duplicate "
            "k_scale to v_scale")
        # NOTE: we remap the deprecated kv_scale to k_scale
        remapped_name = name.replace(".kv_scale", ".attn.k_scale")
        if remapped_name not in params_dict:
            logger.warning_once(
                f"Found kv_scale in the checkpoint (e.g. {name}), "
                "but not found the expected name in the model "
                f"(e.g. {remapped_name}). kv_scale is "
                "not loaded.")
            return None
        return remapped_name

    possible_scale_names = [".k_scale", ".v_scale"]
    modelopt_scale_names = [
        ".self_attn.k_proj.k_scale", ".self_attn.v_proj.v_scale"
    ]
    for scale_name in possible_scale_names:
        if name.endswith(scale_name):
            if any(mo_scale_name in name
                   for mo_scale_name in modelopt_scale_names):
                remapped_name = name.replace(
                    f".self_attn.{scale_name[1]}_proj{scale_name}",
                    f".self_attn.attn{scale_name}")
            else:
                remapped_name = name.replace(scale_name, f".attn{scale_name}")
            if remapped_name not in params_dict:
                logger.warning_once(
                    f"Found {scale_name} in the checkpoint (e.g. {name}), "
                    "but not found the expected name in the model "
                    f"(e.g. {remapped_name}). {scale_name} is "
                    "not loaded.")
                return None
            return remapped_name

    # If there were no matches, return the untouched param name
    return name
fastvideo.models.loader.weight_utils.pt_weights_iterator
pt_weights_iterator(hf_weights_files: list[str], to_cpu: bool = True) -> Generator[tuple[str, Tensor], None, None]

Iterate over the weights in the model bin/pt files.

Source code in fastvideo/models/loader/weight_utils.py
def pt_weights_iterator(
    hf_weights_files: list[str],
    to_cpu: bool = True,
) -> Generator[tuple[str, torch.Tensor], None, None]:
    """Iterate over the weights in the model bin/pt files."""
    device = "cpu" if to_cpu else str(get_local_torch_device())
    enable_tqdm = not torch.distributed.is_initialized(
    ) or torch.distributed.get_rank() == 0
    for bin_file in tqdm(
            hf_weights_files,
            desc="Loading pt checkpoint shards",
            disable=not enable_tqdm,
            bar_format=_BAR_FORMAT,
    ):
        state = torch.load(bin_file, map_location=device, weights_only=True)
        yield from state.items()
        del state
fastvideo.models.loader.weight_utils.safetensors_weights_iterator
safetensors_weights_iterator(hf_weights_files: list[str], to_cpu: bool = True) -> Generator[tuple[str, Tensor], None, None]

Iterate over the weights in the model safetensor files.

Source code in fastvideo/models/loader/weight_utils.py
def safetensors_weights_iterator(
    hf_weights_files: list[str],
    to_cpu: bool = True,
) -> Generator[tuple[str, torch.Tensor], None, None]:
    """Iterate over the weights in the model safetensor files."""
    enable_tqdm = not torch.distributed.is_initialized(
    ) or torch.distributed.get_rank() == 0
    device = "cpu" if to_cpu else str(get_local_torch_device())
    for st_file in tqdm(
            hf_weights_files,
            desc="Loading safetensors checkpoint shards",
            disable=not enable_tqdm,
            bar_format=_BAR_FORMAT,
    ):
        with safe_open(st_file, framework="pt", device=device) as f:
            for name in f.keys():  # noqa: SIM118
                param = f.get_tensor(name)
                yield name, param

fastvideo.models.parameter

Classes

fastvideo.models.parameter.BasevLLMParameter
BasevLLMParameter(data: Tensor, weight_loader: Callable)

Bases: Parameter

Base parameter for vLLM linear layers. Extends the torch.nn.parameter by taking in a linear weight loader. Will copy the loaded weight into the parameter when the provided weight loader is called.

Initialize the BasevLLMParameter

:param data: torch tensor with the parameter data :param weight_loader: weight loader callable

:returns: a torch.nn.parameter

Source code in fastvideo/models/parameter.py
def __init__(self, data: torch.Tensor, weight_loader: Callable):
    """
    Initialize the BasevLLMParameter

    :param data: torch tensor with the parameter data
    :param weight_loader: weight loader callable

    :returns: a torch.nn.parameter
    """

    # During weight loading, we often do something like:
    # narrowed_tensor = param.data.narrow(0, offset, len)
    # narrowed_tensor.copy_(real_weight)
    # expecting narrowed_tensor and param.data to share the same storage.
    # However, on TPUs, narrowed_tensor will lazily propagate to the base
    # tensor, which is param.data, leading to the redundant memory usage.
    # This sometimes causes OOM errors during model loading. To avoid this,
    # we sync the param tensor after its weight loader is called.
    from fastvideo.platforms import current_platform
    if current_platform.is_tpu():
        weight_loader = _make_synced_weight_loader(weight_loader)

    self._weight_loader = weight_loader
Functions
fastvideo.models.parameter.BlockQuantScaleParameter
BlockQuantScaleParameter(output_dim: int, **kwargs)

Bases: _ColumnvLLMParameter, RowvLLMParameter

Parameter class for weight scales loaded for weights with block-wise quantization. Uses both column and row parallelism.

Source code in fastvideo/models/parameter.py
def __init__(self, output_dim: int, **kwargs):
    self._output_dim = output_dim
    super().__init__(**kwargs)
fastvideo.models.parameter.ChannelQuantScaleParameter
ChannelQuantScaleParameter(output_dim: int, **kwargs)

Bases: _ColumnvLLMParameter

Parameter class for weight scales loaded for weights with channel-wise quantization. Equivalent to _ColumnvLLMParameter.

Source code in fastvideo/models/parameter.py
def __init__(self, output_dim: int, **kwargs):
    self._output_dim = output_dim
    super().__init__(**kwargs)
fastvideo.models.parameter.GroupQuantScaleParameter
GroupQuantScaleParameter(output_dim: int, **kwargs)

Bases: _ColumnvLLMParameter, RowvLLMParameter

Parameter class for weight scales loaded for weights with grouped quantization. Uses both column and row parallelism.

Source code in fastvideo/models/parameter.py
def __init__(self, output_dim: int, **kwargs):
    self._output_dim = output_dim
    super().__init__(**kwargs)
fastvideo.models.parameter.ModelWeightParameter
ModelWeightParameter(output_dim: int, **kwargs)

Bases: _ColumnvLLMParameter, RowvLLMParameter

Parameter class for linear layer weights. Uses both column and row parallelism.

Source code in fastvideo/models/parameter.py
def __init__(self, output_dim: int, **kwargs):
    self._output_dim = output_dim
    super().__init__(**kwargs)
fastvideo.models.parameter.PackedColumnParameter
PackedColumnParameter(packed_factor: int | Fraction, packed_dim: int, **kwargs)

Bases: _ColumnvLLMParameter

Parameter for model parameters which are packed on disk and support column parallelism only. See PackedvLLMParameter for more details on the packed properties.

Source code in fastvideo/models/parameter.py
def __init__(self, packed_factor: int | Fraction, packed_dim: int,
             **kwargs):
    self._packed_factor = packed_factor
    self._packed_dim = packed_dim
    super().__init__(**kwargs)
fastvideo.models.parameter.PackedvLLMParameter
PackedvLLMParameter(packed_factor: int | Fraction, packed_dim: int, **kwargs)

Bases: ModelWeightParameter

Parameter for model weights which are packed on disk. Example: GPTQ Marlin weights are int4 or int8, packed into int32. Extends the ModelWeightParameter to take in the packed factor, the packed dimension, and optionally, marlin tile size for marlin kernels. Adjusts the shard_size and shard_offset for fused linear layers model weight loading by accounting for packing and optionally, marlin tile size.

Source code in fastvideo/models/parameter.py
def __init__(self, packed_factor: int | Fraction, packed_dim: int,
             **kwargs):
    self._packed_factor = packed_factor
    self._packed_dim = packed_dim
    super().__init__(**kwargs)
fastvideo.models.parameter.PerTensorScaleParameter
PerTensorScaleParameter(**kwargs)

Bases: BasevLLMParameter

Parameter class for scales where the number of scales is equivalent to the number of logical matrices in fused linear layers (e.g. for QKV, there are 3 scales loaded from disk). This is relevant to weights with per-tensor quantization. Adds functionality to map the scalers to a shard during weight loading.

Note: additional parameter manipulation may be handled for each quantization config specifically, within process_weights_after_loading

Source code in fastvideo/models/parameter.py
def __init__(self, **kwargs):
    self.qkv_idxs = {"q": 0, "k": 1, "v": 2}
    super().__init__(**kwargs)
fastvideo.models.parameter.RowvLLMParameter
RowvLLMParameter(input_dim: int, **kwargs)

Bases: BasevLLMParameter

Parameter class defining weight_loading functionality (load_row_parallel_weight) for parameters being loaded into linear layers with row parallel functionality. Requires an input_dim to be defined.

Source code in fastvideo/models/parameter.py
def __init__(self, input_dim: int, **kwargs):
    self._input_dim = input_dim
    super().__init__(**kwargs)

Functions

fastvideo.models.parameter.permute_param_layout_
permute_param_layout_(param: BasevLLMParameter, input_dim: int, output_dim: int, **kwargs) -> BasevLLMParameter

Permute a parameter's layout to the specified input and output dimensions, useful for forcing the parameter into a known layout, for example, if I need a packed (quantized) weight matrix to be in the layout {input_dim = 0, output_dim = 1, packed_dim = 0} then I can call: permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0) to ensure x is in the correct layout (permuting it to the correct layout if required, asserting if it cannot get it to the correct layout)

Source code in fastvideo/models/parameter.py
def permute_param_layout_(param: BasevLLMParameter, input_dim: int,
                          output_dim: int, **kwargs) -> BasevLLMParameter:
    """
    Permute a parameter's layout to the specified input and output dimensions, 
    useful for forcing the parameter into a known layout, for example, if I need
    a packed (quantized) weight matrix to be in the layout 
        {input_dim = 0, output_dim = 1, packed_dim = 0}
    then I can call:
        permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0)
    to ensure x is in the correct layout (permuting it to the correct layout if 
    required, asserting if it cannot get it to the correct layout)
    """

    curr_input_dim = getattr(param, "input_dim", None)
    curr_output_dim = getattr(param, "output_dim", None)

    if curr_input_dim is None or curr_output_dim is None:
        assert param.data.dim() == 2,\
            "permute_param_layout_ only supports 2D parameters when either "\
            "input_dim or output_dim is not set"

    # if one of the dimensions is not set, set it to the opposite of the other
    #  we can only do this since we asserted the parameter is 2D above
    if curr_input_dim is None:
        assert curr_output_dim is not None,\
            "either input or output dim must be set"
        curr_input_dim = (curr_output_dim + 1) % 2
    if curr_output_dim is None:
        assert curr_input_dim is not None,\
            "either input or output dim must be set"
        curr_output_dim = (curr_input_dim + 1) % 2

    # create permutation from the current layout to the layout with
    # self.input_dim at input_dim and self.output_dim at output_dim preserving
    # other dimensions
    perm = [
        i for i in range(param.data.dim())
        if i not in [curr_input_dim, curr_output_dim]
    ]
    perm.insert(input_dim, curr_input_dim)
    perm.insert(output_dim, curr_output_dim)

    if "packed_dim" in kwargs:
        assert hasattr(param, "packed_dim") and\
            param.packed_dim == perm[kwargs["packed_dim"]],\
            "permute_param_layout_ currently doesn't support repacking"

    param.data = param.data.permute(*perm)
    if hasattr(param, "_input_dim"):
        param._input_dim = input_dim
    if hasattr(param, "_output_dim"):
        param._output_dim = output_dim
    if "packed_dim" in kwargs and hasattr(param, "_packed_dim"):
        param._packed_dim = kwargs["packed_dim"]

    return param

fastvideo.models.registry

Functions

fastvideo.models.utils

Utils for model executor.

Functions

fastvideo.models.utils.auto_attributes
auto_attributes(init_func)

Decorator that automatically adds all initialization arguments as object attributes.

Example

@auto_attributes def init(self, a=1, b=2): pass

This will automatically set:
- self.a = 1 and self.b = 2
- self.config.a = 1 and self.config.b = 2
Source code in fastvideo/models/utils.py
def auto_attributes(init_func):
    """
    Decorator that automatically adds all initialization arguments as object attributes.

    Example:
        @auto_attributes
        def __init__(self, a=1, b=2):
            pass

        # This will automatically set:
        # - self.a = 1 and self.b = 2
        # - self.config.a = 1 and self.config.b = 2
    """

    def wrapper(self, *args, **kwargs):
        # Get the function signature
        import inspect
        signature = inspect.signature(init_func)
        parameters = signature.parameters

        # Get parameter names (excluding 'self')
        param_names = list(parameters.keys())[1:]

        # Bind arguments to parameters
        bound_args = signature.bind(self, *args, **kwargs)
        bound_args.apply_defaults()

        # Create config object if it doesn't exist
        if not hasattr(self, 'config'):
            self.config = type('Config', (), {})()

        # Set attributes on self and self.config
        for name in param_names:
            if name in bound_args.arguments:
                value = bound_args.arguments[name]
                setattr(self, name, value)
                setattr(self.config, name, value)

        # Call the original __init__ function
        return init_func(self, *args, **kwargs)

    return wrapper
fastvideo.models.utils.extract_layer_index
extract_layer_index(layer_name: str) -> int

Extract the layer index from the module name. Examples: - "encoder.layers.0" -> 0 - "encoder.layers.1.self_attn" -> 1 - "2.self_attn" -> 2 - "model.encoder.layers.0.sub.1" -> ValueError

Source code in fastvideo/models/utils.py
def extract_layer_index(layer_name: str) -> int:
    """
    Extract the layer index from the module name.
    Examples:
    - "encoder.layers.0" -> 0
    - "encoder.layers.1.self_attn" -> 1
    - "2.self_attn" -> 2
    - "model.encoder.layers.0.sub.1" -> ValueError
    """
    subnames = layer_name.split(".")
    int_vals: list[int] = []
    for subname in subnames:
        try:
            int_vals.append(int(subname))
        except ValueError:
            continue
    assert len(int_vals) == 1, (f"layer name {layer_name} should"
                                " only contain one integer")
    return int_vals[0]
fastvideo.models.utils.modulate
modulate(x: Tensor, shift: Tensor | None = None, scale: Tensor | None = None) -> Tensor

modulate by shift and scale

Parameters:

Name Type Description Default
x Tensor

input tensor.

required
shift Tensor

shift tensor. Defaults to None.

None
scale Tensor

scale tensor. Defaults to None.

None

Returns:

Type Description
Tensor

torch.Tensor: the output tensor after modulate.

Source code in fastvideo/models/utils.py
def modulate(x: torch.Tensor,
             shift: torch.Tensor | None = None,
             scale: torch.Tensor | None = None) -> torch.Tensor:
    """modulate by shift and scale

    Args:
        x (torch.Tensor): input tensor.
        shift (torch.Tensor, optional): shift tensor. Defaults to None.
        scale (torch.Tensor, optional): scale tensor. Defaults to None.

    Returns:
        torch.Tensor: the output tensor after modulate.
    """
    if scale is None and shift is None:
        return x
    elif shift is None:
        return x * (1 + scale.unsqueeze(1))  # type: ignore[union-attr]
    elif scale is None:
        return x + shift.unsqueeze(1)  # type: ignore[union-attr]
    else:
        return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(
            1)  # type: ignore[union-attr]
fastvideo.models.utils.pred_noise_to_pred_video
pred_noise_to_pred_video(pred_noise: Tensor, noise_input_latent: Tensor, timestep: Tensor, scheduler: Any) -> Tensor

Convert predicted noise to clean latent.

pred_noise: the predicted noise with shape [B, C, H, W] where B is batch_size or batch_size * num_frames noise_input_latent: the noisy latent with shape [B, C, H, W], timestep: the timestep with shape [1] or [bs * num_frames] or [bs, num_frames] scheduler: the scheduler

Returns:

Type Description
Tensor

the predicted video with shape [B, C, H, W]

Source code in fastvideo/models/utils.py
def pred_noise_to_pred_video(pred_noise: torch.Tensor,
                             noise_input_latent: torch.Tensor,
                             timestep: torch.Tensor,
                             scheduler: Any) -> torch.Tensor:
    """
    Convert predicted noise to clean latent.

    Args:
    pred_noise: the predicted noise with shape [B, C, H, W]
        where B is batch_size or batch_size * num_frames
    noise_input_latent: the noisy latent with shape [B, C, H, W],
    timestep: the timestep with shape [1] or [bs * num_frames] or [bs, num_frames]
    scheduler: the scheduler

    Returns:
        the predicted video with shape [B, C, H, W]
    """
    # If timestep is [bs, num_frames]
    if timestep.ndim == 2:
        timestep = timestep.flatten(0, 1)
        assert timestep.numel() == noise_input_latent.shape[0]
    elif timestep.ndim == 1:
        # If timestep is [1]
        if timestep.shape[0] == 1:
            timestep = timestep.expand(noise_input_latent.shape[0])
        else:
            assert timestep.numel() == noise_input_latent.shape[0]
    else:
        raise ValueError(f"[pred_noise_to_pred_video] Invalid timestep shape: {timestep.shape}")
    # timestep shape should be [B]
    dtype = pred_noise.dtype
    device = pred_noise.device
    pred_noise = pred_noise.double().to(device)
    noise_input_latent = noise_input_latent.double().to(device)
    sigmas = scheduler.sigmas.double().to(device)
    timesteps = scheduler.timesteps.double().to(device)
    timestep_id = torch.argmin(
        (timesteps.unsqueeze(0) - timestep.unsqueeze(1)).abs(), dim=1)
    sigma_t = sigmas[timestep_id].reshape(-1, 1, 1, 1)
    pred_video = noise_input_latent - sigma_t * pred_noise
    return pred_video.to(dtype)
fastvideo.models.utils.set_weight_attrs
set_weight_attrs(weight: Tensor, weight_attrs: dict[str, Any] | None)

Set attributes on a weight tensor.

This method is used to set attributes on a weight tensor. This method will not overwrite existing attributes.

Parameters:

Name Type Description Default
weight Tensor

The weight tensor.

required
weight_attrs dict[str, Any] | None

A dictionary of attributes to set on the weight tensor.

required
Source code in fastvideo/models/utils.py
def set_weight_attrs(
    weight: torch.Tensor,
    weight_attrs: dict[str, Any] | None,
):
    """Set attributes on a weight tensor.

    This method is used to set attributes on a weight tensor. This method
    will not overwrite existing attributes.

    Args:
        weight: The weight tensor.
        weight_attrs: A dictionary of attributes to set on the weight tensor.
    """
    if weight_attrs is None:
        return
    for key, value in weight_attrs.items():
        assert not hasattr(
            weight, key), (f"Overwriting existing tensor attribute: {key}")

        # NOTE(woosuk): During weight loading, we often do something like:
        # narrowed_tensor = param.data.narrow(0, offset, len)
        # narrowed_tensor.copy_(real_weight)
        # expecting narrowed_tensor and param.data to share the same storage.
        # However, on TPUs, narrowed_tensor will lazily propagate to the base
        # tensor, which is param.data, leading to the redundant memory usage.
        # This sometimes causes OOM errors during model loading. To avoid this,
        # we sync the param tensor after its weight loader is called.
        # TODO(woosuk): Remove this hack once we have a better solution.
        from fastvideo.platforms import current_platform
        if current_platform.is_tpu() and key == "weight_loader":
            value = _make_synced_weight_loader(value)
        setattr(weight, key, value)

fastvideo.models.vision_utils

Functions

fastvideo.models.vision_utils.create_default_image
create_default_image(width: int = 512, height: int = 512, color: tuple[int, int, int] = (0, 0, 0)) -> Image

Create a default black PIL image.

Parameters:

Name Type Description Default
width int

Image width in pixels

512
height int

Image height in pixels

512
color tuple[int, int, int]

RGB color tuple

(0, 0, 0)

Returns:

Type Description
Image

PIL.Image.Image: A new PIL image with specified dimensions and color

Source code in fastvideo/models/vision_utils.py
def create_default_image(width: int = 512, height: int = 512, color: tuple[int, int, int] = (0, 0, 0)) -> PIL.Image.Image:
    """
    Create a default black PIL image.

    Args:
        width: Image width in pixels
        height: Image height in pixels
        color: RGB color tuple

    Returns:
        PIL.Image.Image: A new PIL image with specified dimensions and color
    """
    return PIL.Image.new("RGB", (width, height), color=color)
fastvideo.models.vision_utils.get_default_height_width
get_default_height_width(image: Image | ndarray | Tensor, vae_scale_factor: int, height: int | None = None, width: int | None = None) -> tuple[int, int]

Returns the height and width of the image, downscaled to the next integer multiple of vae_scale_factor.

Parameters:

Name Type Description Default
image `Union[PIL.Image.Image, np.ndarray, torch.Tensor]`

The image input, which can be a PIL image, NumPy array, or PyTorch tensor. If it is a NumPy array, it should have shape [batch, height, width] or [batch, height, width, channels]. If it is a PyTorch tensor, it should have shape [batch, channels, height, width].

required
height `Optional[int]`, *optional*, defaults to `None`

The height of the preprocessed image. If None, the height of the image input will be used.

None
width `Optional[int]`, *optional*, defaults to `None`

The width of the preprocessed image. If None, the width of the image input will be used.

None

Returns:

Type Description
tuple[int, int]

Tuple[int, int]: A tuple containing the height and width, both resized to the nearest integer multiple of vae_scale_factor.

Source code in fastvideo/models/vision_utils.py
def get_default_height_width(
    image: PIL.Image.Image | np.ndarray | torch.Tensor,
    vae_scale_factor: int,
    height: int | None = None,
    width: int | None = None,
) -> tuple[int, int]:
    r"""
    Returns the height and width of the image, downscaled to the next integer multiple of `vae_scale_factor`.

    Args:
        image (`Union[PIL.Image.Image, np.ndarray, torch.Tensor]`):
            The image input, which can be a PIL image, NumPy array, or PyTorch tensor. If it is a NumPy array, it
            should have shape `[batch, height, width]` or `[batch, height, width, channels]`. If it is a PyTorch
            tensor, it should have shape `[batch, channels, height, width]`.
        height (`Optional[int]`, *optional*, defaults to `None`):
            The height of the preprocessed image. If `None`, the height of the `image` input will be used.
        width (`Optional[int]`, *optional*, defaults to `None`):
            The width of the preprocessed image. If `None`, the width of the `image` input will be used.

    Returns:
        `Tuple[int, int]`:
            A tuple containing the height and width, both resized to the nearest integer multiple of
            `vae_scale_factor`.
    """

    if height is None:
        if isinstance(image, PIL.Image.Image):
            height = image.height
        elif isinstance(image, torch.Tensor):
            height = image.shape[2]
        else:
            height = image.shape[1]

    if width is None:
        if isinstance(image, PIL.Image.Image):
            width = image.width
        elif isinstance(image, torch.Tensor):
            width = image.shape[3]
        else:
            width = image.shape[2]

    width, height = (x - x % vae_scale_factor for x in (width, height)
                     )  # resize to integer multiple of vae_scale_factor

    return height, width
fastvideo.models.vision_utils.load_image
load_image(image: str | Image, convert_method: Callable[[Image], Image] | None = None) -> Image

Loads image to a PIL Image.

Parameters:

Name Type Description Default
image `str` or `PIL.Image.Image`

The image to convert to the PIL Image format.

required
convert_method Callable[[PIL.Image.Image], PIL.Image.Image], *optional*

A conversion method to apply to the image after loading it. When set to None the image will be converted "RGB".

None

Returns:

Type Description
Image

PIL.Image.Image: A PIL Image.

Source code in fastvideo/models/vision_utils.py
def load_image(
    image: str | PIL.Image.Image,
    convert_method: Callable[[PIL.Image.Image], PIL.Image.Image] | None = None
) -> PIL.Image.Image:
    """
    Loads `image` to a PIL Image.

    Args:
        image (`str` or `PIL.Image.Image`):
            The image to convert to the PIL Image format.
        convert_method (Callable[[PIL.Image.Image], PIL.Image.Image], *optional*):
            A conversion method to apply to the image after loading it. When set to `None` the image will be converted
            "RGB".

    Returns:
        `PIL.Image.Image`:
            A PIL Image.
    """
    if isinstance(image, str):
        if image.startswith("http://") or image.startswith("https://"):
            image = PIL.Image.open(requests.get(image, stream=True).raw)
        elif os.path.isfile(image):
            image = PIL.Image.open(image)
        else:
            raise ValueError(
                f"Incorrect path or URL. URLs must start with `http://` or `https://`, and {image} is not a valid path."
            )
    elif isinstance(image, PIL.Image.Image):
        image = image
    else:
        raise ValueError(
            "Incorrect format used for the image. Should be a URL linking to an image, a local path, or a PIL image."
        )

    image = PIL.ImageOps.exif_transpose(image)

    if convert_method is not None:
        image = convert_method(image)
    else:
        image = image.convert("RGB")

    return image
fastvideo.models.vision_utils.load_video
load_video(video: str, convert_method: Callable[[list[Image]], list[Image]] | None = None, return_fps: bool = False) -> tuple[list[Image], float | Any] | list[Image]

Loads video to a list of PIL Image. Args: video (str): A URL or Path to a video to convert to a list of PIL Image format. convert_method (Callable[[List[PIL.Image.Image]], List[PIL.Image.Image]], optional): A conversion method to apply to the video after loading it. When set to None the images will be converted to "RGB". return_fps (bool, optional, defaults to False): Whether to return the FPS of the video. If True, returns a tuple of (images, fps). If False, returns only the list of images. Returns: List[PIL.Image.Image] or Tuple[List[PIL.Image.Image], float | None]: The video as a list of PIL images. If return_fps is True, also returns the original FPS.

Source code in fastvideo/models/vision_utils.py
def load_video(
    video: str,
    convert_method: Callable[[list[PIL.Image.Image]], list[PIL.Image.Image]]
    | None = None,
    return_fps: bool = False,
) -> tuple[list[PIL.Image.Image], float | Any] | list[PIL.Image.Image]:
    """
    Loads `video` to a list of PIL Image.
    Args:
        video (`str`):
            A URL or Path to a video to convert to a list of PIL Image format.
        convert_method (Callable[[List[PIL.Image.Image]], List[PIL.Image.Image]], *optional*):
            A conversion method to apply to the video after loading it. When set to `None` the images will be converted
            to "RGB".
        return_fps (`bool`, *optional*, defaults to `False`):
            Whether to return the FPS of the video. If `True`, returns a tuple of (images, fps).
            If `False`, returns only the list of images.
    Returns:
        `List[PIL.Image.Image]` or `Tuple[List[PIL.Image.Image], float | None]`:
            The video as a list of PIL images. If `return_fps` is True, also returns the original FPS.
    """
    is_url = video.startswith("http://") or video.startswith("https://")
    is_file = os.path.isfile(video)
    was_tempfile_created = False

    if not (is_url or is_file):
        raise ValueError(
            f"Incorrect path or URL. URLs must start with `http://` or `https://`, and {video} is not a valid path."
        )

    if is_url:
        response = requests.get(video, stream=True)
        if response.status_code != 200:
            raise ValueError(
                f"Failed to download video. Status code: {response.status_code}"
            )

        parsed_url = urlparse(video)
        file_name = os.path.basename(unquote(parsed_url.path))

        suffix = os.path.splitext(file_name)[1] or ".mp4"
        with tempfile.NamedTemporaryFile(suffix=suffix,
                                         delete=False) as temp_file:
            video_path = temp_file.name
            video_data = response.iter_content(chunk_size=8192)
            for chunk in video_data:
                temp_file.write(chunk)
        was_tempfile_created = True
    else:
        video_path = video

    pil_images = []
    original_fps = None

    try:
        if video_path.endswith(".gif"):
            pil_images, original_fps = _load_gif(video_path)
        else:
            pil_images, original_fps = _load_video_with_ffmpeg(video_path)
    finally:
        # Clean up temporary file if it was created
        if was_tempfile_created and os.path.exists(video_path):
            os.remove(video_path)

    if convert_method is not None:
        pil_images = convert_method(pil_images)

    return pil_images, original_fps if return_fps else pil_images
fastvideo.models.vision_utils.normalize
normalize(images: ndarray | Tensor) -> ndarray | Tensor

Normalize an image array to [-1,1].

Parameters:

Name Type Description Default
images `np.ndarray` or `torch.Tensor`

The image array to normalize.

required

Returns:

Type Description
ndarray | Tensor

np.ndarray or torch.Tensor: The normalized image array.

Source code in fastvideo/models/vision_utils.py
def normalize(images: np.ndarray | torch.Tensor) -> np.ndarray | torch.Tensor:
    r"""
    Normalize an image array to [-1,1].

    Args:
        images (`np.ndarray` or `torch.Tensor`):
            The image array to normalize.

    Returns:
        `np.ndarray` or `torch.Tensor`:
            The normalized image array.
    """
    return 2.0 * images - 1.0
fastvideo.models.vision_utils.numpy_to_pt
numpy_to_pt(images: ndarray) -> Tensor

Convert a NumPy image to a PyTorch tensor.

Parameters:

Name Type Description Default
images `np.ndarray`

The NumPy image array to convert to PyTorch format.

required

Returns:

Type Description
Tensor

torch.Tensor: A PyTorch tensor representation of the images.

Source code in fastvideo/models/vision_utils.py
def numpy_to_pt(images: np.ndarray) -> torch.Tensor:
    r"""
    Convert a NumPy image to a PyTorch tensor.

    Args:
        images (`np.ndarray`):
            The NumPy image array to convert to PyTorch format.

    Returns:
        `torch.Tensor`:
            A PyTorch tensor representation of the images.
    """
    if images.ndim == 3:
        images = images[..., None]

    images = torch.from_numpy(images.transpose(0, 3, 1, 2))
    return images
fastvideo.models.vision_utils.pil_to_numpy
pil_to_numpy(images: list[Image] | Image) -> ndarray

Convert a PIL image or a list of PIL images to NumPy arrays.

Parameters:

Name Type Description Default
images `PIL.Image.Image` or `List[PIL.Image.Image]`

The PIL image or list of images to convert to NumPy format.

required

Returns:

Type Description
ndarray

np.ndarray: A NumPy array representation of the images.

Source code in fastvideo/models/vision_utils.py
def pil_to_numpy(images: list[PIL.Image.Image] | PIL.Image.Image) -> np.ndarray:
    r"""
    Convert a PIL image or a list of PIL images to NumPy arrays.

    Args:
        images (`PIL.Image.Image` or `List[PIL.Image.Image]`):
            The PIL image or list of images to convert to NumPy format.

    Returns:
        `np.ndarray`:
            A NumPy array representation of the images.
    """
    if not isinstance(images, list):
        images = [images]
    images = [np.array(image).astype(np.float32) / 255.0 for image in images]
    images_arr: np.ndarray = np.stack(images, axis=0)

    return images_arr
fastvideo.models.vision_utils.preprocess_reference_image_for_clip
preprocess_reference_image_for_clip(image: Image, device: device) -> Image

Preprocess reference image to match CLIP encoder requirements.

Applies normalization, resizing to 224x224, and denormalization to ensure the image is in the correct format for CLIP processing.

Parameters:

Name Type Description Default
image Image

Input PIL image

required
device device

Target device for tensor operations

required

Returns:

Type Description
Image

Preprocessed PIL image ready for CLIP encoder

Source code in fastvideo/models/vision_utils.py
def preprocess_reference_image_for_clip(image: PIL.Image.Image, device: torch.device) -> PIL.Image.Image:
    """
    Preprocess reference image to match CLIP encoder requirements.

    Applies normalization, resizing to 224x224, and denormalization to ensure
    the image is in the correct format for CLIP processing.

    Args:
        image: Input PIL image
        device: Target device for tensor operations

    Returns:
        Preprocessed PIL image ready for CLIP encoder
    """
    # Convert PIL to tensor and normalize to [-1, 1] range
    image_tensor = TF.to_tensor(image).sub_(0.5).div_(0.5).to(device)

    # Resize to CLIP's expected input size (224x224) using bicubic interpolation
    resized_tensor = F.interpolate(
        image_tensor.unsqueeze(0),
        size=(224, 224),
        mode='bicubic',
        align_corners=False
    ).squeeze(0)

    # Denormalize back to [0, 1] range
    denormalized_tensor = resized_tensor.mul_(0.5).add_(0.5)

    return TF.to_pil_image(denormalized_tensor)
fastvideo.models.vision_utils.resize
resize(image: Image | ndarray | Tensor, height: int, width: int, resize_mode: str = 'default', resample: str = 'lanczos') -> Image | ndarray | Tensor

Resize image.

Parameters:

Name Type Description Default
image `PIL.Image.Image`, `np.ndarray` or `torch.Tensor`

The image input, can be a PIL image, numpy array or pytorch tensor.

required
height `int`

The height to resize to.

required
width `int`

The width to resize to.

required
resize_mode `str`, *optional*, defaults to `default`

The resize mode to use, can be one of default or fill. If default, will resize the image to fit within the specified width and height, and it may not maintaining the original aspect ratio. If fill, will resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center the image within the dimensions, filling empty with data from image. If crop, will resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center the image within the dimensions, cropping the excess. Note that resize_mode fill and crop are only supported for PIL image input.

'default'

Returns:

Type Description
Image | ndarray | Tensor

PIL.Image.Image, np.ndarray or torch.Tensor: The resized image.

Source code in fastvideo/models/vision_utils.py
def resize(
    image: PIL.Image.Image | np.ndarray | torch.Tensor,
    height: int,
    width: int,
    resize_mode: str = "default",  # "default", "fill", "crop"
    resample: str = "lanczos",
) -> PIL.Image.Image | np.ndarray | torch.Tensor:
    """
    Resize image.

    Args:
        image (`PIL.Image.Image`, `np.ndarray` or `torch.Tensor`):
            The image input, can be a PIL image, numpy array or pytorch tensor.
        height (`int`):
            The height to resize to.
        width (`int`):
            The width to resize to.
        resize_mode (`str`, *optional*, defaults to `default`):
            The resize mode to use, can be one of `default` or `fill`. If `default`, will resize the image to fit
            within the specified width and height, and it may not maintaining the original aspect ratio. If `fill`,
            will resize the image to fit within the specified width and height, maintaining the aspect ratio, and
            then center the image within the dimensions, filling empty with data from image. If `crop`, will resize
            the image to fit within the specified width and height, maintaining the aspect ratio, and then center
            the image within the dimensions, cropping the excess. Note that resize_mode `fill` and `crop` are only
            supported for PIL image input.

    Returns:
        `PIL.Image.Image`, `np.ndarray` or `torch.Tensor`:
            The resized image.
    """
    if resize_mode != "default" and not isinstance(image, PIL.Image.Image):
        raise ValueError(
            f"Only PIL image input is supported for resize_mode {resize_mode}")
    assert isinstance(image, PIL.Image.Image)
    if resize_mode == "default":
        image = image.resize((width, height),
                             resample=PIL_INTERPOLATION[resample])
    else:
        raise ValueError(f"resize_mode {resize_mode} is not supported")
    return image