Skip to content

component_loader

Classes

fastvideo.models.loader.component_loader.AudioDecoderLoader

AudioDecoderLoader(device=None)

Bases: ComponentLoader

Loader for LTX-2 audio decoder (audio_vae component).

Source code in fastvideo/models/loader/component_loader.py
def __init__(self, device=None) -> None:
    self.device = device

fastvideo.models.loader.component_loader.ComponentLoader

ComponentLoader(device=None)

Bases: ABC

Base class for loading a specific type of model component.

Source code in fastvideo/models/loader/component_loader.py
def __init__(self, device=None) -> None:
    self.device = device

Functions

fastvideo.models.loader.component_loader.ComponentLoader.for_module_type classmethod
for_module_type(module_type: str, transformers_or_diffusers: str) -> ComponentLoader

Factory method to create a component loader for a specific module type.

Parameters:

Name Type Description Default
module_type str

Type of module (e.g., "vae", "text_encoder", "transformer", "scheduler")

required
transformers_or_diffusers str

Whether the module is from transformers or diffusers

required

Returns:

Type Description
ComponentLoader

A component loader for the specified module type

Source code in fastvideo/models/loader/component_loader.py
@classmethod
def for_module_type(
    cls, module_type: str, transformers_or_diffusers: str
) -> "ComponentLoader":
    """
    Factory method to create a component loader for a specific module type.

    Args:
        module_type: Type of module (e.g., "vae", "text_encoder", "transformer", "scheduler")
        transformers_or_diffusers: Whether the module is from transformers or diffusers

    Returns:
        A component loader for the specified module type
    """
    # Map of module types to their loader classes and expected library
    module_loaders = {
        "scheduler": (SchedulerLoader, "diffusers"),
        "transformer": (TransformerLoader, "diffusers"),
        "transformer_2": (TransformerLoader, "diffusers"),
        "vae": (VAELoader, "diffusers"),
        "audio_vae": (AudioDecoderLoader, "diffusers"),
        "audio_decoder": (AudioDecoderLoader, "diffusers"),
        "vocoder": (VocoderLoader, "diffusers"),
        "text_encoder": (TextEncoderLoader, "transformers"),
        "text_encoder_2": (TextEncoderLoader, "transformers"),
        "tokenizer": (TokenizerLoader, "transformers"),
        "tokenizer_2": (TokenizerLoader, "transformers"),
        "image_processor": (ImageProcessorLoader, "transformers"),
        "feature_extractor": (ImageProcessorLoader, "transformers"),
        "image_encoder": (ImageEncoderLoader, "transformers"),
    }

    if module_type in module_loaders:
        loader_cls, expected_library = module_loaders[module_type]
        # Allow fastvideo.* libraries for custom implementations (e.g. Cosmos2_5Pipeline)
        # that aren't available in diffusers/transformers yet
        is_fastvideo_module = transformers_or_diffusers.startswith("fastvideo.")
        if not is_fastvideo_module:
            # Assert that the library matches what's expected for this module type
            assert transformers_or_diffusers == expected_library, f"{module_type} must be loaded from {expected_library}, got {transformers_or_diffusers}"
        return loader_cls()

    # For unknown module types, use a generic loader
    logger.warning(
        "No specific loader found for module type: %s. Using generic loader.",
        module_type,
    )
    return GenericComponentLoader(transformers_or_diffusers)
fastvideo.models.loader.component_loader.ComponentLoader.load abstractmethod
load(model_path: str, fastvideo_args: FastVideoArgs)

Load the component based on the model path, architecture, and inference args.

Parameters:

Name Type Description Default
model_path str

Path to the component model

required
fastvideo_args FastVideoArgs

FastVideoArgs

required

Returns:

Type Description

The loaded component

Source code in fastvideo/models/loader/component_loader.py
@abstractmethod
def load(self, model_path: str, fastvideo_args: FastVideoArgs):
    """
    Load the component based on the model path, architecture, and inference args.

    Args:
        model_path: Path to the component model
        fastvideo_args: FastVideoArgs

    Returns:
        The loaded component
    """
    raise NotImplementedError

fastvideo.models.loader.component_loader.GenericComponentLoader

GenericComponentLoader(library='transformers')

Bases: ComponentLoader

Generic loader for components that don't have a specific loader.

Source code in fastvideo/models/loader/component_loader.py
def __init__(self, library="transformers") -> None:
    super().__init__()
    self.library = library

Functions

fastvideo.models.loader.component_loader.GenericComponentLoader.load
load(model_path: str, fastvideo_args: FastVideoArgs)

Load a generic component based on the model path, and inference args.

Source code in fastvideo/models/loader/component_loader.py
def load(self, model_path: str, fastvideo_args: FastVideoArgs):
    """Load a generic component based on the model path, and inference args."""
    logger.warning(
        "Using generic loader for %s with library %s",
        model_path,
        self.library,
    )

    if self.library == "transformers":
        from transformers import AutoModel

        model = AutoModel.from_pretrained(
            model_path,
            trust_remote_code=fastvideo_args.trust_remote_code,
            revision=fastvideo_args.revision,
        )
        logger.info(
            "Loaded generic transformers model: %s",
            model.__class__.__name__,
        )
        return model
    elif self.library == "diffusers":
        logger.warning(
            "Generic loading for diffusers components is not fully implemented"
        )

        model_config = get_diffusers_config(model=model_path)
        logger.info("Diffusers Model config: %s", model_config)
        # This is a placeholder - in a real implementation, you'd need to handle this properly
        return None
    else:
        raise ValueError(f"Unsupported library: {self.library}")

fastvideo.models.loader.component_loader.ImageEncoderLoader

ImageEncoderLoader(device=None)

Bases: TextEncoderLoader

Source code in fastvideo/models/loader/component_loader.py
def __init__(self, device=None) -> None:
    self.device = device

Functions

fastvideo.models.loader.component_loader.ImageEncoderLoader.load
load(model_path: str, fastvideo_args: FastVideoArgs)

Load the text encoders based on the model path, and inference args.

Source code in fastvideo/models/loader/component_loader.py
def load(self, model_path: str, fastvideo_args: FastVideoArgs):
    """Load the text encoders based on the model path, and inference args."""
    # model_config: PretrainedConfig = get_hf_config(
    #     model=model_path,
    #     trust_remote_code=fastvideo_args.trust_remote_code,
    #     revision=fastvideo_args.revision,
    #     model_override_args=None,
    # )
    with open(os.path.join(model_path, "config.json")) as f:
        model_config = json.load(f)
    model_config.pop("_name_or_path", None)
    model_config.pop("transformers_version", None)
    model_config.pop("torch_dtype", None)
    model_config.pop("model_type", None)
    logger.info("HF Model config: %s", model_config)

    encoder_config = fastvideo_args.pipeline_config.image_encoder_config
    encoder_config.update_model_arch(model_config)

    from fastvideo.platforms import current_platform

    if fastvideo_args.image_encoder_cpu_offload:
        target_device = (
            torch.device("mps")
            if current_platform.is_mps()
            else torch.device("cpu")
        )
    else:
        target_device = get_local_torch_device()
    # TODO(will): add support for other dtypes
    return self.load_model(
        model_path,
        encoder_config,
        target_device,
        fastvideo_args,
        fastvideo_args.pipeline_config.image_encoder_precision,
    )

fastvideo.models.loader.component_loader.ImageProcessorLoader

ImageProcessorLoader(device=None)

Bases: ComponentLoader

Loader for image processor.

Source code in fastvideo/models/loader/component_loader.py
def __init__(self, device=None) -> None:
    self.device = device

Functions

fastvideo.models.loader.component_loader.ImageProcessorLoader.load
load(model_path: str, fastvideo_args: FastVideoArgs)

Load the image processor based on the model path, and inference args.

Source code in fastvideo/models/loader/component_loader.py
def load(self, model_path: str, fastvideo_args: FastVideoArgs):
    """Load the image processor based on the model path, and inference args."""
    logger.info("Loading image processor from %s", model_path)

    image_processor = AutoImageProcessor.from_pretrained(
        model_path,
    )
    logger.info(
        "Loaded image processor: %s", image_processor.__class__.__name__
    )
    return image_processor

fastvideo.models.loader.component_loader.PipelineComponentLoader

Utility class for loading pipeline components. This replaces the chain of if-else statements in load_pipeline_module.

Functions

fastvideo.models.loader.component_loader.PipelineComponentLoader.load_module staticmethod
load_module(module_name: str, component_model_path: str, transformers_or_diffusers: str, fastvideo_args: FastVideoArgs)

Load a pipeline module.

Parameters:

Name Type Description Default
module_name str

Name of the module (e.g., "vae", "text_encoder", "transformer", "scheduler")

required
component_model_path str

Path to the component model

required
transformers_or_diffusers str

Whether the module is from transformers or diffusers

required
pipeline_args

Inference arguments

required

Returns:

Type Description

The loaded module

Source code in fastvideo/models/loader/component_loader.py
@staticmethod
def load_module(
    module_name: str,
    component_model_path: str,
    transformers_or_diffusers: str,
    fastvideo_args: FastVideoArgs,
):
    """
    Load a pipeline module.

    Args:
        module_name: Name of the module (e.g., "vae", "text_encoder", "transformer", "scheduler")
        component_model_path: Path to the component model
        transformers_or_diffusers: Whether the module is from transformers or diffusers
        pipeline_args: Inference arguments

    Returns:
        The loaded module
    """
    logger.info(
        "Loading %s using %s from %s",
        module_name,
        transformers_or_diffusers,
        component_model_path,
    )

    # Get the appropriate loader for this module type
    loader = ComponentLoader.for_module_type(
        module_name, transformers_or_diffusers
    )

    # Load the module
    return loader.load(component_model_path, fastvideo_args)

fastvideo.models.loader.component_loader.SchedulerLoader

SchedulerLoader(device=None)

Bases: ComponentLoader

Loader for scheduler.

Source code in fastvideo/models/loader/component_loader.py
def __init__(self, device=None) -> None:
    self.device = device

Functions

fastvideo.models.loader.component_loader.SchedulerLoader.load
load(model_path: str, fastvideo_args: FastVideoArgs)

Load the scheduler based on the model path, and inference args.

Source code in fastvideo/models/loader/component_loader.py
def load(self, model_path: str, fastvideo_args: FastVideoArgs):
    """Load the scheduler based on the model path, and inference args."""
    config = get_diffusers_config(model=model_path)

    class_name = config.pop("_class_name")
    assert class_name is not None, (
        "Model config does not contain a _class_name attribute. Only diffusers format is supported."
    )

    scheduler_cls, _ = ModelRegistry.resolve_model_cls(class_name)

    scheduler = scheduler_cls(**config)
    if fastvideo_args.pipeline_config.flow_shift is not None:
        scheduler.set_shift(fastvideo_args.pipeline_config.flow_shift)
    if fastvideo_args.pipeline_config.timesteps_scale is not None:
        scheduler.set_timesteps_scale(
            fastvideo_args.pipeline_config.timesteps_scale
        )
    return scheduler

fastvideo.models.loader.component_loader.TextEncoderLoader

TextEncoderLoader(device=None)

Bases: ComponentLoader

Loader for text encoders.

Source code in fastvideo/models/loader/component_loader.py
def __init__(self, device=None) -> None:
    self.device = device

Classes

fastvideo.models.loader.component_loader.TextEncoderLoader.Source dataclass
Source(model_or_path: str, prefix: str = '', fall_back_to_pt: bool = True, allow_patterns_overrides: list[str] | None = None)

A source for weights.

Attributes
fastvideo.models.loader.component_loader.TextEncoderLoader.Source.allow_patterns_overrides class-attribute instance-attribute
allow_patterns_overrides: list[str] | None = None

If defined, weights will load exclusively using these patterns.

fastvideo.models.loader.component_loader.TextEncoderLoader.Source.fall_back_to_pt class-attribute instance-attribute
fall_back_to_pt: bool = True

Whether .pt weights can be used.

fastvideo.models.loader.component_loader.TextEncoderLoader.Source.model_or_path instance-attribute
model_or_path: str

The model ID or path.

fastvideo.models.loader.component_loader.TextEncoderLoader.Source.prefix class-attribute instance-attribute
prefix: str = ''

A prefix to prepend to all weights.

Functions

fastvideo.models.loader.component_loader.TextEncoderLoader.load
load(model_path: str, fastvideo_args: FastVideoArgs)

Load the text encoders based on the model path, and inference args.

Source code in fastvideo/models/loader/component_loader.py
def load(self, model_path: str, fastvideo_args: FastVideoArgs):
    """Load the text encoders based on the model path, and inference args."""
    # model_config: PretrainedConfig = get_hf_config(
    #     model=model_path,
    #     trust_remote_code=fastvideo_args.trust_remote_code,
    #     revision=fastvideo_args.revision,
    #     model_override_args=None,
    # )
    model_config = get_diffusers_config(model=model_path)
    model_config.pop("_name_or_path", None)
    model_config.pop("transformers_version", None)
    model_config.pop("model_type", None)
    model_config.pop("tokenizer_class", None)
    model_config.pop("torch_dtype", None)
    repo_root = os.path.dirname(model_path)
    index_path = os.path.join(repo_root, "model_index.json")
    gemma_path = ""
    gemma_path_from_candidate = False
    if os.path.isfile(index_path):
        try:
            with open(index_path, encoding="utf-8") as f:
                model_index = json.load(f)
            gemma_path = model_index.get("gemma_model_path", "")
        except json.JSONDecodeError:
            gemma_path = ""
    if not gemma_path:
        candidate = os.path.normpath(os.path.join(model_path, "gemma"))
        if os.path.isdir(candidate):
            gemma_path = candidate
            gemma_path_from_candidate = True
            model_config["gemma_model_path"] = gemma_path
    if gemma_path and not gemma_path_from_candidate:
        if not os.path.isabs(gemma_path):
            model_config["gemma_model_path"] = os.path.normpath(
                os.path.join(repo_root, gemma_path)
            )
    transformer_config_path = os.path.join(
        repo_root, "transformer", "config.json"
    )
    if os.path.isfile(transformer_config_path):
        try:
            with open(transformer_config_path, encoding="utf-8") as f:
                transformer_config = json.load(f)
            if (
                "connector_double_precision_rope" not in model_config
                or not model_config["connector_double_precision_rope"]
            ):
                if transformer_config.get("double_precision_rope") is True:
                    model_config["connector_double_precision_rope"] = True
            if "connector_rope_type" not in model_config:
                rope_type = transformer_config.get("rope_type")
                if rope_type is not None:
                    model_config["connector_rope_type"] = rope_type
        except json.JSONDecodeError:
            pass
    logger.info("HF Model config: %s", model_config)

    # @TODO(Wei): Better way to handle this?
    try:
        encoder_config = (
            fastvideo_args.pipeline_config.text_encoder_configs[0]
        )
        encoder_config.update_model_arch(model_config)
        encoder_precision = (
            fastvideo_args.pipeline_config.text_encoder_precisions[0]
        )
    except Exception:
        encoder_config = (
            fastvideo_args.pipeline_config.text_encoder_configs[1]
        )
        encoder_config.update_model_arch(model_config)
        encoder_precision = (
            fastvideo_args.pipeline_config.text_encoder_precisions[1]
        )

    target_device = get_local_torch_device()
    # TODO(will): add support for other dtypes
    return self.load_model(
        model_path,
        encoder_config,
        target_device,
        fastvideo_args,
        encoder_precision,
        use_text_encoder_override=True,
    )

fastvideo.models.loader.component_loader.TokenizerLoader

TokenizerLoader(device=None)

Bases: ComponentLoader

Loader for tokenizers.

Source code in fastvideo/models/loader/component_loader.py
def __init__(self, device=None) -> None:
    self.device = device

Functions

fastvideo.models.loader.component_loader.TokenizerLoader.load
load(model_path: str, fastvideo_args: FastVideoArgs)

Load the tokenizer based on the model path, and inference args.

Source code in fastvideo/models/loader/component_loader.py
def load(self, model_path: str, fastvideo_args: FastVideoArgs):
    """Load the tokenizer based on the model path, and inference args."""
    logger.info("Loading tokenizer from %s", model_path)

    # Cosmos2.5 stores an AutoProcessor config in `tokenizer/config.json` (not a tokenizer
    # config). Use its `_name_or_path` (e.g. Qwen/Qwen2.5-VL-7B-Instruct) as the source.
    tokenizer_cfg_path = os.path.join(model_path, "config.json")
    if os.path.exists(tokenizer_cfg_path):
        try:
            with open(tokenizer_cfg_path, "r") as f:
                tokenizer_cfg = json.load(f)
            if isinstance(tokenizer_cfg, dict) and (
                tokenizer_cfg.get("_class_name") == "AutoProcessor"
                or "processor_type" in tokenizer_cfg
            ):
                src = tokenizer_cfg.get("_name_or_path", "")
                if isinstance(src, str) and src.strip():
                    processor = AutoProcessor.from_pretrained(
                        src.strip(),
                        trust_remote_code=True,
                    )
                    logger.info(
                        "Loaded tokenizer/processor from %s: %s",
                        src,
                        processor.__class__.__name__,
                    )
                    return processor
        except Exception:
            # If parsing fails, fall through to AutoTokenizer below.
            pass

    tokenizer = AutoTokenizer.from_pretrained(
        model_path,  # "<path to model>/tokenizer"
        # in v0, this was same string as encoder_name "ClipTextModel"
        # TODO(will): pass these tokenizer kwargs from inference args? Maybe
        # other method of config?
    )
    padding_side = None
    if hasattr(fastvideo_args.pipeline_config, "text_encoder_configs"):
        try:
            arch_config = fastvideo_args.pipeline_config.text_encoder_configs[
                0
            ].arch_config
            padding_side = getattr(arch_config, "padding_side", None)
        except Exception:
            padding_side = None
    if padding_side:
        tokenizer.padding_side = padding_side
    if tokenizer.pad_token is None and tokenizer.eos_token is not None:
        tokenizer.pad_token = tokenizer.eos_token
    logger.info("Loaded tokenizer: %s", tokenizer.__class__.__name__)
    return tokenizer

fastvideo.models.loader.component_loader.TransformerLoader

TransformerLoader(device=None)

Bases: ComponentLoader

Loader for transformer.

Source code in fastvideo/models/loader/component_loader.py
def __init__(self, device=None) -> None:
    self.device = device

Functions

fastvideo.models.loader.component_loader.TransformerLoader.load
load(model_path: str, fastvideo_args: FastVideoArgs)

Load the transformer based on the model path, and inference args.

Source code in fastvideo/models/loader/component_loader.py
def load(self, model_path: str, fastvideo_args: FastVideoArgs):
    """Load the transformer based on the model path, and inference args."""
    config = get_diffusers_config(model=model_path)
    hf_config = deepcopy(config)
    cls_name = config.pop("_class_name")
    if cls_name is None:
        raise ValueError(
            "Model config does not contain a _class_name attribute. "
            "Only diffusers format is supported."
        )

    logger.info("transformer cls_name: %s", cls_name)
    if fastvideo_args.override_transformer_cls_name is not None:
        cls_name = fastvideo_args.override_transformer_cls_name
        logger.info("Overriding transformer cls_name to %s", cls_name)

    fastvideo_args.model_paths["transformer"] = model_path

    # Config from Diffusers supersedes fastvideo's model config
    dit_config = fastvideo_args.pipeline_config.dit_config
    dit_config.update_model_arch(config)

    model_cls, _ = ModelRegistry.resolve_model_cls(cls_name)

    # Find all safetensors files
    safetensors_list = glob.glob(
        os.path.join(str(model_path), "*.safetensors")
    )
    if not safetensors_list:
        raise ValueError(f"No safetensors files found in {model_path}")

    # Check if we should use custom initialization weights
    custom_weights_path = getattr(
        fastvideo_args, "init_weights_from_safetensors", None
    )
    use_custom_weights = (
        custom_weights_path
        and os.path.exists(custom_weights_path)
        and not hasattr(fastvideo_args, "_loading_teacher_critic_model")
    )

    if use_custom_weights:
        if "transformer_2" in model_path:
            custom_weights_path = getattr(
                fastvideo_args, "init_weights_from_safetensors_2", None
            )
        assert custom_weights_path is not None, (
            "Custom initialization weights must be provided"
        )
        if os.path.isdir(custom_weights_path):
            safetensors_list = glob.glob(
                os.path.join(str(custom_weights_path), "*.safetensors")
            )
        else:
            assert custom_weights_path.endswith(".safetensors"), (
                "Custom initialization weights must be a safetensors file"
            )
            safetensors_list = [custom_weights_path]

    logger.info(
        "Loading model from %s safetensors files: %s",
        len(safetensors_list),
        safetensors_list,
    )

    default_dtype = PRECISION_TO_TYPE[
        fastvideo_args.pipeline_config.dit_precision
    ]

    # Load the model using FSDP loader
    logger.info("Loading model from %s, default_dtype: %s", cls_name,
                default_dtype)
    assert fastvideo_args.hsdp_shard_dim is not None
    # Cosmos2.5 checkpoints can include extra entries not present in the
    # instantiated model (e.g. pos_embedder ranges / *_extra_state). Load
    # non-strictly for Cosmos2.5 only; keep upstream strict behavior for others.
    strict_load = not (
        cls_name.startswith("Cosmos25")
        or cls_name == "Cosmos25Transformer3DModel"
        or getattr(fastvideo_args.pipeline_config, "prefix", "") == "Cosmos25"
    )
    model = maybe_load_fsdp_model(
        model_cls=model_cls,
        init_params={"config": dit_config, "hf_config": hf_config},
        weight_dir_list=safetensors_list,
        device=get_local_torch_device(),
        hsdp_replicate_dim=fastvideo_args.hsdp_replicate_dim,
        hsdp_shard_dim=fastvideo_args.hsdp_shard_dim,
        strict=strict_load,
        cpu_offload=fastvideo_args.dit_cpu_offload,
        pin_cpu_memory=fastvideo_args.pin_cpu_memory,
        fsdp_inference=fastvideo_args.use_fsdp_inference,
        # TODO(will): make these configurable
        default_dtype=default_dtype,
        param_dtype=torch.bfloat16,
        reduce_dtype=torch.float32,
        output_dtype=None,
        training_mode=fastvideo_args.training_mode,
        enable_torch_compile=fastvideo_args.enable_torch_compile,
        torch_compile_kwargs=fastvideo_args.torch_compile_kwargs,
    )

    total_params = sum(p.numel() for p in model.parameters())
    logger.info("Loaded model with %.2fB parameters", total_params / 1e9)

    assert next(model.parameters()).dtype == default_dtype, (
        "Model dtype does not match default dtype"
    )

    model = model.eval()

    if fastvideo_args.inference_mode and fastvideo_args.dit_layerwise_offload:
        # Check if model has nn.ModuleList for layerwise offload compatibility
        has_module_list = any(
            isinstance(m, nn.ModuleList) for m in model.children()
        )
        if has_module_list:
            enable_layerwise_offload(model)
        else:
            logger.warning(
                "Layerwise offload requested but model %s does not have "
                "nn.ModuleList structure. Skipping layerwise offload.",
                cls_name
            )
    return model

fastvideo.models.loader.component_loader.VAELoader

VAELoader(device=None)

Bases: ComponentLoader

Loader for VAE.

Source code in fastvideo/models/loader/component_loader.py
def __init__(self, device=None) -> None:
    self.device = device

Functions

fastvideo.models.loader.component_loader.VAELoader.load
load(model_path: str, fastvideo_args: FastVideoArgs)

Load the VAE based on the model path, and inference args.

Source code in fastvideo/models/loader/component_loader.py
def load(self, model_path: str, fastvideo_args: FastVideoArgs):
    """Load the VAE based on the model path, and inference args."""
    config = get_diffusers_config(model=model_path)
    class_name = config.get("_class_name")
    assert class_name is not None, (
        "Model config does not contain a _class_name attribute. Only diffusers format is supported."
    )
    fastvideo_args.model_paths["vae"] = model_path

    from fastvideo.platforms import current_platform

    if fastvideo_args.vae_cpu_offload:
        target_device = (
            torch.device("mps")
            if current_platform.is_mps()
            else torch.device("cpu")
        )
    else:
        target_device = get_local_torch_device()

    with set_default_torch_dtype(
        PRECISION_TO_TYPE[fastvideo_args.pipeline_config.vae_precision]
        if fastvideo_args.pipeline_config.vae_precision
        else torch.bfloat16
    ):
        # Cosmos2.5 uses a Wan2.1 VAE stored as `tokenizer.safetensors` under the VAE folder.
        is_cosmos25 = fastvideo_args.pipeline_config.__class__.__name__ == "Cosmos25Config"
        if class_name == "AutoencoderKLWan" and is_cosmos25:
            from fastvideo.models.vaes.cosmos25wanvae import Cosmos25WanVAE

            dtype = PRECISION_TO_TYPE[fastvideo_args.pipeline_config.vae_precision]
            vae = Cosmos25WanVAE(device=target_device, dtype=dtype)

            weight_path = os.path.join(model_path, "tokenizer.safetensors")
            if not os.path.exists(weight_path):
                raise FileNotFoundError(
                    f"Missing Cosmos2.5 VAE weights: {weight_path}"
                )
            sd = safetensors_load_file(weight_path)
            vae.load_state_dict(sd, strict=False)
            return vae.eval()

        # LTX-2 uses CausalVideoAutoencoder with nested "vae" config
        if class_name == "CausalVideoAutoencoder" and "vae" in config:
            vae_cls, _ = ModelRegistry.resolve_model_cls(class_name)
            vae = vae_cls(config).to(target_device)
            if hasattr(vae, "set_tiling_config"):
                vae_config = fastvideo_args.pipeline_config.vae_config
                vae.set_tiling_config(
                    spatial_tile_size_in_pixels=getattr(
                        vae_config, "ltx2_spatial_tile_size_in_pixels", 512),
                    spatial_tile_overlap_in_pixels=getattr(
                        vae_config, "ltx2_spatial_tile_overlap_in_pixels", 64),
                    temporal_tile_size_in_frames=getattr(
                        vae_config, "ltx2_temporal_tile_size_in_frames", 64),
                    temporal_tile_overlap_in_frames=getattr(
                        vae_config,
                        "ltx2_temporal_tile_overlap_in_frames", 24),
                )
        else:
            config.pop("_class_name", None)
            vae_config = fastvideo_args.pipeline_config.vae_config
            vae_config.update_model_arch(config)
            vae_cls, _ = ModelRegistry.resolve_model_cls(class_name)
            vae = vae_cls(vae_config).to(target_device)

    # Find all safetensors files
    safetensors_list = glob.glob(
        os.path.join(str(model_path), "*.safetensors"))
    if not safetensors_list:
        raise ValueError(f"No safetensors files found in {model_path}")
    # Common case: a single `.safetensors` checkpoint file.
    # Some models may be sharded into multiple files; in that case we merge.
    loaded = {}
    for sf_file in safetensors_list:
        loaded.update(safetensors_load_file(sf_file))

    # LTX-2 CausalVideoAutoencoder needs per_channel_statistics remapping
    if class_name == "CausalVideoAutoencoder" and "vae" in config:
        per_channel_prefixes = (
            "per_channel_statistics.",
            "vae.per_channel_statistics.",
        )
        remapped = {}
        for key, tensor in loaded.items():
            remapped[key] = tensor
            for prefix in per_channel_prefixes:
                if key.startswith(prefix):
                    suffix = key[len(prefix):]
                    remapped.setdefault(
                        f"encoder.per_channel_statistics.{suffix}",
                        tensor,
                    )
                    remapped.setdefault(
                        f"decoder.per_channel_statistics.{suffix}",
                        tensor,
                    )
                    break
        loaded = remapped

    vae.load_state_dict(loaded, strict=False)

    return vae.eval()

fastvideo.models.loader.component_loader.VocoderLoader

VocoderLoader(device=None)

Bases: ComponentLoader

Loader for LTX-2 vocoder.

Source code in fastvideo/models/loader/component_loader.py
def __init__(self, device=None) -> None:
    self.device = device

Functions