Skip to content

component_loader

Classes

fastvideo.models.loader.component_loader.ComponentLoader

ComponentLoader(device=None)

Bases: ABC

Base class for loading a specific type of model component.

Source code in fastvideo/models/loader/component_loader.py
def __init__(self, device=None) -> None:
    self.device = device

Functions

fastvideo.models.loader.component_loader.ComponentLoader.for_module_type classmethod
for_module_type(module_type: str, transformers_or_diffusers: str) -> ComponentLoader

Factory method to create a component loader for a specific module type.

Parameters:

Name Type Description Default
module_type str

Type of module (e.g., "vae", "text_encoder", "transformer", "scheduler")

required
transformers_or_diffusers str

Whether the module is from transformers or diffusers

required

Returns:

Type Description
ComponentLoader

A component loader for the specified module type

Source code in fastvideo/models/loader/component_loader.py
@classmethod
def for_module_type(
    cls, module_type: str, transformers_or_diffusers: str
) -> "ComponentLoader":
    """
    Factory method to create a component loader for a specific module type.

    Args:
        module_type: Type of module (e.g., "vae", "text_encoder", "transformer", "scheduler")
        transformers_or_diffusers: Whether the module is from transformers or diffusers

    Returns:
        A component loader for the specified module type
    """
    # Map of module types to their loader classes and expected library
    module_loaders = {
        "scheduler": (SchedulerLoader, "diffusers"),
        "transformer": (TransformerLoader, "diffusers"),
        "transformer_2": (TransformerLoader, "diffusers"),
        "vae": (VAELoader, "diffusers"),
        "text_encoder": (TextEncoderLoader, "transformers"),
        "text_encoder_2": (TextEncoderLoader, "transformers"),
        "tokenizer": (TokenizerLoader, "transformers"),
        "tokenizer_2": (TokenizerLoader, "transformers"),
        "image_processor": (ImageProcessorLoader, "transformers"),
        "image_encoder": (ImageEncoderLoader, "transformers"),
    }

    if module_type in module_loaders:
        loader_cls, expected_library = module_loaders[module_type]
        # Allow fastvideo.* libraries for custom implementations (e.g. Cosmos2_5Pipeline)
        # that aren't available in diffusers/transformers yet
        is_fastvideo_module = transformers_or_diffusers.startswith("fastvideo.")
        if not is_fastvideo_module:
            # Assert that the library matches what's expected for this module type
            assert transformers_or_diffusers == expected_library, f"{module_type} must be loaded from {expected_library}, got {transformers_or_diffusers}"
        return loader_cls()

    # For unknown module types, use a generic loader
    logger.warning(
        "No specific loader found for module type: %s. Using generic loader.",
        module_type,
    )
    return GenericComponentLoader(transformers_or_diffusers)
fastvideo.models.loader.component_loader.ComponentLoader.load abstractmethod
load(model_path: str, fastvideo_args: FastVideoArgs)

Load the component based on the model path, architecture, and inference args.

Parameters:

Name Type Description Default
model_path str

Path to the component model

required
fastvideo_args FastVideoArgs

FastVideoArgs

required

Returns:

Type Description

The loaded component

Source code in fastvideo/models/loader/component_loader.py
@abstractmethod
def load(self, model_path: str, fastvideo_args: FastVideoArgs):
    """
    Load the component based on the model path, architecture, and inference args.

    Args:
        model_path: Path to the component model
        fastvideo_args: FastVideoArgs

    Returns:
        The loaded component
    """
    raise NotImplementedError

fastvideo.models.loader.component_loader.GenericComponentLoader

GenericComponentLoader(library='transformers')

Bases: ComponentLoader

Generic loader for components that don't have a specific loader.

Source code in fastvideo/models/loader/component_loader.py
def __init__(self, library="transformers") -> None:
    super().__init__()
    self.library = library

Functions

fastvideo.models.loader.component_loader.GenericComponentLoader.load
load(model_path: str, fastvideo_args: FastVideoArgs)

Load a generic component based on the model path, and inference args.

Source code in fastvideo/models/loader/component_loader.py
def load(self, model_path: str, fastvideo_args: FastVideoArgs):
    """Load a generic component based on the model path, and inference args."""
    logger.warning(
        "Using generic loader for %s with library %s",
        model_path,
        self.library,
    )

    if self.library == "transformers":
        from transformers import AutoModel

        model = AutoModel.from_pretrained(
            model_path,
            trust_remote_code=fastvideo_args.trust_remote_code,
            revision=fastvideo_args.revision,
        )
        logger.info(
            "Loaded generic transformers model: %s",
            model.__class__.__name__,
        )
        return model
    elif self.library == "diffusers":
        logger.warning(
            "Generic loading for diffusers components is not fully implemented"
        )

        model_config = get_diffusers_config(model=model_path)
        logger.info("Diffusers Model config: %s", model_config)
        # This is a placeholder - in a real implementation, you'd need to handle this properly
        return None
    else:
        raise ValueError(f"Unsupported library: {self.library}")

fastvideo.models.loader.component_loader.ImageEncoderLoader

ImageEncoderLoader(device=None)

Bases: TextEncoderLoader

Source code in fastvideo/models/loader/component_loader.py
def __init__(self, device=None) -> None:
    self.device = device

Functions

fastvideo.models.loader.component_loader.ImageEncoderLoader.load
load(model_path: str, fastvideo_args: FastVideoArgs)

Load the text encoders based on the model path, and inference args.

Source code in fastvideo/models/loader/component_loader.py
def load(self, model_path: str, fastvideo_args: FastVideoArgs):
    """Load the text encoders based on the model path, and inference args."""
    # model_config: PretrainedConfig = get_hf_config(
    #     model=model_path,
    #     trust_remote_code=fastvideo_args.trust_remote_code,
    #     revision=fastvideo_args.revision,
    #     model_override_args=None,
    # )
    with open(os.path.join(model_path, "config.json")) as f:
        model_config = json.load(f)
    model_config.pop("_name_or_path", None)
    model_config.pop("transformers_version", None)
    model_config.pop("torch_dtype", None)
    model_config.pop("model_type", None)
    logger.info("HF Model config: %s", model_config)

    encoder_config = fastvideo_args.pipeline_config.image_encoder_config
    encoder_config.update_model_arch(model_config)

    from fastvideo.platforms import current_platform

    if fastvideo_args.image_encoder_cpu_offload:
        target_device = (
            torch.device("mps")
            if current_platform.is_mps()
            else torch.device("cpu")
        )
    else:
        target_device = get_local_torch_device()
    # TODO(will): add support for other dtypes
    return self.load_model(
        model_path,
        encoder_config,
        target_device,
        fastvideo_args,
        fastvideo_args.pipeline_config.image_encoder_precision,
    )

fastvideo.models.loader.component_loader.ImageProcessorLoader

ImageProcessorLoader(device=None)

Bases: ComponentLoader

Loader for image processor.

Source code in fastvideo/models/loader/component_loader.py
def __init__(self, device=None) -> None:
    self.device = device

Functions

fastvideo.models.loader.component_loader.ImageProcessorLoader.load
load(model_path: str, fastvideo_args: FastVideoArgs)

Load the image processor based on the model path, and inference args.

Source code in fastvideo/models/loader/component_loader.py
def load(self, model_path: str, fastvideo_args: FastVideoArgs):
    """Load the image processor based on the model path, and inference args."""
    logger.info("Loading image processor from %s", model_path)

    image_processor = AutoImageProcessor.from_pretrained(
        model_path,
    )
    logger.info(
        "Loaded image processor: %s", image_processor.__class__.__name__
    )
    return image_processor

fastvideo.models.loader.component_loader.PipelineComponentLoader

Utility class for loading pipeline components. This replaces the chain of if-else statements in load_pipeline_module.

Functions

fastvideo.models.loader.component_loader.PipelineComponentLoader.load_module staticmethod
load_module(module_name: str, component_model_path: str, transformers_or_diffusers: str, fastvideo_args: FastVideoArgs)

Load a pipeline module.

Parameters:

Name Type Description Default
module_name str

Name of the module (e.g., "vae", "text_encoder", "transformer", "scheduler")

required
component_model_path str

Path to the component model

required
transformers_or_diffusers str

Whether the module is from transformers or diffusers

required
pipeline_args

Inference arguments

required

Returns:

Type Description

The loaded module

Source code in fastvideo/models/loader/component_loader.py
@staticmethod
def load_module(
    module_name: str,
    component_model_path: str,
    transformers_or_diffusers: str,
    fastvideo_args: FastVideoArgs,
):
    """
    Load a pipeline module.

    Args:
        module_name: Name of the module (e.g., "vae", "text_encoder", "transformer", "scheduler")
        component_model_path: Path to the component model
        transformers_or_diffusers: Whether the module is from transformers or diffusers
        pipeline_args: Inference arguments

    Returns:
        The loaded module
    """
    logger.info(
        "Loading %s using %s from %s",
        module_name,
        transformers_or_diffusers,
        component_model_path,
    )

    # Get the appropriate loader for this module type
    loader = ComponentLoader.for_module_type(
        module_name, transformers_or_diffusers
    )

    # Load the module
    return loader.load(component_model_path, fastvideo_args)

fastvideo.models.loader.component_loader.SchedulerLoader

SchedulerLoader(device=None)

Bases: ComponentLoader

Loader for scheduler.

Source code in fastvideo/models/loader/component_loader.py
def __init__(self, device=None) -> None:
    self.device = device

Functions

fastvideo.models.loader.component_loader.SchedulerLoader.load
load(model_path: str, fastvideo_args: FastVideoArgs)

Load the scheduler based on the model path, and inference args.

Source code in fastvideo/models/loader/component_loader.py
def load(self, model_path: str, fastvideo_args: FastVideoArgs):
    """Load the scheduler based on the model path, and inference args."""
    config = get_diffusers_config(model=model_path)

    class_name = config.pop("_class_name")
    assert class_name is not None, (
        "Model config does not contain a _class_name attribute. Only diffusers format is supported."
    )

    scheduler_cls, _ = ModelRegistry.resolve_model_cls(class_name)

    scheduler = scheduler_cls(**config)
    if fastvideo_args.pipeline_config.flow_shift is not None:
        scheduler.set_shift(fastvideo_args.pipeline_config.flow_shift)
    if fastvideo_args.pipeline_config.timesteps_scale is not None:
        scheduler.set_timesteps_scale(
            fastvideo_args.pipeline_config.timesteps_scale
        )
    return scheduler

fastvideo.models.loader.component_loader.TextEncoderLoader

TextEncoderLoader(device=None)

Bases: ComponentLoader

Loader for text encoders.

Source code in fastvideo/models/loader/component_loader.py
def __init__(self, device=None) -> None:
    self.device = device

Classes

fastvideo.models.loader.component_loader.TextEncoderLoader.Source dataclass
Source(model_or_path: str, prefix: str = '', fall_back_to_pt: bool = True, allow_patterns_overrides: list[str] | None = None)

A source for weights.

Attributes
fastvideo.models.loader.component_loader.TextEncoderLoader.Source.allow_patterns_overrides class-attribute instance-attribute
allow_patterns_overrides: list[str] | None = None

If defined, weights will load exclusively using these patterns.

fastvideo.models.loader.component_loader.TextEncoderLoader.Source.fall_back_to_pt class-attribute instance-attribute
fall_back_to_pt: bool = True

Whether .pt weights can be used.

fastvideo.models.loader.component_loader.TextEncoderLoader.Source.model_or_path instance-attribute
model_or_path: str

The model ID or path.

fastvideo.models.loader.component_loader.TextEncoderLoader.Source.prefix class-attribute instance-attribute
prefix: str = ''

A prefix to prepend to all weights.

Functions

fastvideo.models.loader.component_loader.TextEncoderLoader.load
load(model_path: str, fastvideo_args: FastVideoArgs)

Load the text encoders based on the model path, and inference args.

Source code in fastvideo/models/loader/component_loader.py
def load(self, model_path: str, fastvideo_args: FastVideoArgs):
    """Load the text encoders based on the model path, and inference args."""
    # model_config: PretrainedConfig = get_hf_config(
    #     model=model_path,
    #     trust_remote_code=fastvideo_args.trust_remote_code,
    #     revision=fastvideo_args.revision,
    #     model_override_args=None,
    # )
    model_config = get_diffusers_config(model=model_path)
    model_config.pop("_name_or_path", None)
    model_config.pop("transformers_version", None)
    model_config.pop("model_type", None)
    model_config.pop("tokenizer_class", None)
    model_config.pop("torch_dtype", None)
    logger.info("HF Model config: %s", model_config)

    # @TODO(Wei): Better way to handle this?
    try:
        encoder_config = (
            fastvideo_args.pipeline_config.text_encoder_configs[0]
        )
        encoder_config.update_model_arch(model_config)
        encoder_precision = (
            fastvideo_args.pipeline_config.text_encoder_precisions[0]
        )
    except Exception:
        encoder_config = (
            fastvideo_args.pipeline_config.text_encoder_configs[1]
        )
        encoder_config.update_model_arch(model_config)
        encoder_precision = (
            fastvideo_args.pipeline_config.text_encoder_precisions[1]
        )

    target_device = get_local_torch_device()
    # TODO(will): add support for other dtypes
    return self.load_model(
        model_path,
        encoder_config,
        target_device,
        fastvideo_args,
        encoder_precision,
        use_text_encoder_override=True,
    )

fastvideo.models.loader.component_loader.TokenizerLoader

TokenizerLoader(device=None)

Bases: ComponentLoader

Loader for tokenizers.

Source code in fastvideo/models/loader/component_loader.py
def __init__(self, device=None) -> None:
    self.device = device

Functions

fastvideo.models.loader.component_loader.TokenizerLoader.load
load(model_path: str, fastvideo_args: FastVideoArgs)

Load the tokenizer based on the model path, and inference args.

Source code in fastvideo/models/loader/component_loader.py
def load(self, model_path: str, fastvideo_args: FastVideoArgs):
    """Load the tokenizer based on the model path, and inference args."""
    logger.info("Loading tokenizer from %s", model_path)

    # Cosmos2.5 stores an AutoProcessor config in `tokenizer/config.json` (not a tokenizer
    # config). Use its `_name_or_path` (e.g. Qwen/Qwen2.5-VL-7B-Instruct) as the source.
    tokenizer_cfg_path = os.path.join(model_path, "config.json")
    if os.path.exists(tokenizer_cfg_path):
        try:
            with open(tokenizer_cfg_path, "r") as f:
                tokenizer_cfg = json.load(f)
            if isinstance(tokenizer_cfg, dict) and (
                tokenizer_cfg.get("_class_name") == "AutoProcessor"
                or "processor_type" in tokenizer_cfg
            ):
                src = tokenizer_cfg.get("_name_or_path", "")
                if isinstance(src, str) and src.strip():
                    processor = AutoProcessor.from_pretrained(
                        src.strip(),
                        trust_remote_code=True,
                    )
                    logger.info(
                        "Loaded tokenizer/processor from %s: %s",
                        src,
                        processor.__class__.__name__,
                    )
                    return processor
        except Exception:
            # If parsing fails, fall through to AutoTokenizer below.
            pass

    tokenizer = AutoTokenizer.from_pretrained(
        model_path,  # "<path to model>/tokenizer"
        # in v0, this was same string as encoder_name "ClipTextModel"
        # TODO(will): pass these tokenizer kwargs from inference args? Maybe
        # other method of config?
        padding_size="right",
    )
    logger.info("Loaded tokenizer: %s", tokenizer.__class__.__name__)
    return tokenizer

fastvideo.models.loader.component_loader.TransformerLoader

TransformerLoader(device=None)

Bases: ComponentLoader

Loader for transformer.

Source code in fastvideo/models/loader/component_loader.py
def __init__(self, device=None) -> None:
    self.device = device

Functions

fastvideo.models.loader.component_loader.TransformerLoader.load
load(model_path: str, fastvideo_args: FastVideoArgs)

Load the transformer based on the model path, and inference args.

Source code in fastvideo/models/loader/component_loader.py
def load(self, model_path: str, fastvideo_args: FastVideoArgs):
    """Load the transformer based on the model path, and inference args."""
    config = get_diffusers_config(model=model_path)
    hf_config = deepcopy(config)
    cls_name = config.pop("_class_name")
    if cls_name is None:
        raise ValueError(
            "Model config does not contain a _class_name attribute. "
            "Only diffusers format is supported."
        )

    logger.info("transformer cls_name: %s", cls_name)
    if fastvideo_args.override_transformer_cls_name is not None:
        cls_name = fastvideo_args.override_transformer_cls_name
        logger.info("Overriding transformer cls_name to %s", cls_name)

    fastvideo_args.model_paths["transformer"] = model_path

    # Config from Diffusers supersedes fastvideo's model config
    dit_config = fastvideo_args.pipeline_config.dit_config
    dit_config.update_model_arch(config)

    model_cls, _ = ModelRegistry.resolve_model_cls(cls_name)

    # Find all safetensors files
    safetensors_list = glob.glob(
        os.path.join(str(model_path), "*.safetensors")
    )
    if not safetensors_list:
        raise ValueError(f"No safetensors files found in {model_path}")

    # Check if we should use custom initialization weights
    custom_weights_path = getattr(
        fastvideo_args, "init_weights_from_safetensors", None
    )
    use_custom_weights = (
        custom_weights_path
        and os.path.exists(custom_weights_path)
        and not hasattr(fastvideo_args, "_loading_teacher_critic_model")
    )

    if use_custom_weights:
        if "transformer_2" in model_path:
            custom_weights_path = getattr(
                fastvideo_args, "init_weights_from_safetensors_2", None
            )
        assert custom_weights_path is not None, (
            "Custom initialization weights must be provided"
        )
        if os.path.isdir(custom_weights_path):
            safetensors_list = glob.glob(
                os.path.join(str(custom_weights_path), "*.safetensors")
            )
        else:
            assert custom_weights_path.endswith(".safetensors"), (
                "Custom initialization weights must be a safetensors file"
            )
            safetensors_list = [custom_weights_path]

    logger.info(
        "Loading model from %s safetensors files: %s",
        len(safetensors_list),
        safetensors_list,
    )

    default_dtype = PRECISION_TO_TYPE[
        fastvideo_args.pipeline_config.dit_precision
    ]

    # Load the model using FSDP loader
    logger.info("Loading model from %s, default_dtype: %s", cls_name,
                default_dtype)
    assert fastvideo_args.hsdp_shard_dim is not None
    # Cosmos2.5 checkpoints can include extra entries not present in the
    # instantiated model (e.g. pos_embedder ranges / *_extra_state). Load
    # non-strictly for Cosmos2.5 only; keep upstream strict behavior for others.
    strict_load = not (
        cls_name.startswith("Cosmos25")
        or cls_name == "Cosmos25Transformer3DModel"
        or getattr(fastvideo_args.pipeline_config, "prefix", "") == "Cosmos25"
    )
    model = maybe_load_fsdp_model(
        model_cls=model_cls,
        init_params={"config": dit_config, "hf_config": hf_config},
        weight_dir_list=safetensors_list,
        device=get_local_torch_device(),
        hsdp_replicate_dim=fastvideo_args.hsdp_replicate_dim,
        hsdp_shard_dim=fastvideo_args.hsdp_shard_dim,
        strict=strict_load,
        cpu_offload=fastvideo_args.dit_cpu_offload,
        pin_cpu_memory=fastvideo_args.pin_cpu_memory,
        fsdp_inference=fastvideo_args.use_fsdp_inference,
        # TODO(will): make these configurable
        default_dtype=default_dtype,
        param_dtype=torch.bfloat16,
        reduce_dtype=torch.float32,
        output_dtype=None,
        training_mode=fastvideo_args.training_mode,
        enable_torch_compile=fastvideo_args.enable_torch_compile,
        torch_compile_kwargs=fastvideo_args.torch_compile_kwargs,
    )

    total_params = sum(p.numel() for p in model.parameters())
    logger.info("Loaded model with %.2fB parameters", total_params / 1e9)

    assert next(model.parameters()).dtype == default_dtype, (
        "Model dtype does not match default dtype"
    )

    model = model.eval()

    if fastvideo_args.dit_layerwise_offload and hasattr(model, "blocks"):
        # Check if this is a Wan model (only Wan models support layerwise offload)
        is_wan_model = "Wan" in cls_name
        if not is_wan_model:
            logger.warning(
                "Layerwise offload is currently only supported for Wan models. "
                "Model class '%s' does not support layerwise offload. "
                "Disabling layerwise offload for this model.",
                cls_name
            )
        else:
            try:
                num_layers = len(getattr(model, "blocks"))
            except TypeError:
                num_layers = None
            if isinstance(num_layers, int) and num_layers > 0:
                # Ensure model is on the correct device (CUDA) before initializing manager
                # This ensures non-managed parameters (embeddings, final norms) are on GPU
                model = model.to(get_local_torch_device())
                mgr = LayerwiseOffloadManager(
                    model,
                    module_list_attr="blocks",
                    num_layers=num_layers,
                    enabled=True,
                    pin_cpu_memory=fastvideo_args.pin_cpu_memory,
                    auto_initialize=True,
                )
                setattr(model, "_layerwise_offload_manager", mgr)

    return model

fastvideo.models.loader.component_loader.VAELoader

VAELoader(device=None)

Bases: ComponentLoader

Loader for VAE.

Source code in fastvideo/models/loader/component_loader.py
def __init__(self, device=None) -> None:
    self.device = device

Functions

fastvideo.models.loader.component_loader.VAELoader.load
load(model_path: str, fastvideo_args: FastVideoArgs)

Load the VAE based on the model path, and inference args.

Source code in fastvideo/models/loader/component_loader.py
def load(self, model_path: str, fastvideo_args: FastVideoArgs):
    """Load the VAE based on the model path, and inference args."""
    config = get_diffusers_config(model=model_path)
    class_name = config.pop("_class_name")
    assert class_name is not None, (
        "Model config does not contain a _class_name attribute. Only diffusers format is supported."
    )
    fastvideo_args.model_paths["vae"] = model_path

    vae_config = fastvideo_args.pipeline_config.vae_config
    vae_config.update_model_arch(config)

    from fastvideo.platforms import current_platform

    if fastvideo_args.vae_cpu_offload:
        target_device = (
            torch.device("mps")
            if current_platform.is_mps()
            else torch.device("cpu")
        )
    else:
        target_device = get_local_torch_device()

    with set_default_torch_dtype(
        PRECISION_TO_TYPE[fastvideo_args.pipeline_config.vae_precision]
        if fastvideo_args.pipeline_config.vae_precision
        else torch.bfloat16
    ):
        # Cosmos2.5 uses a Wan2.1 VAE stored as `tokenizer.safetensors` under the VAE folder.
        is_cosmos25 = fastvideo_args.pipeline_config.__class__.__name__ == "Cosmos25Config"
        if class_name == "AutoencoderKLWan" and is_cosmos25:
            from fastvideo.models.vaes.cosmos25wanvae import Cosmos25WanVAE

            dtype = PRECISION_TO_TYPE[fastvideo_args.pipeline_config.vae_precision]
            vae = Cosmos25WanVAE(device=target_device, dtype=dtype)

            weight_path = os.path.join(model_path, "tokenizer.safetensors")
            if not os.path.exists(weight_path):
                raise FileNotFoundError(
                    f"Missing Cosmos2.5 VAE weights: {weight_path}"
                )
            sd = safetensors_load_file(weight_path)
            vae.load_state_dict(sd, strict=False)
            return vae.eval()

        vae_cls, _ = ModelRegistry.resolve_model_cls(class_name)
        vae = vae_cls(vae_config).to(target_device)

    # Find all safetensors files
    safetensors_list = glob.glob(
        os.path.join(str(model_path), "*.safetensors"))
    if not safetensors_list:
        raise ValueError(f"No safetensors files found in {model_path}")
    # Common case: a single `.safetensors` checkpoint file.
    # Some models may be sharded into multiple files; in that case we merge.
    if len(safetensors_list) == 1:
        loaded = safetensors_load_file(safetensors_list[0])
    else:
        loaded = {}
        for sf_file in safetensors_list:
            loaded.update(safetensors_load_file(sf_file))
    vae.load_state_dict(loaded, strict=False)

    return vae.eval()

Functions