Skip to content

fastvideo_args

The arguments of FastVideo Inference.

Classes

fastvideo.fastvideo_args.ExecutionMode

Bases: str, Enum

Enumeration for different pipeline modes.

Inherits from str to allow string comparison for backward compatibility.

Functions

fastvideo.fastvideo_args.ExecutionMode.choices classmethod
choices() -> list[str]

Get all available choices as strings for argparse.

Source code in fastvideo/fastvideo_args.py
@classmethod
def choices(cls) -> list[str]:
    """Get all available choices as strings for argparse."""
    return [mode.value for mode in cls]
fastvideo.fastvideo_args.ExecutionMode.from_string classmethod
from_string(value: str) -> ExecutionMode

Convert string to ExecutionMode enum.

Source code in fastvideo/fastvideo_args.py
@classmethod
def from_string(cls, value: str) -> "ExecutionMode":
    """Convert string to ExecutionMode enum."""
    try:
        return cls(value.lower())
    except ValueError:
        raise ValueError(
            f"Invalid mode: {value}. Must be one of: {', '.join([m.value for m in cls])}"
        ) from None

fastvideo.fastvideo_args.FastVideoArgs dataclass

FastVideoArgs(model_path: str, mode: ExecutionMode = INFERENCE, workload_type: WorkloadType = T2V, cache_strategy: str = 'none', distributed_executor_backend: str = 'mp', ray_placement_group: PlacementGroup | None = None, ray_runtime_env: RuntimeEnv | None = None, inference_mode: bool = True, trust_remote_code: bool = False, revision: str | None = None, num_gpus: int = 1, tp_size: int = -1, sp_size: int = -1, hsdp_replicate_dim: int = 1, hsdp_shard_dim: int = -1, dist_timeout: int | None = None, pipeline_config: PipelineConfig = PipelineConfig(), preprocess_config: PreprocessConfig | None = None, lora_path: str | None = None, lora_nickname: str = 'default', lora_target_modules: list[str] | None = None, output_type: str = 'pil', dit_cpu_offload: bool = True, use_fsdp_inference: bool = True, text_encoder_cpu_offload: bool = True, image_encoder_cpu_offload: bool = True, vae_cpu_offload: bool = True, pin_cpu_memory: bool = True, mask_strategy_file_path: str | None = None, STA_mode: STA_Mode = STA_INFERENCE, skip_time_steps: int = 15, enable_torch_compile: bool = False, torch_compile_kwargs: dict[str, Any] = dict(), disable_autocast: bool = False, VSA_sparsity: float = 0.0, moba_config_path: str | None = None, moba_config: dict[str, Any] = dict(), master_port: int | None = None, enable_stage_verification: bool = True, prompt_txt: str | None = None, model_paths: dict[str, str] = dict(), model_loaded: dict[str, bool] = (lambda: {'transformer': True, 'vae': True})(), override_transformer_cls_name: str | None = None, init_weights_from_safetensors: str = '', init_weights_from_safetensors_2: str = '', boundary_ratio: float | None = 0.875)

Functions

fastvideo.fastvideo_args.FastVideoArgs.check_fastvideo_args
check_fastvideo_args() -> None

Validate inference arguments for consistency

Source code in fastvideo/fastvideo_args.py
def check_fastvideo_args(self) -> None:
    """Validate inference arguments for consistency"""
    from fastvideo.platforms import current_platform

    if current_platform.is_mps():
        self.use_fsdp_inference = False

    # Validate mode and inference_mode consistency
    assert isinstance(
        self.mode, ExecutionMode
    ), f"Mode must be an ExecutionMode enum, got {type(self.mode)}"
    assert self.mode in ExecutionMode.choices(
    ), f"Invalid execution mode: {self.mode}"

    # Validate workload type
    assert isinstance(
        self.workload_type, WorkloadType
    ), f"Workload type must be a WorkloadType enum, got {type(self.workload_type)}"
    assert self.workload_type in WorkloadType.choices(
    ), f"Invalid workload type: {self.workload_type}"

    if self.mode in [ExecutionMode.DISTILLATION, ExecutionMode.FINETUNING
                     ] and self.inference_mode:
        logger.warning(
            "Mode is 'training' but inference_mode is True. Setting inference_mode to False."
        )
        self.inference_mode = False
    elif self.mode in [ExecutionMode.INFERENCE, ExecutionMode.PREPROCESS
                       ] and not self.inference_mode:
        logger.warning(
            "Mode is '%s' but inference_mode is False. Setting inference_mode to True.",
            self.mode)
        self.inference_mode = True

    if not self.inference_mode:
        assert self.hsdp_replicate_dim != -1, "hsdp_replicate_dim must be set for training"
        assert self.hsdp_shard_dim != -1, "hsdp_shard_dim must be set for training"
        assert self.sp_size != -1, "sp_size must be set for training"

    if self.tp_size == -1:
        self.tp_size = 1
    if self.sp_size == -1:
        self.sp_size = self.num_gpus
    if self.hsdp_shard_dim == -1:
        self.hsdp_shard_dim = self.num_gpus

    assert self.sp_size <= self.num_gpus and self.num_gpus % self.sp_size == 0, "num_gpus must >= and be divisible by sp_size"
    assert self.hsdp_replicate_dim <= self.num_gpus and self.num_gpus % self.hsdp_replicate_dim == 0, "num_gpus must >= and be divisible by hsdp_replicate_dim"
    assert self.hsdp_shard_dim <= self.num_gpus and self.num_gpus % self.hsdp_shard_dim == 0, "num_gpus must >= and be divisible by hsdp_shard_dim"

    if self.num_gpus < max(self.tp_size, self.sp_size):
        self.num_gpus = max(self.tp_size, self.sp_size)

    if self.pipeline_config is None:
        raise ValueError("pipeline_config is not set in FastVideoArgs")

    self.pipeline_config.check_pipeline_config()

    # Add preprocessing config validation if needed
    if self.mode == ExecutionMode.PREPROCESS:
        if self.preprocess_config is None:
            raise ValueError(
                "preprocess_config is not set in FastVideoArgs when mode is PREPROCESS"
            )
        if self.preprocess_config.model_path == "":
            self.preprocess_config.model_path = self.model_path
        if not self.pipeline_config.vae_config.load_encoder:
            self.pipeline_config.vae_config.load_encoder = True
        self.preprocess_config.check_preprocess_config()

fastvideo.fastvideo_args.TrainingArgs dataclass

TrainingArgs(model_path: str, mode: ExecutionMode = INFERENCE, workload_type: WorkloadType = T2V, cache_strategy: str = 'none', distributed_executor_backend: str = 'mp', ray_placement_group: PlacementGroup | None = None, ray_runtime_env: RuntimeEnv | None = None, inference_mode: bool = True, trust_remote_code: bool = False, revision: str | None = None, num_gpus: int = 1, tp_size: int = -1, sp_size: int = -1, hsdp_replicate_dim: int = 1, hsdp_shard_dim: int = -1, dist_timeout: int | None = None, pipeline_config: PipelineConfig = PipelineConfig(), preprocess_config: PreprocessConfig | None = None, lora_path: str | None = None, lora_nickname: str = 'default', lora_target_modules: list[str] | None = None, output_type: str = 'pil', dit_cpu_offload: bool = True, use_fsdp_inference: bool = True, text_encoder_cpu_offload: bool = True, image_encoder_cpu_offload: bool = True, vae_cpu_offload: bool = True, pin_cpu_memory: bool = True, mask_strategy_file_path: str | None = None, STA_mode: STA_Mode = STA_INFERENCE, skip_time_steps: int = 15, enable_torch_compile: bool = False, torch_compile_kwargs: dict[str, Any] = dict(), disable_autocast: bool = False, VSA_sparsity: float = 0.0, moba_config_path: str | None = None, moba_config: dict[str, Any] = dict(), master_port: int | None = None, enable_stage_verification: bool = True, prompt_txt: str | None = None, model_paths: dict[str, str] = dict(), model_loaded: dict[str, bool] = (lambda: {'transformer': True, 'vae': True})(), override_transformer_cls_name: str | None = None, init_weights_from_safetensors: str = '', init_weights_from_safetensors_2: str = '', boundary_ratio: float | None = 0.875, data_path: str = '', dataloader_num_workers: int = 0, num_height: int = 0, num_width: int = 0, num_frames: int = 0, train_batch_size: int = 0, num_latent_t: int = 0, group_frame: bool = False, group_resolution: bool = False, pretrained_model_name_or_path: str = '', real_score_model_path: str = '', fake_score_model_path: str = '', ema_decay: float = 0.0, ema_start_step: int = 0, training_cfg_rate: float = 0.0, precondition_outputs: bool = False, validation_dataset_file: str = '', validation_preprocessed_path: str = '', validation_sampling_steps: str = '', validation_guidance_scale: str = '', validation_steps: float = 0.0, log_validation: bool = False, trackers: list[str] = list(), tracker_project_name: str = '', wandb_run_name: str = '', seed: int | None = None, output_dir: str = '', checkpoints_total_limit: int = 0, resume_from_checkpoint: str = '', num_train_epochs: int = 0, max_train_steps: int = 0, gradient_accumulation_steps: int = 0, learning_rate: float = 0.0, scale_lr: bool = False, lr_scheduler: str = 'constant', lr_warmup_steps: int = 0, max_grad_norm: float = 0.0, enable_gradient_checkpointing_type: str | None = None, selective_checkpointing: float = 0.0, mixed_precision: str = '', train_sp_batch_size: int = 0, fsdp_sharding_startegy: str = '', weighting_scheme: str = '', logit_mean: float = 0.0, logit_std: float = 1.0, mode_scale: float = 0.0, num_euler_timesteps: int = 0, lr_num_cycles: int = 0, lr_power: float = 0.0, min_lr_ratio: float = 0.5, not_apply_cfg_solver: bool = False, distill_cfg: float = 0.0, scheduler_type: str = '', linear_quadratic_threshold: float = 0.0, linear_range: float = 0.0, weight_decay: float = 0.0, betas: str = '0.9,0.999', use_ema: bool = False, multi_phased_distill_schedule: str = '', pred_decay_weight: float = 0.0, pred_decay_type: str = '', hunyuan_teacher_disable_cfg: bool = False, master_weight_type: str = '', VSA_decay_rate: float = 0.01, VSA_decay_interval_steps: int = 1, lora_rank: int | None = None, lora_alpha: int | None = None, lora_training: bool = False, generator_update_interval: int = 5, dfake_gen_update_ratio: int = 5, min_timestep_ratio: float = 0.2, max_timestep_ratio: float = 0.98, real_score_guidance_scale: float = 3.5, fake_score_learning_rate: float = 0.0, fake_score_lr_scheduler: str = 'constant', fake_score_betas: str = '0.9,0.999', training_state_checkpointing_steps: int = 0, weight_only_checkpointing_steps: int = 0, log_visualization: bool = False, simulate_generator_forward: bool = False, warp_denoising_step: bool = False, num_frame_per_block: int = 3, independent_first_frame: bool = False, enable_gradient_masking: bool = True, gradient_mask_last_n_frames: int = 21, same_step_across_blocks: bool = False, last_step_only: bool = False, context_noise: int = 0)

Bases: FastVideoArgs

Training arguments. Inherits from FastVideoArgs and adds training-specific arguments. If there are any conflicts, the training arguments will take precedence.

fastvideo.fastvideo_args.WorkloadType

Bases: str, Enum

Enumeration for different workload types.

Inherits from str to allow string comparison for backward compatibility.

Functions

fastvideo.fastvideo_args.WorkloadType.choices classmethod
choices() -> list[str]

Get all available choices as strings for argparse.

Source code in fastvideo/fastvideo_args.py
@classmethod
def choices(cls) -> list[str]:
    """Get all available choices as strings for argparse."""
    return [workload.value for workload in cls]
fastvideo.fastvideo_args.WorkloadType.from_string classmethod
from_string(value: str) -> WorkloadType

Convert string to WorkloadType enum.

Source code in fastvideo/fastvideo_args.py
@classmethod
def from_string(cls, value: str) -> "WorkloadType":
    """Convert string to WorkloadType enum."""
    try:
        return cls(value.lower())
    except ValueError:
        raise ValueError(
            f"Invalid workload type: {value}. Must be one of: {', '.join([m.value for m in cls])}"
        ) from None

Functions

fastvideo.fastvideo_args.prepare_fastvideo_args

prepare_fastvideo_args(argv: list[str]) -> FastVideoArgs

Prepare the inference arguments from the command line arguments.

Parameters:

Name Type Description Default
argv list[str]

The command line arguments. Typically, it should be sys.argv[1:] to ensure compatibility with parse_args when no arguments are passed.

required

Returns:

Type Description
FastVideoArgs

The inference arguments.

Source code in fastvideo/fastvideo_args.py
def prepare_fastvideo_args(argv: list[str]) -> FastVideoArgs:
    """
    Prepare the inference arguments from the command line arguments.

    Args:
        argv: The command line arguments. Typically, it should be `sys.argv[1:]`
            to ensure compatibility with `parse_args` when no arguments are passed.

    Returns:
        The inference arguments.
    """
    parser = FlexibleArgumentParser()
    FastVideoArgs.add_cli_args(parser)
    raw_args = parser.parse_args(argv)
    fastvideo_args = FastVideoArgs.from_cli_args(raw_args)
    global _current_fastvideo_args
    _current_fastvideo_args = fastvideo_args
    return fastvideo_args

fastvideo.fastvideo_args.set_current_fastvideo_args

set_current_fastvideo_args(fastvideo_args: FastVideoArgs)

Temporarily set the current fastvideo config. Used during model initialization. We save the current fastvideo config in a global variable, so that all modules can access it, e.g. custom ops can access the fastvideo config to determine how to dispatch.

Source code in fastvideo/fastvideo_args.py
@contextmanager
def set_current_fastvideo_args(fastvideo_args: FastVideoArgs):
    """
    Temporarily set the current fastvideo config.
    Used during model initialization.
    We save the current fastvideo config in a global variable,
    so that all modules can access it, e.g. custom ops
    can access the fastvideo config to determine how to dispatch.
    """
    global _current_fastvideo_args
    old_fastvideo_args = _current_fastvideo_args
    try:
        _current_fastvideo_args = fastvideo_args
        yield
    finally:
        _current_fastvideo_args = old_fastvideo_args