Skip to content

component_loader

Classes

fastvideo.models.loader.component_loader.ComponentLoader

ComponentLoader(device=None)

Bases: ABC

Base class for loading a specific type of model component.

Source code in fastvideo/models/loader/component_loader.py
def __init__(self, device=None) -> None:
    self.device = device

Functions

fastvideo.models.loader.component_loader.ComponentLoader.for_module_type classmethod
for_module_type(module_type: str, transformers_or_diffusers: str) -> ComponentLoader

Factory method to create a component loader for a specific module type.

Parameters:

Name Type Description Default
module_type str

Type of module (e.g., "vae", "text_encoder", "transformer", "scheduler")

required
transformers_or_diffusers str

Whether the module is from transformers or diffusers

required

Returns:

Type Description
ComponentLoader

A component loader for the specified module type

Source code in fastvideo/models/loader/component_loader.py
@classmethod
def for_module_type(cls, module_type: str,
                    transformers_or_diffusers: str) -> 'ComponentLoader':
    """
    Factory method to create a component loader for a specific module type.

    Args:
        module_type: Type of module (e.g., "vae", "text_encoder", "transformer", "scheduler")
        transformers_or_diffusers: Whether the module is from transformers or diffusers

    Returns:
        A component loader for the specified module type
    """
    # Map of module types to their loader classes and expected library
    module_loaders = {
        "scheduler": (SchedulerLoader, "diffusers"),
        "transformer": (TransformerLoader, "diffusers"),
        "transformer_2": (TransformerLoader, "diffusers"),
        "vae": (VAELoader, "diffusers"),
        "text_encoder": (TextEncoderLoader, "transformers"),
        "text_encoder_2": (TextEncoderLoader, "transformers"),
        "tokenizer": (TokenizerLoader, "transformers"),
        "tokenizer_2": (TokenizerLoader, "transformers"),
        "image_processor": (ImageProcessorLoader, "transformers"),
        "image_encoder": (ImageEncoderLoader, "transformers"),
    }

    if module_type in module_loaders:
        loader_cls, expected_library = module_loaders[module_type]
        # Assert that the library matches what's expected for this module type
        assert transformers_or_diffusers == expected_library, f"{module_type} must be loaded from {expected_library}, got {transformers_or_diffusers}"
        return loader_cls()

    # For unknown module types, use a generic loader
    logger.warning(
        "No specific loader found for module type: %s. Using generic loader.",
        module_type)
    return GenericComponentLoader(transformers_or_diffusers)
fastvideo.models.loader.component_loader.ComponentLoader.load abstractmethod
load(model_path: str, fastvideo_args: FastVideoArgs)

Load the component based on the model path, architecture, and inference args.

Parameters:

Name Type Description Default
model_path str

Path to the component model

required
fastvideo_args FastVideoArgs

FastVideoArgs

required

Returns:

Type Description

The loaded component

Source code in fastvideo/models/loader/component_loader.py
@abstractmethod
def load(self, model_path: str, fastvideo_args: FastVideoArgs):
    """
    Load the component based on the model path, architecture, and inference args.

    Args:
        model_path: Path to the component model
        fastvideo_args: FastVideoArgs

    Returns:
        The loaded component
    """
    raise NotImplementedError

fastvideo.models.loader.component_loader.GenericComponentLoader

GenericComponentLoader(library='transformers')

Bases: ComponentLoader

Generic loader for components that don't have a specific loader.

Source code in fastvideo/models/loader/component_loader.py
def __init__(self, library="transformers") -> None:
    super().__init__()
    self.library = library

Functions

fastvideo.models.loader.component_loader.GenericComponentLoader.load
load(model_path: str, fastvideo_args: FastVideoArgs)

Load a generic component based on the model path, and inference args.

Source code in fastvideo/models/loader/component_loader.py
def load(self, model_path: str, fastvideo_args: FastVideoArgs):
    """Load a generic component based on the model path, and inference args."""
    logger.warning("Using generic loader for %s with library %s",
                   model_path, self.library)

    if self.library == "transformers":
        from transformers import AutoModel

        model = AutoModel.from_pretrained(
            model_path,
            trust_remote_code=fastvideo_args.trust_remote_code,
            revision=fastvideo_args.revision,
        )
        logger.info("Loaded generic transformers model: %s",
                    model.__class__.__name__)
        return model
    elif self.library == "diffusers":
        logger.warning(
            "Generic loading for diffusers components is not fully implemented"
        )

        model_config = get_diffusers_config(model=model_path)
        logger.info("Diffusers Model config: %s", model_config)
        # This is a placeholder - in a real implementation, you'd need to handle this properly
        return None
    else:
        raise ValueError(f"Unsupported library: {self.library}")

fastvideo.models.loader.component_loader.ImageEncoderLoader

ImageEncoderLoader(device=None)

Bases: TextEncoderLoader

Source code in fastvideo/models/loader/component_loader.py
def __init__(self, device=None) -> None:
    self.device = device

Functions

fastvideo.models.loader.component_loader.ImageEncoderLoader.load
load(model_path: str, fastvideo_args: FastVideoArgs)

Load the text encoders based on the model path, and inference args.

Source code in fastvideo/models/loader/component_loader.py
def load(self, model_path: str, fastvideo_args: FastVideoArgs):
    """Load the text encoders based on the model path, and inference args."""
    # model_config: PretrainedConfig = get_hf_config(
    #     model=model_path,
    #     trust_remote_code=fastvideo_args.trust_remote_code,
    #     revision=fastvideo_args.revision,
    #     model_override_args=None,
    # )
    with open(os.path.join(model_path, "config.json")) as f:
        model_config = json.load(f)
    model_config.pop("_name_or_path", None)
    model_config.pop("transformers_version", None)
    model_config.pop("torch_dtype", None)
    model_config.pop("model_type", None)
    logger.info("HF Model config: %s", model_config)

    encoder_config = fastvideo_args.pipeline_config.image_encoder_config
    encoder_config.update_model_arch(model_config)

    from fastvideo.platforms import current_platform

    if fastvideo_args.image_encoder_cpu_offload:
        target_device = torch.device("mps") if current_platform.is_mps() else torch.device("cpu")
    else:
        target_device = get_local_torch_device()
    # TODO(will): add support for other dtypes
    return self.load_model(
        model_path, encoder_config, target_device, fastvideo_args,
        fastvideo_args.pipeline_config.image_encoder_precision)

fastvideo.models.loader.component_loader.ImageProcessorLoader

ImageProcessorLoader(device=None)

Bases: ComponentLoader

Loader for image processor.

Source code in fastvideo/models/loader/component_loader.py
def __init__(self, device=None) -> None:
    self.device = device

Functions

fastvideo.models.loader.component_loader.ImageProcessorLoader.load
load(model_path: str, fastvideo_args: FastVideoArgs)

Load the image processor based on the model path, and inference args.

Source code in fastvideo/models/loader/component_loader.py
def load(self, model_path: str, fastvideo_args: FastVideoArgs):
    """Load the image processor based on the model path, and inference args."""
    logger.info("Loading image processor from %s", model_path)

    image_processor = AutoImageProcessor.from_pretrained(model_path, )
    logger.info("Loaded image processor: %s",
                image_processor.__class__.__name__)
    return image_processor

fastvideo.models.loader.component_loader.PipelineComponentLoader

Utility class for loading pipeline components. This replaces the chain of if-else statements in load_pipeline_module.

Functions

fastvideo.models.loader.component_loader.PipelineComponentLoader.load_module staticmethod
load_module(module_name: str, component_model_path: str, transformers_or_diffusers: str, fastvideo_args: FastVideoArgs)

Load a pipeline module.

Parameters:

Name Type Description Default
module_name str

Name of the module (e.g., "vae", "text_encoder", "transformer", "scheduler")

required
component_model_path str

Path to the component model

required
transformers_or_diffusers str

Whether the module is from transformers or diffusers

required
pipeline_args

Inference arguments

required

Returns:

Type Description

The loaded module

Source code in fastvideo/models/loader/component_loader.py
@staticmethod
def load_module(module_name: str, component_model_path: str,
                transformers_or_diffusers: str,
                fastvideo_args: FastVideoArgs):
    """
    Load a pipeline module.

    Args:
        module_name: Name of the module (e.g., "vae", "text_encoder", "transformer", "scheduler")
        component_model_path: Path to the component model
        transformers_or_diffusers: Whether the module is from transformers or diffusers
        pipeline_args: Inference arguments

    Returns:
        The loaded module
    """
    logger.info(
        "Loading %s using %s from %s",
        module_name,
        transformers_or_diffusers,
        component_model_path,
    )

    # Get the appropriate loader for this module type
    loader = ComponentLoader.for_module_type(module_name,
                                             transformers_or_diffusers)

    # Load the module
    return loader.load(component_model_path, fastvideo_args)

fastvideo.models.loader.component_loader.SchedulerLoader

SchedulerLoader(device=None)

Bases: ComponentLoader

Loader for scheduler.

Source code in fastvideo/models/loader/component_loader.py
def __init__(self, device=None) -> None:
    self.device = device

Functions

fastvideo.models.loader.component_loader.SchedulerLoader.load
load(model_path: str, fastvideo_args: FastVideoArgs)

Load the scheduler based on the model path, and inference args.

Source code in fastvideo/models/loader/component_loader.py
def load(self, model_path: str, fastvideo_args: FastVideoArgs):
    """Load the scheduler based on the model path, and inference args."""
    config = get_diffusers_config(model=model_path)

    class_name = config.pop("_class_name")
    assert class_name is not None, "Model config does not contain a _class_name attribute. Only diffusers format is supported."

    scheduler_cls, _ = ModelRegistry.resolve_model_cls(class_name)

    scheduler = scheduler_cls(**config)
    if fastvideo_args.pipeline_config.flow_shift is not None:
        scheduler.set_shift(fastvideo_args.pipeline_config.flow_shift)
    if fastvideo_args.pipeline_config.timesteps_scale is not None:
        scheduler.set_timesteps_scale(
            fastvideo_args.pipeline_config.timesteps_scale)
    return scheduler

fastvideo.models.loader.component_loader.TextEncoderLoader

TextEncoderLoader(device=None)

Bases: ComponentLoader

Loader for text encoders.

Source code in fastvideo/models/loader/component_loader.py
def __init__(self, device=None) -> None:
    self.device = device

Classes

fastvideo.models.loader.component_loader.TextEncoderLoader.Source dataclass
Source(model_or_path: str, prefix: str = '', fall_back_to_pt: bool = True, allow_patterns_overrides: list[str] | None = None)

A source for weights.

Attributes
fastvideo.models.loader.component_loader.TextEncoderLoader.Source.allow_patterns_overrides class-attribute instance-attribute
allow_patterns_overrides: list[str] | None = None

If defined, weights will load exclusively using these patterns.

fastvideo.models.loader.component_loader.TextEncoderLoader.Source.fall_back_to_pt class-attribute instance-attribute
fall_back_to_pt: bool = True

Whether .pt weights can be used.

fastvideo.models.loader.component_loader.TextEncoderLoader.Source.model_or_path instance-attribute
model_or_path: str

The model ID or path.

fastvideo.models.loader.component_loader.TextEncoderLoader.Source.prefix class-attribute instance-attribute
prefix: str = ''

A prefix to prepend to all weights.

Functions

fastvideo.models.loader.component_loader.TextEncoderLoader.load
load(model_path: str, fastvideo_args: FastVideoArgs)

Load the text encoders based on the model path, and inference args.

Source code in fastvideo/models/loader/component_loader.py
def load(self, model_path: str, fastvideo_args: FastVideoArgs):
    """Load the text encoders based on the model path, and inference args."""
    # model_config: PretrainedConfig = get_hf_config(
    #     model=model_path,
    #     trust_remote_code=fastvideo_args.trust_remote_code,
    #     revision=fastvideo_args.revision,
    #     model_override_args=None,
    # )
    model_config = get_diffusers_config(model=model_path)
    model_config.pop("_name_or_path", None)
    model_config.pop("transformers_version", None)
    model_config.pop("model_type", None)
    model_config.pop("tokenizer_class", None)
    model_config.pop("torch_dtype", None)
    logger.info("HF Model config: %s", model_config)

    # @TODO(Wei): Better way to handle this?
    try:
        encoder_config = fastvideo_args.pipeline_config.text_encoder_configs[
            0]
        encoder_config.update_model_arch(model_config)
        encoder_precision = fastvideo_args.pipeline_config.text_encoder_precisions[
            0]
    except Exception:
        encoder_config = fastvideo_args.pipeline_config.text_encoder_configs[
            1]
        encoder_config.update_model_arch(model_config)
        encoder_precision = fastvideo_args.pipeline_config.text_encoder_precisions[
            1]

    target_device = get_local_torch_device()
    # TODO(will): add support for other dtypes
    return self.load_model(model_path, encoder_config, target_device,
                           fastvideo_args, encoder_precision)

fastvideo.models.loader.component_loader.TokenizerLoader

TokenizerLoader(device=None)

Bases: ComponentLoader

Loader for tokenizers.

Source code in fastvideo/models/loader/component_loader.py
def __init__(self, device=None) -> None:
    self.device = device

Functions

fastvideo.models.loader.component_loader.TokenizerLoader.load
load(model_path: str, fastvideo_args: FastVideoArgs)

Load the tokenizer based on the model path, and inference args.

Source code in fastvideo/models/loader/component_loader.py
def load(self, model_path: str, fastvideo_args: FastVideoArgs):
    """Load the tokenizer based on the model path, and inference args."""
    logger.info("Loading tokenizer from %s", model_path)

    tokenizer = AutoTokenizer.from_pretrained(
        model_path,  # "<path to model>/tokenizer"
        # in v0, this was same string as encoder_name "ClipTextModel"
        # TODO(will): pass these tokenizer kwargs from inference args? Maybe
        # other method of config?
        padding_size='right',
    )
    logger.info("Loaded tokenizer: %s", tokenizer.__class__.__name__)
    return tokenizer

fastvideo.models.loader.component_loader.TransformerLoader

TransformerLoader(device=None)

Bases: ComponentLoader

Loader for transformer.

Source code in fastvideo/models/loader/component_loader.py
def __init__(self, device=None) -> None:
    self.device = device

Functions

fastvideo.models.loader.component_loader.TransformerLoader.load
load(model_path: str, fastvideo_args: FastVideoArgs)

Load the transformer based on the model path, and inference args.

Source code in fastvideo/models/loader/component_loader.py
def load(self, model_path: str, fastvideo_args: FastVideoArgs):
    """Load the transformer based on the model path, and inference args."""
    config = get_diffusers_config(model=model_path)
    hf_config = deepcopy(config)
    cls_name = config.pop("_class_name")
    if cls_name is None:
        raise ValueError(
            "Model config does not contain a _class_name attribute. "
            "Only diffusers format is supported.")

    logger.info("transformer cls_name: %s", cls_name)
    if fastvideo_args.override_transformer_cls_name is not None:
        cls_name = fastvideo_args.override_transformer_cls_name
        logger.info("Overriding transformer cls_name to %s", cls_name)

    fastvideo_args.model_paths["transformer"] = model_path

    # Config from Diffusers supersedes fastvideo's model config
    dit_config = fastvideo_args.pipeline_config.dit_config
    dit_config.update_model_arch(config)

    model_cls, _ = ModelRegistry.resolve_model_cls(cls_name)

    # Find all safetensors files
    safetensors_list = glob.glob(
        os.path.join(str(model_path), "*.safetensors"))
    if not safetensors_list:
        raise ValueError(f"No safetensors files found in {model_path}")

    # Check if we should use custom initialization weights
    custom_weights_path = getattr(fastvideo_args, 'init_weights_from_safetensors', None)
    use_custom_weights = (custom_weights_path and os.path.exists(custom_weights_path) and 
                        not hasattr(fastvideo_args, '_loading_teacher_critic_model'))

    if use_custom_weights:
        if 'transformer_2' in model_path:
            custom_weights_path = getattr(fastvideo_args, 'init_weights_from_safetensors_2', None)
        assert custom_weights_path is not None, "Custom initialization weights must be provided"
        if os.path.isdir(custom_weights_path):
            safetensors_list = glob.glob(
                os.path.join(str(custom_weights_path), "*.safetensors"))
        else:
            assert custom_weights_path.endswith(".safetensors"), "Custom initialization weights must be a safetensors file"
            safetensors_list = [custom_weights_path]

    logger.info("Loading model from %s safetensors files: %s",
                len(safetensors_list), safetensors_list)

    default_dtype = PRECISION_TO_TYPE[
        fastvideo_args.pipeline_config.dit_precision]

    # Load the model using FSDP loader
    logger.info("Loading model from %s, default_dtype: %s", cls_name,
                default_dtype)
    assert fastvideo_args.hsdp_shard_dim is not None
    model = maybe_load_fsdp_model(
        model_cls=model_cls,
        init_params={
            "config": dit_config,
            "hf_config": hf_config
        },
        weight_dir_list=safetensors_list,
        device=get_local_torch_device(),
        hsdp_replicate_dim=fastvideo_args.hsdp_replicate_dim,
        hsdp_shard_dim=fastvideo_args.hsdp_shard_dim,
        cpu_offload=fastvideo_args.dit_cpu_offload,
        pin_cpu_memory=fastvideo_args.pin_cpu_memory,
        fsdp_inference=fastvideo_args.use_fsdp_inference,
        # TODO(will): make these configurable
        default_dtype=default_dtype,
        param_dtype=torch.bfloat16,
        reduce_dtype=torch.float32,
        output_dtype=None,
        training_mode=fastvideo_args.training_mode,
        enable_torch_compile=fastvideo_args.enable_torch_compile,
        torch_compile_kwargs=fastvideo_args.torch_compile_kwargs)


    total_params = sum(p.numel() for p in model.parameters())
    logger.info("Loaded model with %.2fB parameters", total_params / 1e9)

    assert next(model.parameters()).dtype == default_dtype, "Model dtype does not match default dtype"

    model = model.eval()
    return model

fastvideo.models.loader.component_loader.VAELoader

VAELoader(device=None)

Bases: ComponentLoader

Loader for VAE.

Source code in fastvideo/models/loader/component_loader.py
def __init__(self, device=None) -> None:
    self.device = device

Functions

fastvideo.models.loader.component_loader.VAELoader.load
load(model_path: str, fastvideo_args: FastVideoArgs)

Load the VAE based on the model path, and inference args.

Source code in fastvideo/models/loader/component_loader.py
def load(self, model_path: str, fastvideo_args: FastVideoArgs):
    """Load the VAE based on the model path, and inference args."""
    config = get_diffusers_config(model=model_path)
    class_name = config.pop("_class_name")
    assert class_name is not None, "Model config does not contain a _class_name attribute. Only diffusers format is supported."
    fastvideo_args.model_paths["vae"] = model_path

    vae_config = fastvideo_args.pipeline_config.vae_config
    vae_config.update_model_arch(config)

    from fastvideo.platforms import current_platform

    if fastvideo_args.vae_cpu_offload:
        target_device = torch.device("mps") if current_platform.is_mps() else torch.device("cpu")
    else:
        target_device = get_local_torch_device()

    with set_default_torch_dtype(PRECISION_TO_TYPE[
            fastvideo_args.pipeline_config.vae_precision]):
        vae_cls, _ = ModelRegistry.resolve_model_cls(class_name)
        vae = vae_cls(vae_config).to(target_device)

    # Find all safetensors files
    safetensors_list = glob.glob(
        os.path.join(str(model_path), "*.safetensors"))
    # TODO(PY)
    assert len(
        safetensors_list
    ) == 1, f"Found {len(safetensors_list)} safetensors files in {model_path}"
    loaded = safetensors_load_file(safetensors_list[0])
    vae.load_state_dict(
        loaded, strict=False)  # We might only load encoder or decoder

    return vae.eval()

Functions