Skip to content

training

Classes

fastvideo.training.DistillationPipeline

DistillationPipeline(model_path: str, fastvideo_args: TrainingArgs, required_config_modules: list[str] | None = None, loaded_modules: dict[str, Module] | None = None)

Bases: TrainingPipeline

A distillation pipeline for training a 3 step model. Inherits from TrainingPipeline to reuse training infrastructure.

Source code in fastvideo/training/training_pipeline.py
def __init__(
        self,
        model_path: str,
        fastvideo_args: TrainingArgs,
        required_config_modules: list[str] | None = None,
        loaded_modules: dict[str, torch.nn.Module] | None = None) -> None:
    fastvideo_args.inference_mode = False
    self.lora_training = fastvideo_args.lora_training
    if self.lora_training and fastvideo_args.lora_rank is None:
        raise ValueError("lora rank must be set when using lora training")

    set_random_seed(fastvideo_args.seed)  # for lora param init
    super().__init__(model_path, fastvideo_args, required_config_modules,
                     loaded_modules)  # type: ignore
    self.tracker = DummyTracker()

Functions

fastvideo.training.DistillationPipeline.apply_ema_to_model
apply_ema_to_model(model)

Apply EMA weights to the model for validation or inference.

Source code in fastvideo/training/distillation_pipeline.py
def apply_ema_to_model(self, model):
    """Apply EMA weights to the model for validation or inference."""
    if model is self.transformer and self.generator_ema is not None:
        with self.generator_ema.apply_to_model(model):
            return model
    elif model is self.transformer_2 and self.generator_ema_2 is not None:
        with self.generator_ema_2.apply_to_model(model):
            return model
    return model
fastvideo.training.DistillationPipeline.get_ema_2_model_copy
get_ema_2_model_copy() -> Module | None

Get a copy of the transformer_2 model with EMA weights applied.

Source code in fastvideo/training/distillation_pipeline.py
def get_ema_2_model_copy(self) -> torch.nn.Module | None:
    """Get a copy of the transformer_2 model with EMA weights applied."""
    if self.generator_ema_2 is not None and self.transformer_2 is not None:
        ema_2_model = copy.deepcopy(self.transformer_2)
        self.generator_ema_2.copy_to_unwrapped(ema_2_model)
        return ema_2_model
    return None
fastvideo.training.DistillationPipeline.get_ema_model_copy
get_ema_model_copy() -> Module | None

Get a copy of the model with EMA weights applied.

Source code in fastvideo/training/distillation_pipeline.py
def get_ema_model_copy(self) -> torch.nn.Module | None:
    """Get a copy of the model with EMA weights applied."""
    if self.generator_ema is not None:
        ema_model = copy.deepcopy(self.transformer)
        self.generator_ema.copy_to_unwrapped(ema_model)
        return ema_model
    return None
fastvideo.training.DistillationPipeline.get_ema_stats
get_ema_stats() -> dict[str, Any]

Get EMA statistics for monitoring.

Source code in fastvideo/training/distillation_pipeline.py
def get_ema_stats(self) -> dict[str, Any]:
    """Get EMA statistics for monitoring."""
    ema_enabled = self.generator_ema is not None
    ema_2_enabled = self.generator_ema_2 is not None

    if not ema_enabled and not ema_2_enabled:
        return {
            "ema_enabled": False,
            "ema_2_enabled": False,
            "ema_decay": None,
            "ema_start_step": self.training_args.ema_start_step,
            "ema_ready": False,
            "ema_2_ready": False,
            "ema_step": self.current_trainstep,
        }

    return {
        "ema_enabled": ema_enabled,
        "ema_2_enabled": ema_2_enabled,
        "ema_decay": self.training_args.ema_decay,
        "ema_start_step": self.training_args.ema_start_step,
        "ema_ready": self.is_ema_ready() if ema_enabled else False,
        "ema_2_ready": self.is_ema_ready() if ema_2_enabled else False,
        "ema_step": self.current_trainstep,
    }
fastvideo.training.DistillationPipeline.initialize_training_pipeline
initialize_training_pipeline(training_args: TrainingArgs)

Initialize the distillation training pipeline with multiple models.

Source code in fastvideo/training/distillation_pipeline.py
def initialize_training_pipeline(self, training_args: TrainingArgs):
    """Initialize the distillation training pipeline with multiple models."""
    logger.info("Initializing distillation pipeline...")

    super().initialize_training_pipeline(training_args)

    self.noise_scheduler = self.get_module("scheduler")
    self.vae = self.get_module("vae")
    self.vae.requires_grad_(False)

    self.timestep_shift = self.training_args.pipeline_config.flow_shift
    self.noise_scheduler = FlowMatchEulerDiscreteScheduler(
        shift=self.timestep_shift)

    if self.training_args.boundary_ratio is not None:
        self.boundary_timestep = self.training_args.boundary_ratio * self.noise_scheduler.num_train_timesteps
    else:
        self.boundary_timestep = None

    if training_args.real_score_model_path:
        logger.info("Loading real score transformer from: %s",
                    training_args.real_score_model_path)
        training_args.override_transformer_cls_name = "WanTransformer3DModel"
        self.real_score_transformer = self.load_module_from_path(
            training_args.real_score_model_path, "transformer",
            training_args)
        try:
            self.real_score_transformer_2 = self.load_module_from_path(
                training_args.real_score_model_path, "transformer_2",
                training_args)
            logger.info("Loaded real score transformer_2 for MoE support")
        except Exception:
            logger.info(
                "real score transformer_2 not found, using single transformer"
            )
            self.real_score_transformer_2 = None
    else:
        self.real_score_transformer = self.get_module(
            "real_score_transformer")
        self.real_score_transformer_2 = self.get_module(
            "real_score_transformer_2")

    if training_args.fake_score_model_path:
        logger.info("Loading fake score transformer from: %s",
                    training_args.fake_score_model_path)
        training_args.override_transformer_cls_name = "WanTransformer3DModel"
        self.fake_score_transformer = self.load_module_from_path(
            training_args.fake_score_model_path, "transformer",
            training_args)
        try:
            self.fake_score_transformer_2 = self.load_module_from_path(
                training_args.fake_score_model_path, "transformer_2",
                training_args)
            logger.info("Loaded fake score transformer_2 for MoE support")
        except Exception:
            logger.info(
                "fake score transformer_2 not found, using single transformer"
            )
            self.fake_score_transformer_2 = None
    else:
        self.fake_score_transformer = self.get_module(
            "fake_score_transformer")
        self.fake_score_transformer_2 = self.get_module(
            "fake_score_transformer_2")

    self.real_score_transformer.requires_grad_(False)
    self.real_score_transformer.eval()
    if self.real_score_transformer_2 is not None:
        self.real_score_transformer_2.requires_grad_(False)
        self.real_score_transformer_2.eval()

    # Set training modes for fake score transformers (trainable)
    self.fake_score_transformer.requires_grad_(True)
    self.fake_score_transformer.train()
    if self.fake_score_transformer_2 is not None:
        self.fake_score_transformer_2.requires_grad_(True)
        self.fake_score_transformer_2.train()

    if training_args.enable_gradient_checkpointing_type is not None:
        self.fake_score_transformer = apply_activation_checkpointing(
            self.fake_score_transformer,
            checkpointing_type=training_args.
            enable_gradient_checkpointing_type)
        if self.fake_score_transformer_2 is not None:
            self.fake_score_transformer_2 = apply_activation_checkpointing(
                self.fake_score_transformer_2,
                checkpointing_type=training_args.
                enable_gradient_checkpointing_type)

        self.real_score_transformer = apply_activation_checkpointing(
            self.real_score_transformer,
            checkpointing_type=training_args.
            enable_gradient_checkpointing_type)
        if self.real_score_transformer_2 is not None:
            self.real_score_transformer_2 = apply_activation_checkpointing(
                self.real_score_transformer_2,
                checkpointing_type=training_args.
                enable_gradient_checkpointing_type)

    # Initialize optimizers
    fake_score_params = list(
        filter(lambda p: p.requires_grad,
               self.fake_score_transformer.parameters()))

    # Use separate learning rate for fake_score_transformer if specified
    fake_score_lr = training_args.fake_score_learning_rate
    if fake_score_lr == 0.0:
        fake_score_lr = training_args.learning_rate

    betas_str = training_args.fake_score_betas
    betas = tuple(float(x.strip()) for x in betas_str.split(","))

    self.fake_score_optimizer = torch.optim.AdamW(
        fake_score_params,
        lr=fake_score_lr,
        betas=betas,
        weight_decay=training_args.weight_decay,
        eps=1e-8,
    )

    self.fake_score_lr_scheduler = get_scheduler(
        training_args.fake_score_lr_scheduler,
        optimizer=self.fake_score_optimizer,
        num_warmup_steps=training_args.lr_warmup_steps,
        num_training_steps=training_args.max_train_steps,
        num_cycles=training_args.lr_num_cycles,
        power=training_args.lr_power,
        min_lr_ratio=training_args.min_lr_ratio,
        last_epoch=self.init_steps - 1,
    )

    if self.fake_score_transformer_2 is not None:
        fake_score_params_2 = list(
            filter(lambda p: p.requires_grad,
                   self.fake_score_transformer_2.parameters()))
        self.fake_score_optimizer_2 = torch.optim.AdamW(
            fake_score_params_2,
            lr=fake_score_lr,
            betas=betas,
            weight_decay=training_args.weight_decay,
            eps=1e-8,
        )
        self.fake_score_lr_scheduler_2 = get_scheduler(
            training_args.fake_score_lr_scheduler,
            optimizer=self.fake_score_optimizer_2,
            num_warmup_steps=training_args.lr_warmup_steps,
            num_training_steps=training_args.max_train_steps,
            num_cycles=training_args.lr_num_cycles,
            power=training_args.lr_power,
            min_lr_ratio=training_args.min_lr_ratio,
            last_epoch=self.init_steps - 1,
        )

    logger.info(
        "Distillation optimizers initialized: generator and fake_score")

    self.generator_update_interval = self.training_args.generator_update_interval
    logger.info(
        "Distillation pipeline initialized with generator_update_interval=%s",
        self.generator_update_interval)

    self.denoising_step_list = torch.tensor(
        self.training_args.pipeline_config.dmd_denoising_steps,
        dtype=torch.long,
        device=get_local_torch_device())

    if training_args.warp_denoising_step:  # Warp the denoising step according to the scheduler time shift
        timesteps = torch.cat((self.noise_scheduler.timesteps.cpu(),
                               torch.tensor([0],
                                            dtype=torch.float32))).cuda()
        self.denoising_step_list = timesteps[1000 -
                                             self.denoising_step_list]
        logger.info("Warping denoising_step_list")

    self.denoising_step_list = self.denoising_step_list.to(
        get_local_torch_device())
    logger.info("Distillation generator model to %s denoising steps: %s",
                len(self.denoising_step_list), self.denoising_step_list)
    self.num_train_timestep = self.noise_scheduler.num_train_timesteps

    self.min_timestep = int(self.training_args.min_timestep_ratio *
                            self.num_train_timestep)
    self.max_timestep = int(self.training_args.max_timestep_ratio *
                            self.num_train_timestep)

    self.real_score_guidance_scale = self.training_args.real_score_guidance_scale

    self.generator_ema: EMA_FSDP | None = None
    self.generator_ema_2: EMA_FSDP | None = None
    if (self.training_args.ema_decay
            is not None) and (self.training_args.ema_decay > 0.0):
        self.generator_ema = EMA_FSDP(self.transformer,
                                      decay=self.training_args.ema_decay)
        logger.info("Initialized generator EMA with decay=%s",
                    self.training_args.ema_decay)

        # Initialize EMA for transformer_2 if it exists
        if self.transformer_2 is not None:
            self.generator_ema_2 = EMA_FSDP(
                self.transformer_2, decay=self.training_args.ema_decay)
            logger.info("Initialized generator EMA_2 with decay=%s",
                        self.training_args.ema_decay)
    else:
        logger.info("Generator EMA disabled (ema_decay <= 0.0)")
fastvideo.training.DistillationPipeline.initialize_validation_pipeline abstractmethod
initialize_validation_pipeline(training_args: TrainingArgs)

Initialize validation pipeline - must be implemented by subclasses.

Source code in fastvideo/training/distillation_pipeline.py
@abstractmethod
def initialize_validation_pipeline(self, training_args: TrainingArgs):
    """Initialize validation pipeline - must be implemented by subclasses."""
    raise NotImplementedError(
        "Distillation pipelines must implement this method")
fastvideo.training.DistillationPipeline.is_ema_ready
is_ema_ready(current_step: int | None = None)

Check if EMA is ready for use (after ema_start_step).

Source code in fastvideo/training/distillation_pipeline.py
def is_ema_ready(self, current_step: int | None = None):
    """Check if EMA is ready for use (after ema_start_step)."""
    if current_step is None:
        current_step = getattr(self, 'current_trainstep', 0)
    return (self.generator_ema is not None
            and current_step >= self.training_args.ema_start_step)
fastvideo.training.DistillationPipeline.load_module_from_path
load_module_from_path(model_path: str, module_type: str, training_args: TrainingArgs)

Load a module from a specific path using the same loading logic as the pipeline.

Parameters:

Name Type Description Default
model_path str

Path to the model

required
module_type str

Type of module to load (e.g., "transformer")

required
training_args TrainingArgs

Training arguments

required

Returns:

Type Description

The loaded module

Source code in fastvideo/training/distillation_pipeline.py
def load_module_from_path(self, model_path: str, module_type: str,
                          training_args: "TrainingArgs"):
    """
    Load a module from a specific path using the same loading logic as the pipeline.

    Args:
        model_path: Path to the model
        module_type: Type of module to load (e.g., "transformer")
        training_args: Training arguments

    Returns:
        The loaded module
    """
    logger.info("Loading %s from custom path: %s", module_type, model_path)
    # Set flag to prevent custom weight loading for teacher/critic models
    training_args._loading_teacher_critic_model = True

    try:
        from fastvideo.models.loader.component_loader import (
            PipelineComponentLoader)

        # Download the model if it's a Hugging Face model ID
        local_model_path = maybe_download_model(model_path)
        logger.info("Model downloaded/found at: %s", local_model_path)
        config = verify_model_config_and_directory(local_model_path)

        if module_type not in config:
            if hasattr(self, '_extra_config_module_map'
                       ) and module_type in self._extra_config_module_map:
                extra_module = self._extra_config_module_map[module_type]
                if extra_module in config:
                    module_type = extra_module
                    logger.info("Using %s for %s", extra_module,
                                module_type)
                else:
                    raise ValueError(
                        f"Module {module_type} not found in config at {local_model_path}"
                    )
            else:
                raise ValueError(
                    f"Module {module_type} not found in config at {local_model_path}"
                )

        module_info = config[module_type]
        if module_info is None:
            raise ValueError(
                f"Module {module_type} has null value in config at {local_model_path}"
            )

        transformers_or_diffusers, architecture = module_info
        component_path = os.path.join(local_model_path, module_type)
        module = PipelineComponentLoader.load_module(
            module_name=module_type,
            component_model_path=component_path,
            transformers_or_diffusers=transformers_or_diffusers,
            fastvideo_args=training_args,
        )

        logger.info("Successfully loaded %s from %s", module_type,
                    component_path)
        return module
    finally:
        # Always clean up the flag
        if hasattr(training_args, '_loading_teacher_critic_model'):
            delattr(training_args, '_loading_teacher_critic_model')
fastvideo.training.DistillationPipeline.reset_ema
reset_ema()

Reset EMA to current model weights.

Source code in fastvideo/training/distillation_pipeline.py
def reset_ema(self):
    """Reset EMA to current model weights."""
    if self.generator_ema is not None:
        logger.info("Resetting EMA to current model weights")
        self.generator_ema.update(self.transformer)
        # Force update to current weights by setting decay to 0 temporarily
        original_decay = self.generator_ema.decay
        self.generator_ema.decay = 0.0
        self.generator_ema.update(self.transformer)
        self.generator_ema.decay = original_decay
        logger.info("EMA reset completed")
    else:
        logger.warning("Cannot reset EMA: EMA not initialized")

    if self.generator_ema_2 is not None:
        logger.info("Resetting EMA_2 to current model weights")
        self.generator_ema_2.update(self.transformer_2)
        # Force update to current weights by setting decay to 0 temporarily
        original_decay_2 = self.generator_ema_2.decay
        self.generator_ema_2.decay = 0.0
        self.generator_ema_2.update(self.transformer_2)
        self.generator_ema_2.decay = original_decay_2
        logger.info("EMA_2 reset completed")
fastvideo.training.DistillationPipeline.save_ema_weights
save_ema_weights(output_dir: str, step: int)

Save EMA weights separately for inference purposes.

Source code in fastvideo/training/distillation_pipeline.py
def save_ema_weights(self, output_dir: str, step: int):
    """Save EMA weights separately for inference purposes."""
    if self.generator_ema is None and self.generator_ema_2 is None:
        logger.warning("Cannot save EMA weights: No EMA initialized")
        return

    if not self.is_ema_ready():
        logger.warning(
            "Cannot save EMA weights: EMA not ready yet (step < ema_start_step)"
        )
        return

    try:
        # Save main transformer EMA
        if self.generator_ema is not None:
            ema_model = self.get_ema_model_copy()
            if ema_model is None:
                logger.warning("Failed to create EMA model copy")
            else:
                ema_save_dir = os.path.join(output_dir,
                                            f"ema_checkpoint-{step}")
                os.makedirs(ema_save_dir, exist_ok=True)

                # save as diffusers format
                from safetensors.torch import save_file

                from fastvideo.training.training_utils import (
                    custom_to_hf_state_dict, gather_state_dict_on_cpu_rank0)
                cpu_state = gather_state_dict_on_cpu_rank0(ema_model,
                                                           device=None)

                if self.global_rank == 0:
                    weight_path = os.path.join(
                        ema_save_dir, "diffusion_pytorch_model.safetensors")
                    diffusers_state_dict = custom_to_hf_state_dict(
                        cpu_state, ema_model.reverse_param_names_mapping)
                    save_file(diffusers_state_dict, weight_path)

                    config_dict = ema_model.hf_config
                    if "dtype" in config_dict:
                        del config_dict["dtype"]
                    config_path = os.path.join(ema_save_dir, "config.json")
                    with open(config_path, "w") as f:
                        json.dump(config_dict, f, indent=4)

                    logger.info("EMA weights saved to %s", weight_path)

                del ema_model

        # Save transformer_2 EMA
        if self.generator_ema_2 is not None:
            ema_2_model = self.get_ema_2_model_copy()
            if ema_2_model is None:
                logger.warning("Failed to create EMA_2 model copy")
            else:
                ema_2_save_dir = os.path.join(output_dir,
                                              f"ema_2_checkpoint-{step}")
                os.makedirs(ema_2_save_dir, exist_ok=True)

                # save as diffusers format
                from safetensors.torch import save_file

                from fastvideo.training.training_utils import (
                    custom_to_hf_state_dict, gather_state_dict_on_cpu_rank0)
                cpu_state_2 = gather_state_dict_on_cpu_rank0(ema_2_model,
                                                             device=None)

                if self.global_rank == 0:
                    weight_path_2 = os.path.join(
                        ema_2_save_dir,
                        "diffusion_pytorch_model.safetensors")
                    diffusers_state_dict_2 = custom_to_hf_state_dict(
                        cpu_state_2,
                        ema_2_model.reverse_param_names_mapping)
                    save_file(diffusers_state_dict_2, weight_path_2)

                    config_dict_2 = ema_2_model.hf_config
                    if "dtype" in config_dict_2:
                        del config_dict_2["dtype"]
                    config_path_2 = os.path.join(ema_2_save_dir,
                                                 "config.json")
                    with open(config_path_2, "w") as f:
                        json.dump(config_dict_2, f, indent=4)

                    logger.info("EMA_2 weights saved to %s", weight_path_2)

                del ema_2_model

    except Exception as e:
        logger.error("Failed to save EMA weights: %s", str(e))
fastvideo.training.DistillationPipeline.train
train() -> None

Main training loop with distillation-specific logging.

Source code in fastvideo/training/distillation_pipeline.py
def train(self) -> None:
    """Main training loop with distillation-specific logging."""
    assert self.training_args.seed is not None, "seed must be set"
    seed = self.training_args.seed

    # Set the same seed within each SP group to ensure reproducibility
    if self.sp_world_size > 1:
        # Use the same seed for all processes within the same SP group
        sp_group_seed = seed + (self.global_rank // self.sp_world_size)
        set_random_seed(sp_group_seed)
        logger.info("Rank %s: Using SP group seed %s", self.global_rank,
                    sp_group_seed)
    else:
        set_random_seed(seed + self.global_rank)

    # Set random seeds for deterministic training
    self.noise_random_generator = torch.Generator(device="cpu").manual_seed(
        self.seed)
    self.noise_gen_cuda = torch.Generator(device="cuda").manual_seed(
        self.seed)
    self.validation_random_generator = torch.Generator(
        device="cpu").manual_seed(self.seed)
    logger.info("Initialized random seeds with seed: %s", seed)

    # Initialize current_trainstep for EMA ready checks
    #TODO: check if needed
    self.current_trainstep = self.init_steps

    # Resume from checkpoint if specified (this will restore random states)
    if self.training_args.resume_from_checkpoint:
        self._resume_from_checkpoint()
        logger.info("Resumed from checkpoint, random states restored")
    else:
        logger.info("Starting training from scratch")

    self.train_loader_iter = iter(self.train_dataloader)

    step_times: deque[float] = deque(maxlen=100)

    self._log_training_info()
    self._log_validation(self.transformer, self.training_args,
                         self.init_steps)

    progress_bar = tqdm(
        range(0, self.training_args.max_train_steps),
        initial=self.init_steps,
        desc="Steps",
        disable=self.local_rank > 0,
    )

    use_vsa = vsa_available and envs.FASTVIDEO_ATTENTION_BACKEND == "VIDEO_SPARSE_ATTN"
    for step in range(self.init_steps + 1,
                      self.training_args.max_train_steps + 1):
        start_time = time.perf_counter()
        if use_vsa:
            vsa_sparsity = self.training_args.VSA_sparsity
            vsa_decay_rate = self.training_args.VSA_decay_rate
            vsa_decay_interval_steps = self.training_args.VSA_decay_interval_steps
            if vsa_decay_interval_steps > 1:
                current_decay_times = min(step // vsa_decay_interval_steps,
                                          vsa_sparsity // vsa_decay_rate)
                current_vsa_sparsity = current_decay_times * vsa_decay_rate
            else:
                current_vsa_sparsity = vsa_sparsity
        else:
            current_vsa_sparsity = 0.0

        training_batch = TrainingBatch()
        self.current_trainstep = step
        training_batch.current_vsa_sparsity = current_vsa_sparsity

        if (step >= self.training_args.ema_start_step) and \
                (self.generator_ema is None) and (self.training_args.ema_decay > 0):
            self.generator_ema = EMA_FSDP(
                self.transformer, decay=self.training_args.ema_decay)
            logger.info("Created generator EMA at step %s with decay=%s",
                        step, self.training_args.ema_decay)

            # Create EMA for transformer_2 if it exists
            if self.transformer_2 is not None and self.generator_ema_2 is None:
                self.generator_ema_2 = EMA_FSDP(
                    self.transformer_2, decay=self.training_args.ema_decay)
                logger.info(
                    "Created generator EMA_2 at step %s with decay=%s",
                    step, self.training_args.ema_decay)

        with torch.autocast("cuda", dtype=torch.bfloat16):
            training_batch = self.train_one_step(training_batch)

        total_loss = training_batch.total_loss
        generator_loss = training_batch.generator_loss
        fake_score_loss = training_batch.fake_score_loss
        grad_norm = training_batch.grad_norm

        step_time = time.perf_counter() - start_time
        step_times.append(step_time)
        avg_step_time = sum(step_times) / len(step_times)

        progress_bar.set_postfix({
            "total_loss":
            f"{total_loss:.4f}",
            "generator_loss":
            f"{generator_loss:.4f}",
            "fake_score_loss":
            f"{fake_score_loss:.4f}",
            "step_time":
            f"{step_time:.2f}s",
            "grad_norm":
            grad_norm,
            "ema":
            "✓" if (self.generator_ema is not None and self.is_ema_ready())
            else "✗",
            "ema2":
            "✓" if (self.generator_ema_2 is not None
                    and self.is_ema_ready()) else "✗",
        })
        progress_bar.update(1)

        if self.global_rank == 0:
            # Prepare logging data
            log_data = {
                "train_total_loss":
                total_loss,
                "train_fake_score_loss":
                fake_score_loss,
                "learning_rate":
                self.lr_scheduler.get_last_lr()[0],
                "fake_score_learning_rate":
                self.fake_score_lr_scheduler.get_last_lr()[0],
                "step_time":
                step_time,
                "avg_step_time":
                avg_step_time,
                "grad_norm":
                grad_norm,
            }
            # Only log generator loss when generator is actually trained
            if (step % self.generator_update_interval == 0):
                log_data["train_generator_loss"] = generator_loss
            if use_vsa:
                log_data["VSA_train_sparsity"] = current_vsa_sparsity

            if self.generator_ema is not None or self.generator_ema_2 is not None:
                log_data["ema_enabled"] = self.generator_ema is not None
                log_data["ema_2_enabled"] = self.generator_ema_2 is not None
                log_data["ema_decay"] = self.training_args.ema_decay
            else:
                log_data["ema_enabled"] = False
                log_data["ema_2_enabled"] = False

            ema_stats = self.get_ema_stats()
            log_data.update(ema_stats)

            if training_batch.dmd_latent_vis_dict:
                dmd_additional_logs = {
                    "generator_timestep":
                    training_batch.
                    dmd_latent_vis_dict["generator_timestep"].item(),
                    "dmd_timestep":
                    training_batch.dmd_latent_vis_dict["dmd_timestep"].item(
                    ),
                }
                log_data.update(dmd_additional_logs)

            faker_score_additional_logs = {
                "fake_score_timestep":
                training_batch.
                fake_score_latent_vis_dict["fake_score_timestep"].item(),
            }
            log_data.update(faker_score_additional_logs)

            self.tracker.log(log_data, step)

        # Save training state checkpoint (for resuming training)
        if (self.training_args.training_state_checkpointing_steps > 0
                and step %
                self.training_args.training_state_checkpointing_steps == 0):
            print("rank", self.global_rank,
                  "save training state checkpoint at step", step)
            save_distillation_checkpoint(
                self.transformer,
                self.fake_score_transformer,
                self.global_rank,
                self.training_args.output_dir,
                step,
                self.optimizer,
                self.fake_score_optimizer,
                self.train_dataloader,
                self.lr_scheduler,
                self.fake_score_lr_scheduler,
                self.noise_random_generator,
                self.generator_ema,
                # MoE support
                generator_transformer_2=getattr(self, 'transformer_2',
                                                None),
                real_score_transformer_2=getattr(
                    self, 'real_score_transformer_2', None),
                fake_score_transformer_2=getattr(
                    self, 'fake_score_transformer_2', None),
                generator_optimizer_2=getattr(self, 'optimizer_2', None),
                fake_score_optimizer_2=getattr(self,
                                               'fake_score_optimizer_2',
                                               None),
                generator_scheduler_2=getattr(self, 'lr_scheduler_2', None),
                fake_score_scheduler_2=getattr(self,
                                               'fake_score_lr_scheduler_2',
                                               None),
                generator_ema_2=getattr(self, 'generator_ema_2', None))

            if self.transformer:
                self.transformer.train()
            self.sp_group.barrier()

        # Save weight-only checkpoint
        if (self.training_args.weight_only_checkpointing_steps > 0
                and step %
                self.training_args.weight_only_checkpointing_steps == 0):
            print("rank", self.global_rank,
                  "save weight-only checkpoint at step", step)
            save_distillation_checkpoint(
                self.transformer,
                self.fake_score_transformer,
                self.global_rank,
                self.training_args.output_dir,
                f"{step}_weight_only",
                only_save_generator_weight=True,
                generator_ema=self.generator_ema,
                # MoE support
                generator_transformer_2=getattr(self, 'transformer_2',
                                                None),
                real_score_transformer_2=getattr(
                    self, 'real_score_transformer_2', None),
                fake_score_transformer_2=getattr(
                    self, 'fake_score_transformer_2', None),
                generator_optimizer_2=getattr(self, 'optimizer_2', None),
                fake_score_optimizer_2=getattr(self,
                                               'fake_score_optimizer_2',
                                               None),
                generator_scheduler_2=getattr(self, 'lr_scheduler_2', None),
                fake_score_scheduler_2=getattr(self,
                                               'fake_score_lr_scheduler_2',
                                               None),
                generator_ema_2=getattr(self, 'generator_ema_2', None))

            if self.training_args.use_ema and self.is_ema_ready():
                self.save_ema_weights(self.training_args.output_dir, step)

        if self.training_args.log_validation and step % self.training_args.validation_steps == 0:
            if self.training_args.log_visualization:
                self.visualize_intermediate_latents(training_batch,
                                                    self.training_args,
                                                    step)
            self._log_validation(self.transformer, self.training_args, step)

    self.tracker.finish()

    # Save final training state checkpoint
    print("rank", self.global_rank,
          "save final training state checkpoint at step",
          self.training_args.max_train_steps)
    save_distillation_checkpoint(
        self.transformer,
        self.fake_score_transformer,
        self.global_rank,
        self.training_args.output_dir,
        self.training_args.max_train_steps,
        self.optimizer,
        self.fake_score_optimizer,
        self.train_dataloader,
        self.lr_scheduler,
        self.fake_score_lr_scheduler,
        self.noise_random_generator,
        self.generator_ema,
        # MoE support
        generator_transformer_2=getattr(self, 'transformer_2', None),
        real_score_transformer_2=getattr(self, 'real_score_transformer_2',
                                         None),
        fake_score_transformer_2=getattr(self, 'fake_score_transformer_2',
                                         None),
        generator_optimizer_2=getattr(self, 'optimizer_2', None),
        fake_score_optimizer_2=getattr(self, 'fake_score_optimizer_2',
                                       None),
        generator_scheduler_2=getattr(self, 'lr_scheduler_2', None),
        fake_score_scheduler_2=getattr(self, 'fake_score_lr_scheduler_2',
                                       None),
        generator_ema_2=getattr(self, 'generator_ema_2', None))

    if self.training_args.use_ema and self.is_ema_ready():
        self.save_ema_weights(self.training_args.output_dir,
                              self.training_args.max_train_steps)

    if envs.FASTVIDEO_TORCH_PROFILER_DIR:
        logger.info("Stopping profiler...")
        self.profiler_controller.stop()
        logger.info("Profiler stopped.")

    if get_sp_group():
        cleanup_dist_env_and_memory()
fastvideo.training.DistillationPipeline.visualize_intermediate_latents
visualize_intermediate_latents(training_batch: TrainingBatch, training_args: TrainingArgs, step: int)

Add visualization data to tracker logging and save frames to disk.

Source code in fastvideo/training/distillation_pipeline.py
def visualize_intermediate_latents(self, training_batch: TrainingBatch,
                                   training_args: TrainingArgs, step: int):
    """Add visualization data to tracker logging and save frames to disk."""
    tracker_loss_dict: dict[str, Any] = {}
    dmd_latents_vis_dict = training_batch.dmd_latent_vis_dict
    fake_score_latents_vis_dict = training_batch.fake_score_latent_vis_dict
    fake_score_log_keys = ['generator_pred_video']
    dmd_log_keys = ['faker_score_pred_video', 'real_score_pred_video']

    for latent_key in fake_score_log_keys:
        latents = fake_score_latents_vis_dict[latent_key]
        latents = latents.permute(0, 2, 1, 3, 4)

        if isinstance(self.vae.scaling_factor, torch.Tensor):
            latents = latents / self.vae.scaling_factor.to(
                latents.device, latents.dtype)
        else:
            latents = latents / self.vae.scaling_factor

        # Apply shifting if needed
        if (hasattr(self.vae, "shift_factor")
                and self.vae.shift_factor is not None):
            if isinstance(self.vae.shift_factor, torch.Tensor):
                latents += self.vae.shift_factor.to(latents.device,
                                                    latents.dtype)
            else:
                latents += self.vae.shift_factor
            with torch.autocast("cuda", dtype=torch.bfloat16):
                video = self.vae.decode(latents)
            video = (video / 2 + 0.5).clamp(0, 1)
            video = video.cpu().float()
            video = video.permute(0, 2, 1, 3, 4)
            video = (video * 255).numpy().astype(np.uint8)
            video_artifact = self.tracker.video(
                video, fps=24, format="mp4")  # change to 16 for Wan2.1
            if video_artifact is not None:
                tracker_loss_dict[latent_key] = video_artifact
            # Clean up references
            del video, latents

    # Process DMD training data if available - use decode_stage instead of self.vae.decode
    if 'generator_pred_video' in dmd_latents_vis_dict:
        for latent_key in dmd_log_keys:
            latents = dmd_latents_vis_dict[latent_key]
            latents = latents.permute(0, 2, 1, 3, 4)
            # decoded_latent = decode_stage(ForwardBatch(data_type="video", latents=latents), training_args)
            if isinstance(self.vae.scaling_factor, torch.Tensor):
                latents = latents / self.vae.scaling_factor.to(
                    latents.device, latents.dtype)
            else:
                latents = latents / self.vae.scaling_factor

            # Apply shifting if needed
            if (hasattr(self.vae, "shift_factor")
                    and self.vae.shift_factor is not None):
                if isinstance(self.vae.shift_factor, torch.Tensor):
                    latents += self.vae.shift_factor.to(
                        latents.device, latents.dtype)
                else:
                    latents += self.vae.shift_factor
            with torch.autocast("cuda", dtype=torch.bfloat16):
                video = self.vae.decode(latents)
            video = (video / 2 + 0.5).clamp(0, 1)
            video = video.cpu().float()
            video = video.permute(0, 2, 1, 3, 4)
            video = (video * 255).numpy().astype(np.uint8)
            video_artifact = self.tracker.video(
                video, fps=24, format="mp4")  # change to 16 for Wan2.1
            if video_artifact is not None:
                tracker_loss_dict[latent_key] = video_artifact
            # Clean up references
            del video, latents

    # Log to tracker
    if self.global_rank == 0 and tracker_loss_dict:
        self.tracker.log_artifacts(tracker_loss_dict, step)

fastvideo.training.TrainingPipeline

TrainingPipeline(model_path: str, fastvideo_args: TrainingArgs, required_config_modules: list[str] | None = None, loaded_modules: dict[str, Module] | None = None)

Bases: LoRAPipeline, ABC

A pipeline for training a model. All training pipelines should inherit from this class. All reusable components and code should be implemented in this class.

Source code in fastvideo/training/training_pipeline.py
def __init__(
        self,
        model_path: str,
        fastvideo_args: TrainingArgs,
        required_config_modules: list[str] | None = None,
        loaded_modules: dict[str, torch.nn.Module] | None = None) -> None:
    fastvideo_args.inference_mode = False
    self.lora_training = fastvideo_args.lora_training
    if self.lora_training and fastvideo_args.lora_rank is None:
        raise ValueError("lora rank must be set when using lora training")

    set_random_seed(fastvideo_args.seed)  # for lora param init
    super().__init__(model_path, fastvideo_args, required_config_modules,
                     loaded_modules)  # type: ignore
    self.tracker = DummyTracker()

Functions

fastvideo.training.TrainingPipeline.visualize_intermediate_latents
visualize_intermediate_latents(training_batch: TrainingBatch, training_args: TrainingArgs, step: int)

Add visualization data to tracker logging and save frames to disk.

Source code in fastvideo/training/training_pipeline.py
def visualize_intermediate_latents(self, training_batch: TrainingBatch,
                                   training_args: TrainingArgs, step: int):
    """Add visualization data to tracker logging and save frames to disk."""
    raise NotImplementedError(
        "Visualize intermediate latents is not implemented for training pipeline"
    )

fastvideo.training.WanTrainingPipeline

WanTrainingPipeline(model_path: str, fastvideo_args: TrainingArgs, required_config_modules: list[str] | None = None, loaded_modules: dict[str, Module] | None = None)

Bases: TrainingPipeline

A training pipeline for Wan.

Source code in fastvideo/training/training_pipeline.py
def __init__(
        self,
        model_path: str,
        fastvideo_args: TrainingArgs,
        required_config_modules: list[str] | None = None,
        loaded_modules: dict[str, torch.nn.Module] | None = None) -> None:
    fastvideo_args.inference_mode = False
    self.lora_training = fastvideo_args.lora_training
    if self.lora_training and fastvideo_args.lora_rank is None:
        raise ValueError("lora rank must be set when using lora training")

    set_random_seed(fastvideo_args.seed)  # for lora param init
    super().__init__(model_path, fastvideo_args, required_config_modules,
                     loaded_modules)  # type: ignore
    self.tracker = DummyTracker()

Functions

fastvideo.training.WanTrainingPipeline.create_training_stages
create_training_stages(training_args: TrainingArgs)

May be used in future refactors.

Source code in fastvideo/training/wan_training_pipeline.py
def create_training_stages(self, training_args: TrainingArgs):
    """
    May be used in future refactors.
    """
    pass

Modules

fastvideo.training.distillation_pipeline

Classes

fastvideo.training.distillation_pipeline.DistillationPipeline
DistillationPipeline(model_path: str, fastvideo_args: TrainingArgs, required_config_modules: list[str] | None = None, loaded_modules: dict[str, Module] | None = None)

Bases: TrainingPipeline

A distillation pipeline for training a 3 step model. Inherits from TrainingPipeline to reuse training infrastructure.

Source code in fastvideo/training/training_pipeline.py
def __init__(
        self,
        model_path: str,
        fastvideo_args: TrainingArgs,
        required_config_modules: list[str] | None = None,
        loaded_modules: dict[str, torch.nn.Module] | None = None) -> None:
    fastvideo_args.inference_mode = False
    self.lora_training = fastvideo_args.lora_training
    if self.lora_training and fastvideo_args.lora_rank is None:
        raise ValueError("lora rank must be set when using lora training")

    set_random_seed(fastvideo_args.seed)  # for lora param init
    super().__init__(model_path, fastvideo_args, required_config_modules,
                     loaded_modules)  # type: ignore
    self.tracker = DummyTracker()
Functions
fastvideo.training.distillation_pipeline.DistillationPipeline.apply_ema_to_model
apply_ema_to_model(model)

Apply EMA weights to the model for validation or inference.

Source code in fastvideo/training/distillation_pipeline.py
def apply_ema_to_model(self, model):
    """Apply EMA weights to the model for validation or inference."""
    if model is self.transformer and self.generator_ema is not None:
        with self.generator_ema.apply_to_model(model):
            return model
    elif model is self.transformer_2 and self.generator_ema_2 is not None:
        with self.generator_ema_2.apply_to_model(model):
            return model
    return model
fastvideo.training.distillation_pipeline.DistillationPipeline.get_ema_2_model_copy
get_ema_2_model_copy() -> Module | None

Get a copy of the transformer_2 model with EMA weights applied.

Source code in fastvideo/training/distillation_pipeline.py
def get_ema_2_model_copy(self) -> torch.nn.Module | None:
    """Get a copy of the transformer_2 model with EMA weights applied."""
    if self.generator_ema_2 is not None and self.transformer_2 is not None:
        ema_2_model = copy.deepcopy(self.transformer_2)
        self.generator_ema_2.copy_to_unwrapped(ema_2_model)
        return ema_2_model
    return None
fastvideo.training.distillation_pipeline.DistillationPipeline.get_ema_model_copy
get_ema_model_copy() -> Module | None

Get a copy of the model with EMA weights applied.

Source code in fastvideo/training/distillation_pipeline.py
def get_ema_model_copy(self) -> torch.nn.Module | None:
    """Get a copy of the model with EMA weights applied."""
    if self.generator_ema is not None:
        ema_model = copy.deepcopy(self.transformer)
        self.generator_ema.copy_to_unwrapped(ema_model)
        return ema_model
    return None
fastvideo.training.distillation_pipeline.DistillationPipeline.get_ema_stats
get_ema_stats() -> dict[str, Any]

Get EMA statistics for monitoring.

Source code in fastvideo/training/distillation_pipeline.py
def get_ema_stats(self) -> dict[str, Any]:
    """Get EMA statistics for monitoring."""
    ema_enabled = self.generator_ema is not None
    ema_2_enabled = self.generator_ema_2 is not None

    if not ema_enabled and not ema_2_enabled:
        return {
            "ema_enabled": False,
            "ema_2_enabled": False,
            "ema_decay": None,
            "ema_start_step": self.training_args.ema_start_step,
            "ema_ready": False,
            "ema_2_ready": False,
            "ema_step": self.current_trainstep,
        }

    return {
        "ema_enabled": ema_enabled,
        "ema_2_enabled": ema_2_enabled,
        "ema_decay": self.training_args.ema_decay,
        "ema_start_step": self.training_args.ema_start_step,
        "ema_ready": self.is_ema_ready() if ema_enabled else False,
        "ema_2_ready": self.is_ema_ready() if ema_2_enabled else False,
        "ema_step": self.current_trainstep,
    }
fastvideo.training.distillation_pipeline.DistillationPipeline.initialize_training_pipeline
initialize_training_pipeline(training_args: TrainingArgs)

Initialize the distillation training pipeline with multiple models.

Source code in fastvideo/training/distillation_pipeline.py
def initialize_training_pipeline(self, training_args: TrainingArgs):
    """Initialize the distillation training pipeline with multiple models."""
    logger.info("Initializing distillation pipeline...")

    super().initialize_training_pipeline(training_args)

    self.noise_scheduler = self.get_module("scheduler")
    self.vae = self.get_module("vae")
    self.vae.requires_grad_(False)

    self.timestep_shift = self.training_args.pipeline_config.flow_shift
    self.noise_scheduler = FlowMatchEulerDiscreteScheduler(
        shift=self.timestep_shift)

    if self.training_args.boundary_ratio is not None:
        self.boundary_timestep = self.training_args.boundary_ratio * self.noise_scheduler.num_train_timesteps
    else:
        self.boundary_timestep = None

    if training_args.real_score_model_path:
        logger.info("Loading real score transformer from: %s",
                    training_args.real_score_model_path)
        training_args.override_transformer_cls_name = "WanTransformer3DModel"
        self.real_score_transformer = self.load_module_from_path(
            training_args.real_score_model_path, "transformer",
            training_args)
        try:
            self.real_score_transformer_2 = self.load_module_from_path(
                training_args.real_score_model_path, "transformer_2",
                training_args)
            logger.info("Loaded real score transformer_2 for MoE support")
        except Exception:
            logger.info(
                "real score transformer_2 not found, using single transformer"
            )
            self.real_score_transformer_2 = None
    else:
        self.real_score_transformer = self.get_module(
            "real_score_transformer")
        self.real_score_transformer_2 = self.get_module(
            "real_score_transformer_2")

    if training_args.fake_score_model_path:
        logger.info("Loading fake score transformer from: %s",
                    training_args.fake_score_model_path)
        training_args.override_transformer_cls_name = "WanTransformer3DModel"
        self.fake_score_transformer = self.load_module_from_path(
            training_args.fake_score_model_path, "transformer",
            training_args)
        try:
            self.fake_score_transformer_2 = self.load_module_from_path(
                training_args.fake_score_model_path, "transformer_2",
                training_args)
            logger.info("Loaded fake score transformer_2 for MoE support")
        except Exception:
            logger.info(
                "fake score transformer_2 not found, using single transformer"
            )
            self.fake_score_transformer_2 = None
    else:
        self.fake_score_transformer = self.get_module(
            "fake_score_transformer")
        self.fake_score_transformer_2 = self.get_module(
            "fake_score_transformer_2")

    self.real_score_transformer.requires_grad_(False)
    self.real_score_transformer.eval()
    if self.real_score_transformer_2 is not None:
        self.real_score_transformer_2.requires_grad_(False)
        self.real_score_transformer_2.eval()

    # Set training modes for fake score transformers (trainable)
    self.fake_score_transformer.requires_grad_(True)
    self.fake_score_transformer.train()
    if self.fake_score_transformer_2 is not None:
        self.fake_score_transformer_2.requires_grad_(True)
        self.fake_score_transformer_2.train()

    if training_args.enable_gradient_checkpointing_type is not None:
        self.fake_score_transformer = apply_activation_checkpointing(
            self.fake_score_transformer,
            checkpointing_type=training_args.
            enable_gradient_checkpointing_type)
        if self.fake_score_transformer_2 is not None:
            self.fake_score_transformer_2 = apply_activation_checkpointing(
                self.fake_score_transformer_2,
                checkpointing_type=training_args.
                enable_gradient_checkpointing_type)

        self.real_score_transformer = apply_activation_checkpointing(
            self.real_score_transformer,
            checkpointing_type=training_args.
            enable_gradient_checkpointing_type)
        if self.real_score_transformer_2 is not None:
            self.real_score_transformer_2 = apply_activation_checkpointing(
                self.real_score_transformer_2,
                checkpointing_type=training_args.
                enable_gradient_checkpointing_type)

    # Initialize optimizers
    fake_score_params = list(
        filter(lambda p: p.requires_grad,
               self.fake_score_transformer.parameters()))

    # Use separate learning rate for fake_score_transformer if specified
    fake_score_lr = training_args.fake_score_learning_rate
    if fake_score_lr == 0.0:
        fake_score_lr = training_args.learning_rate

    betas_str = training_args.fake_score_betas
    betas = tuple(float(x.strip()) for x in betas_str.split(","))

    self.fake_score_optimizer = torch.optim.AdamW(
        fake_score_params,
        lr=fake_score_lr,
        betas=betas,
        weight_decay=training_args.weight_decay,
        eps=1e-8,
    )

    self.fake_score_lr_scheduler = get_scheduler(
        training_args.fake_score_lr_scheduler,
        optimizer=self.fake_score_optimizer,
        num_warmup_steps=training_args.lr_warmup_steps,
        num_training_steps=training_args.max_train_steps,
        num_cycles=training_args.lr_num_cycles,
        power=training_args.lr_power,
        min_lr_ratio=training_args.min_lr_ratio,
        last_epoch=self.init_steps - 1,
    )

    if self.fake_score_transformer_2 is not None:
        fake_score_params_2 = list(
            filter(lambda p: p.requires_grad,
                   self.fake_score_transformer_2.parameters()))
        self.fake_score_optimizer_2 = torch.optim.AdamW(
            fake_score_params_2,
            lr=fake_score_lr,
            betas=betas,
            weight_decay=training_args.weight_decay,
            eps=1e-8,
        )
        self.fake_score_lr_scheduler_2 = get_scheduler(
            training_args.fake_score_lr_scheduler,
            optimizer=self.fake_score_optimizer_2,
            num_warmup_steps=training_args.lr_warmup_steps,
            num_training_steps=training_args.max_train_steps,
            num_cycles=training_args.lr_num_cycles,
            power=training_args.lr_power,
            min_lr_ratio=training_args.min_lr_ratio,
            last_epoch=self.init_steps - 1,
        )

    logger.info(
        "Distillation optimizers initialized: generator and fake_score")

    self.generator_update_interval = self.training_args.generator_update_interval
    logger.info(
        "Distillation pipeline initialized with generator_update_interval=%s",
        self.generator_update_interval)

    self.denoising_step_list = torch.tensor(
        self.training_args.pipeline_config.dmd_denoising_steps,
        dtype=torch.long,
        device=get_local_torch_device())

    if training_args.warp_denoising_step:  # Warp the denoising step according to the scheduler time shift
        timesteps = torch.cat((self.noise_scheduler.timesteps.cpu(),
                               torch.tensor([0],
                                            dtype=torch.float32))).cuda()
        self.denoising_step_list = timesteps[1000 -
                                             self.denoising_step_list]
        logger.info("Warping denoising_step_list")

    self.denoising_step_list = self.denoising_step_list.to(
        get_local_torch_device())
    logger.info("Distillation generator model to %s denoising steps: %s",
                len(self.denoising_step_list), self.denoising_step_list)
    self.num_train_timestep = self.noise_scheduler.num_train_timesteps

    self.min_timestep = int(self.training_args.min_timestep_ratio *
                            self.num_train_timestep)
    self.max_timestep = int(self.training_args.max_timestep_ratio *
                            self.num_train_timestep)

    self.real_score_guidance_scale = self.training_args.real_score_guidance_scale

    self.generator_ema: EMA_FSDP | None = None
    self.generator_ema_2: EMA_FSDP | None = None
    if (self.training_args.ema_decay
            is not None) and (self.training_args.ema_decay > 0.0):
        self.generator_ema = EMA_FSDP(self.transformer,
                                      decay=self.training_args.ema_decay)
        logger.info("Initialized generator EMA with decay=%s",
                    self.training_args.ema_decay)

        # Initialize EMA for transformer_2 if it exists
        if self.transformer_2 is not None:
            self.generator_ema_2 = EMA_FSDP(
                self.transformer_2, decay=self.training_args.ema_decay)
            logger.info("Initialized generator EMA_2 with decay=%s",
                        self.training_args.ema_decay)
    else:
        logger.info("Generator EMA disabled (ema_decay <= 0.0)")
fastvideo.training.distillation_pipeline.DistillationPipeline.initialize_validation_pipeline abstractmethod
initialize_validation_pipeline(training_args: TrainingArgs)

Initialize validation pipeline - must be implemented by subclasses.

Source code in fastvideo/training/distillation_pipeline.py
@abstractmethod
def initialize_validation_pipeline(self, training_args: TrainingArgs):
    """Initialize validation pipeline - must be implemented by subclasses."""
    raise NotImplementedError(
        "Distillation pipelines must implement this method")
fastvideo.training.distillation_pipeline.DistillationPipeline.is_ema_ready
is_ema_ready(current_step: int | None = None)

Check if EMA is ready for use (after ema_start_step).

Source code in fastvideo/training/distillation_pipeline.py
def is_ema_ready(self, current_step: int | None = None):
    """Check if EMA is ready for use (after ema_start_step)."""
    if current_step is None:
        current_step = getattr(self, 'current_trainstep', 0)
    return (self.generator_ema is not None
            and current_step >= self.training_args.ema_start_step)
fastvideo.training.distillation_pipeline.DistillationPipeline.load_module_from_path
load_module_from_path(model_path: str, module_type: str, training_args: TrainingArgs)

Load a module from a specific path using the same loading logic as the pipeline.

Parameters:

Name Type Description Default
model_path str

Path to the model

required
module_type str

Type of module to load (e.g., "transformer")

required
training_args TrainingArgs

Training arguments

required

Returns:

Type Description

The loaded module

Source code in fastvideo/training/distillation_pipeline.py
def load_module_from_path(self, model_path: str, module_type: str,
                          training_args: "TrainingArgs"):
    """
    Load a module from a specific path using the same loading logic as the pipeline.

    Args:
        model_path: Path to the model
        module_type: Type of module to load (e.g., "transformer")
        training_args: Training arguments

    Returns:
        The loaded module
    """
    logger.info("Loading %s from custom path: %s", module_type, model_path)
    # Set flag to prevent custom weight loading for teacher/critic models
    training_args._loading_teacher_critic_model = True

    try:
        from fastvideo.models.loader.component_loader import (
            PipelineComponentLoader)

        # Download the model if it's a Hugging Face model ID
        local_model_path = maybe_download_model(model_path)
        logger.info("Model downloaded/found at: %s", local_model_path)
        config = verify_model_config_and_directory(local_model_path)

        if module_type not in config:
            if hasattr(self, '_extra_config_module_map'
                       ) and module_type in self._extra_config_module_map:
                extra_module = self._extra_config_module_map[module_type]
                if extra_module in config:
                    module_type = extra_module
                    logger.info("Using %s for %s", extra_module,
                                module_type)
                else:
                    raise ValueError(
                        f"Module {module_type} not found in config at {local_model_path}"
                    )
            else:
                raise ValueError(
                    f"Module {module_type} not found in config at {local_model_path}"
                )

        module_info = config[module_type]
        if module_info is None:
            raise ValueError(
                f"Module {module_type} has null value in config at {local_model_path}"
            )

        transformers_or_diffusers, architecture = module_info
        component_path = os.path.join(local_model_path, module_type)
        module = PipelineComponentLoader.load_module(
            module_name=module_type,
            component_model_path=component_path,
            transformers_or_diffusers=transformers_or_diffusers,
            fastvideo_args=training_args,
        )

        logger.info("Successfully loaded %s from %s", module_type,
                    component_path)
        return module
    finally:
        # Always clean up the flag
        if hasattr(training_args, '_loading_teacher_critic_model'):
            delattr(training_args, '_loading_teacher_critic_model')
fastvideo.training.distillation_pipeline.DistillationPipeline.reset_ema
reset_ema()

Reset EMA to current model weights.

Source code in fastvideo/training/distillation_pipeline.py
def reset_ema(self):
    """Reset EMA to current model weights."""
    if self.generator_ema is not None:
        logger.info("Resetting EMA to current model weights")
        self.generator_ema.update(self.transformer)
        # Force update to current weights by setting decay to 0 temporarily
        original_decay = self.generator_ema.decay
        self.generator_ema.decay = 0.0
        self.generator_ema.update(self.transformer)
        self.generator_ema.decay = original_decay
        logger.info("EMA reset completed")
    else:
        logger.warning("Cannot reset EMA: EMA not initialized")

    if self.generator_ema_2 is not None:
        logger.info("Resetting EMA_2 to current model weights")
        self.generator_ema_2.update(self.transformer_2)
        # Force update to current weights by setting decay to 0 temporarily
        original_decay_2 = self.generator_ema_2.decay
        self.generator_ema_2.decay = 0.0
        self.generator_ema_2.update(self.transformer_2)
        self.generator_ema_2.decay = original_decay_2
        logger.info("EMA_2 reset completed")
fastvideo.training.distillation_pipeline.DistillationPipeline.save_ema_weights
save_ema_weights(output_dir: str, step: int)

Save EMA weights separately for inference purposes.

Source code in fastvideo/training/distillation_pipeline.py
def save_ema_weights(self, output_dir: str, step: int):
    """Save EMA weights separately for inference purposes."""
    if self.generator_ema is None and self.generator_ema_2 is None:
        logger.warning("Cannot save EMA weights: No EMA initialized")
        return

    if not self.is_ema_ready():
        logger.warning(
            "Cannot save EMA weights: EMA not ready yet (step < ema_start_step)"
        )
        return

    try:
        # Save main transformer EMA
        if self.generator_ema is not None:
            ema_model = self.get_ema_model_copy()
            if ema_model is None:
                logger.warning("Failed to create EMA model copy")
            else:
                ema_save_dir = os.path.join(output_dir,
                                            f"ema_checkpoint-{step}")
                os.makedirs(ema_save_dir, exist_ok=True)

                # save as diffusers format
                from safetensors.torch import save_file

                from fastvideo.training.training_utils import (
                    custom_to_hf_state_dict, gather_state_dict_on_cpu_rank0)
                cpu_state = gather_state_dict_on_cpu_rank0(ema_model,
                                                           device=None)

                if self.global_rank == 0:
                    weight_path = os.path.join(
                        ema_save_dir, "diffusion_pytorch_model.safetensors")
                    diffusers_state_dict = custom_to_hf_state_dict(
                        cpu_state, ema_model.reverse_param_names_mapping)
                    save_file(diffusers_state_dict, weight_path)

                    config_dict = ema_model.hf_config
                    if "dtype" in config_dict:
                        del config_dict["dtype"]
                    config_path = os.path.join(ema_save_dir, "config.json")
                    with open(config_path, "w") as f:
                        json.dump(config_dict, f, indent=4)

                    logger.info("EMA weights saved to %s", weight_path)

                del ema_model

        # Save transformer_2 EMA
        if self.generator_ema_2 is not None:
            ema_2_model = self.get_ema_2_model_copy()
            if ema_2_model is None:
                logger.warning("Failed to create EMA_2 model copy")
            else:
                ema_2_save_dir = os.path.join(output_dir,
                                              f"ema_2_checkpoint-{step}")
                os.makedirs(ema_2_save_dir, exist_ok=True)

                # save as diffusers format
                from safetensors.torch import save_file

                from fastvideo.training.training_utils import (
                    custom_to_hf_state_dict, gather_state_dict_on_cpu_rank0)
                cpu_state_2 = gather_state_dict_on_cpu_rank0(ema_2_model,
                                                             device=None)

                if self.global_rank == 0:
                    weight_path_2 = os.path.join(
                        ema_2_save_dir,
                        "diffusion_pytorch_model.safetensors")
                    diffusers_state_dict_2 = custom_to_hf_state_dict(
                        cpu_state_2,
                        ema_2_model.reverse_param_names_mapping)
                    save_file(diffusers_state_dict_2, weight_path_2)

                    config_dict_2 = ema_2_model.hf_config
                    if "dtype" in config_dict_2:
                        del config_dict_2["dtype"]
                    config_path_2 = os.path.join(ema_2_save_dir,
                                                 "config.json")
                    with open(config_path_2, "w") as f:
                        json.dump(config_dict_2, f, indent=4)

                    logger.info("EMA_2 weights saved to %s", weight_path_2)

                del ema_2_model

    except Exception as e:
        logger.error("Failed to save EMA weights: %s", str(e))
fastvideo.training.distillation_pipeline.DistillationPipeline.train
train() -> None

Main training loop with distillation-specific logging.

Source code in fastvideo/training/distillation_pipeline.py
def train(self) -> None:
    """Main training loop with distillation-specific logging."""
    assert self.training_args.seed is not None, "seed must be set"
    seed = self.training_args.seed

    # Set the same seed within each SP group to ensure reproducibility
    if self.sp_world_size > 1:
        # Use the same seed for all processes within the same SP group
        sp_group_seed = seed + (self.global_rank // self.sp_world_size)
        set_random_seed(sp_group_seed)
        logger.info("Rank %s: Using SP group seed %s", self.global_rank,
                    sp_group_seed)
    else:
        set_random_seed(seed + self.global_rank)

    # Set random seeds for deterministic training
    self.noise_random_generator = torch.Generator(device="cpu").manual_seed(
        self.seed)
    self.noise_gen_cuda = torch.Generator(device="cuda").manual_seed(
        self.seed)
    self.validation_random_generator = torch.Generator(
        device="cpu").manual_seed(self.seed)
    logger.info("Initialized random seeds with seed: %s", seed)

    # Initialize current_trainstep for EMA ready checks
    #TODO: check if needed
    self.current_trainstep = self.init_steps

    # Resume from checkpoint if specified (this will restore random states)
    if self.training_args.resume_from_checkpoint:
        self._resume_from_checkpoint()
        logger.info("Resumed from checkpoint, random states restored")
    else:
        logger.info("Starting training from scratch")

    self.train_loader_iter = iter(self.train_dataloader)

    step_times: deque[float] = deque(maxlen=100)

    self._log_training_info()
    self._log_validation(self.transformer, self.training_args,
                         self.init_steps)

    progress_bar = tqdm(
        range(0, self.training_args.max_train_steps),
        initial=self.init_steps,
        desc="Steps",
        disable=self.local_rank > 0,
    )

    use_vsa = vsa_available and envs.FASTVIDEO_ATTENTION_BACKEND == "VIDEO_SPARSE_ATTN"
    for step in range(self.init_steps + 1,
                      self.training_args.max_train_steps + 1):
        start_time = time.perf_counter()
        if use_vsa:
            vsa_sparsity = self.training_args.VSA_sparsity
            vsa_decay_rate = self.training_args.VSA_decay_rate
            vsa_decay_interval_steps = self.training_args.VSA_decay_interval_steps
            if vsa_decay_interval_steps > 1:
                current_decay_times = min(step // vsa_decay_interval_steps,
                                          vsa_sparsity // vsa_decay_rate)
                current_vsa_sparsity = current_decay_times * vsa_decay_rate
            else:
                current_vsa_sparsity = vsa_sparsity
        else:
            current_vsa_sparsity = 0.0

        training_batch = TrainingBatch()
        self.current_trainstep = step
        training_batch.current_vsa_sparsity = current_vsa_sparsity

        if (step >= self.training_args.ema_start_step) and \
                (self.generator_ema is None) and (self.training_args.ema_decay > 0):
            self.generator_ema = EMA_FSDP(
                self.transformer, decay=self.training_args.ema_decay)
            logger.info("Created generator EMA at step %s with decay=%s",
                        step, self.training_args.ema_decay)

            # Create EMA for transformer_2 if it exists
            if self.transformer_2 is not None and self.generator_ema_2 is None:
                self.generator_ema_2 = EMA_FSDP(
                    self.transformer_2, decay=self.training_args.ema_decay)
                logger.info(
                    "Created generator EMA_2 at step %s with decay=%s",
                    step, self.training_args.ema_decay)

        with torch.autocast("cuda", dtype=torch.bfloat16):
            training_batch = self.train_one_step(training_batch)

        total_loss = training_batch.total_loss
        generator_loss = training_batch.generator_loss
        fake_score_loss = training_batch.fake_score_loss
        grad_norm = training_batch.grad_norm

        step_time = time.perf_counter() - start_time
        step_times.append(step_time)
        avg_step_time = sum(step_times) / len(step_times)

        progress_bar.set_postfix({
            "total_loss":
            f"{total_loss:.4f}",
            "generator_loss":
            f"{generator_loss:.4f}",
            "fake_score_loss":
            f"{fake_score_loss:.4f}",
            "step_time":
            f"{step_time:.2f}s",
            "grad_norm":
            grad_norm,
            "ema":
            "✓" if (self.generator_ema is not None and self.is_ema_ready())
            else "✗",
            "ema2":
            "✓" if (self.generator_ema_2 is not None
                    and self.is_ema_ready()) else "✗",
        })
        progress_bar.update(1)

        if self.global_rank == 0:
            # Prepare logging data
            log_data = {
                "train_total_loss":
                total_loss,
                "train_fake_score_loss":
                fake_score_loss,
                "learning_rate":
                self.lr_scheduler.get_last_lr()[0],
                "fake_score_learning_rate":
                self.fake_score_lr_scheduler.get_last_lr()[0],
                "step_time":
                step_time,
                "avg_step_time":
                avg_step_time,
                "grad_norm":
                grad_norm,
            }
            # Only log generator loss when generator is actually trained
            if (step % self.generator_update_interval == 0):
                log_data["train_generator_loss"] = generator_loss
            if use_vsa:
                log_data["VSA_train_sparsity"] = current_vsa_sparsity

            if self.generator_ema is not None or self.generator_ema_2 is not None:
                log_data["ema_enabled"] = self.generator_ema is not None
                log_data["ema_2_enabled"] = self.generator_ema_2 is not None
                log_data["ema_decay"] = self.training_args.ema_decay
            else:
                log_data["ema_enabled"] = False
                log_data["ema_2_enabled"] = False

            ema_stats = self.get_ema_stats()
            log_data.update(ema_stats)

            if training_batch.dmd_latent_vis_dict:
                dmd_additional_logs = {
                    "generator_timestep":
                    training_batch.
                    dmd_latent_vis_dict["generator_timestep"].item(),
                    "dmd_timestep":
                    training_batch.dmd_latent_vis_dict["dmd_timestep"].item(
                    ),
                }
                log_data.update(dmd_additional_logs)

            faker_score_additional_logs = {
                "fake_score_timestep":
                training_batch.
                fake_score_latent_vis_dict["fake_score_timestep"].item(),
            }
            log_data.update(faker_score_additional_logs)

            self.tracker.log(log_data, step)

        # Save training state checkpoint (for resuming training)
        if (self.training_args.training_state_checkpointing_steps > 0
                and step %
                self.training_args.training_state_checkpointing_steps == 0):
            print("rank", self.global_rank,
                  "save training state checkpoint at step", step)
            save_distillation_checkpoint(
                self.transformer,
                self.fake_score_transformer,
                self.global_rank,
                self.training_args.output_dir,
                step,
                self.optimizer,
                self.fake_score_optimizer,
                self.train_dataloader,
                self.lr_scheduler,
                self.fake_score_lr_scheduler,
                self.noise_random_generator,
                self.generator_ema,
                # MoE support
                generator_transformer_2=getattr(self, 'transformer_2',
                                                None),
                real_score_transformer_2=getattr(
                    self, 'real_score_transformer_2', None),
                fake_score_transformer_2=getattr(
                    self, 'fake_score_transformer_2', None),
                generator_optimizer_2=getattr(self, 'optimizer_2', None),
                fake_score_optimizer_2=getattr(self,
                                               'fake_score_optimizer_2',
                                               None),
                generator_scheduler_2=getattr(self, 'lr_scheduler_2', None),
                fake_score_scheduler_2=getattr(self,
                                               'fake_score_lr_scheduler_2',
                                               None),
                generator_ema_2=getattr(self, 'generator_ema_2', None))

            if self.transformer:
                self.transformer.train()
            self.sp_group.barrier()

        # Save weight-only checkpoint
        if (self.training_args.weight_only_checkpointing_steps > 0
                and step %
                self.training_args.weight_only_checkpointing_steps == 0):
            print("rank", self.global_rank,
                  "save weight-only checkpoint at step", step)
            save_distillation_checkpoint(
                self.transformer,
                self.fake_score_transformer,
                self.global_rank,
                self.training_args.output_dir,
                f"{step}_weight_only",
                only_save_generator_weight=True,
                generator_ema=self.generator_ema,
                # MoE support
                generator_transformer_2=getattr(self, 'transformer_2',
                                                None),
                real_score_transformer_2=getattr(
                    self, 'real_score_transformer_2', None),
                fake_score_transformer_2=getattr(
                    self, 'fake_score_transformer_2', None),
                generator_optimizer_2=getattr(self, 'optimizer_2', None),
                fake_score_optimizer_2=getattr(self,
                                               'fake_score_optimizer_2',
                                               None),
                generator_scheduler_2=getattr(self, 'lr_scheduler_2', None),
                fake_score_scheduler_2=getattr(self,
                                               'fake_score_lr_scheduler_2',
                                               None),
                generator_ema_2=getattr(self, 'generator_ema_2', None))

            if self.training_args.use_ema and self.is_ema_ready():
                self.save_ema_weights(self.training_args.output_dir, step)

        if self.training_args.log_validation and step % self.training_args.validation_steps == 0:
            if self.training_args.log_visualization:
                self.visualize_intermediate_latents(training_batch,
                                                    self.training_args,
                                                    step)
            self._log_validation(self.transformer, self.training_args, step)

    self.tracker.finish()

    # Save final training state checkpoint
    print("rank", self.global_rank,
          "save final training state checkpoint at step",
          self.training_args.max_train_steps)
    save_distillation_checkpoint(
        self.transformer,
        self.fake_score_transformer,
        self.global_rank,
        self.training_args.output_dir,
        self.training_args.max_train_steps,
        self.optimizer,
        self.fake_score_optimizer,
        self.train_dataloader,
        self.lr_scheduler,
        self.fake_score_lr_scheduler,
        self.noise_random_generator,
        self.generator_ema,
        # MoE support
        generator_transformer_2=getattr(self, 'transformer_2', None),
        real_score_transformer_2=getattr(self, 'real_score_transformer_2',
                                         None),
        fake_score_transformer_2=getattr(self, 'fake_score_transformer_2',
                                         None),
        generator_optimizer_2=getattr(self, 'optimizer_2', None),
        fake_score_optimizer_2=getattr(self, 'fake_score_optimizer_2',
                                       None),
        generator_scheduler_2=getattr(self, 'lr_scheduler_2', None),
        fake_score_scheduler_2=getattr(self, 'fake_score_lr_scheduler_2',
                                       None),
        generator_ema_2=getattr(self, 'generator_ema_2', None))

    if self.training_args.use_ema and self.is_ema_ready():
        self.save_ema_weights(self.training_args.output_dir,
                              self.training_args.max_train_steps)

    if envs.FASTVIDEO_TORCH_PROFILER_DIR:
        logger.info("Stopping profiler...")
        self.profiler_controller.stop()
        logger.info("Profiler stopped.")

    if get_sp_group():
        cleanup_dist_env_and_memory()
fastvideo.training.distillation_pipeline.DistillationPipeline.visualize_intermediate_latents
visualize_intermediate_latents(training_batch: TrainingBatch, training_args: TrainingArgs, step: int)

Add visualization data to tracker logging and save frames to disk.

Source code in fastvideo/training/distillation_pipeline.py
def visualize_intermediate_latents(self, training_batch: TrainingBatch,
                                   training_args: TrainingArgs, step: int):
    """Add visualization data to tracker logging and save frames to disk."""
    tracker_loss_dict: dict[str, Any] = {}
    dmd_latents_vis_dict = training_batch.dmd_latent_vis_dict
    fake_score_latents_vis_dict = training_batch.fake_score_latent_vis_dict
    fake_score_log_keys = ['generator_pred_video']
    dmd_log_keys = ['faker_score_pred_video', 'real_score_pred_video']

    for latent_key in fake_score_log_keys:
        latents = fake_score_latents_vis_dict[latent_key]
        latents = latents.permute(0, 2, 1, 3, 4)

        if isinstance(self.vae.scaling_factor, torch.Tensor):
            latents = latents / self.vae.scaling_factor.to(
                latents.device, latents.dtype)
        else:
            latents = latents / self.vae.scaling_factor

        # Apply shifting if needed
        if (hasattr(self.vae, "shift_factor")
                and self.vae.shift_factor is not None):
            if isinstance(self.vae.shift_factor, torch.Tensor):
                latents += self.vae.shift_factor.to(latents.device,
                                                    latents.dtype)
            else:
                latents += self.vae.shift_factor
            with torch.autocast("cuda", dtype=torch.bfloat16):
                video = self.vae.decode(latents)
            video = (video / 2 + 0.5).clamp(0, 1)
            video = video.cpu().float()
            video = video.permute(0, 2, 1, 3, 4)
            video = (video * 255).numpy().astype(np.uint8)
            video_artifact = self.tracker.video(
                video, fps=24, format="mp4")  # change to 16 for Wan2.1
            if video_artifact is not None:
                tracker_loss_dict[latent_key] = video_artifact
            # Clean up references
            del video, latents

    # Process DMD training data if available - use decode_stage instead of self.vae.decode
    if 'generator_pred_video' in dmd_latents_vis_dict:
        for latent_key in dmd_log_keys:
            latents = dmd_latents_vis_dict[latent_key]
            latents = latents.permute(0, 2, 1, 3, 4)
            # decoded_latent = decode_stage(ForwardBatch(data_type="video", latents=latents), training_args)
            if isinstance(self.vae.scaling_factor, torch.Tensor):
                latents = latents / self.vae.scaling_factor.to(
                    latents.device, latents.dtype)
            else:
                latents = latents / self.vae.scaling_factor

            # Apply shifting if needed
            if (hasattr(self.vae, "shift_factor")
                    and self.vae.shift_factor is not None):
                if isinstance(self.vae.shift_factor, torch.Tensor):
                    latents += self.vae.shift_factor.to(
                        latents.device, latents.dtype)
                else:
                    latents += self.vae.shift_factor
            with torch.autocast("cuda", dtype=torch.bfloat16):
                video = self.vae.decode(latents)
            video = (video / 2 + 0.5).clamp(0, 1)
            video = video.cpu().float()
            video = video.permute(0, 2, 1, 3, 4)
            video = (video * 255).numpy().astype(np.uint8)
            video_artifact = self.tracker.video(
                video, fps=24, format="mp4")  # change to 16 for Wan2.1
            if video_artifact is not None:
                tracker_loss_dict[latent_key] = video_artifact
            # Clean up references
            del video, latents

    # Log to tracker
    if self.global_rank == 0 and tracker_loss_dict:
        self.tracker.log_artifacts(tracker_loss_dict, step)

Functions

fastvideo.training.ode_causal_pipeline

Classes

fastvideo.training.ode_causal_pipeline.ODEInitTrainingPipeline
ODEInitTrainingPipeline(model_path: str, fastvideo_args: TrainingArgs, required_config_modules: list[str] | None = None, loaded_modules: dict[str, Module] | None = None)

Bases: TrainingPipeline

Training pipeline for ODE-init using precomputed denoising trajectories.

Supervision: predict the next latent in the stored trajectory by - feeding current latent at timestep t into the transformer to predict noise - stepping the scheduler with the predicted noise - minimizing MSE to the stored next latent at timestep t_next

Source code in fastvideo/training/training_pipeline.py
def __init__(
        self,
        model_path: str,
        fastvideo_args: TrainingArgs,
        required_config_modules: list[str] | None = None,
        loaded_modules: dict[str, torch.nn.Module] | None = None) -> None:
    fastvideo_args.inference_mode = False
    self.lora_training = fastvideo_args.lora_training
    if self.lora_training and fastvideo_args.lora_rank is None:
        raise ValueError("lora rank must be set when using lora training")

    set_random_seed(fastvideo_args.seed)  # for lora param init
    super().__init__(model_path, fastvideo_args, required_config_modules,
                     loaded_modules)  # type: ignore
    self.tracker = DummyTracker()

Functions

fastvideo.training.self_forcing_distillation_pipeline

Classes

fastvideo.training.self_forcing_distillation_pipeline.SelfForcingDistillationPipeline
SelfForcingDistillationPipeline(model_path: str, fastvideo_args: TrainingArgs, required_config_modules: list[str] | None = None, loaded_modules: dict[str, Module] | None = None)

Bases: DistillationPipeline

A self-forcing distillation pipeline that alternates between training the generator and critic based on the self-forcing methodology.

This implementation follows the self-forcing approach where: 1. Generator and critic are trained in alternating steps 2. Generator loss uses DMD-style loss with the critic as fake score 3. Critic loss trains the fake score model to distinguish real vs fake

Source code in fastvideo/training/training_pipeline.py
def __init__(
        self,
        model_path: str,
        fastvideo_args: TrainingArgs,
        required_config_modules: list[str] | None = None,
        loaded_modules: dict[str, torch.nn.Module] | None = None) -> None:
    fastvideo_args.inference_mode = False
    self.lora_training = fastvideo_args.lora_training
    if self.lora_training and fastvideo_args.lora_rank is None:
        raise ValueError("lora rank must be set when using lora training")

    set_random_seed(fastvideo_args.seed)  # for lora param init
    super().__init__(model_path, fastvideo_args, required_config_modules,
                     loaded_modules)  # type: ignore
    self.tracker = DummyTracker()
Functions
fastvideo.training.self_forcing_distillation_pipeline.SelfForcingDistillationPipeline.critic_loss
critic_loss(training_batch: TrainingBatch) -> tuple[Tensor, dict[str, Any]]

Compute critic loss using flow matching between noise and generator output. The critic learns to predict the flow from noise to the generator's output.

Source code in fastvideo/training/self_forcing_distillation_pipeline.py
def critic_loss(
        self, training_batch: TrainingBatch
) -> tuple[torch.Tensor, dict[str, Any]]:
    """
    Compute critic loss using flow matching between noise and generator output.
    The critic learns to predict the flow from noise to the generator's output.
    """
    updated_batch, flow_matching_loss = self.faker_score_forward(
        training_batch)
    training_batch.fake_score_latent_vis_dict = updated_batch.fake_score_latent_vis_dict
    log_dict: dict[str, Any] = {}

    return flow_matching_loss, log_dict
fastvideo.training.self_forcing_distillation_pipeline.SelfForcingDistillationPipeline.generate_and_sync_list
generate_and_sync_list(num_blocks: int, num_denoising_steps: int, device: device) -> list[int]

Generate and synchronize random exit flags across distributed processes.

Source code in fastvideo/training/self_forcing_distillation_pipeline.py
def generate_and_sync_list(self, num_blocks: int, num_denoising_steps: int,
                           device: torch.device) -> list[int]:
    """Generate and synchronize random exit flags across distributed processes."""
    logger.info(
        "RANK: %s, enter generate_and_sync_list blocks=%s steps=%s device=%s",
        self.global_rank,
        num_blocks,
        num_denoising_steps,
        str(device),
        local_main_process_only=False)
    rank = dist.get_rank() if dist.is_initialized() else 0

    if rank == 0:
        # Generate random indices
        indices = torch.randint(low=0,
                                high=num_denoising_steps,
                                size=(num_blocks, ),
                                device=device)
        if self.last_step_only:
            indices = torch.ones_like(indices) * (num_denoising_steps - 1)
    else:
        indices = torch.empty(num_blocks, dtype=torch.long, device=device)

    if dist.is_initialized():
        dist.broadcast(indices,
                       src=0)  # Broadcast the random indices to all ranks
    flags = indices.tolist()
    logger.info(
        "RANK: %s, exit generate_and_sync_list flags_len=%s first=%s",
        self.global_rank,
        len(flags),
        flags[0] if len(flags) > 0 else None,
        local_main_process_only=False)
    return flags
fastvideo.training.self_forcing_distillation_pipeline.SelfForcingDistillationPipeline.generator_loss
generator_loss(training_batch: TrainingBatch) -> tuple[Tensor, dict[str, Any]]

Compute generator loss using DMD-style approach. The generator tries to fool the critic (fake_score_transformer).

Source code in fastvideo/training/self_forcing_distillation_pipeline.py
def generator_loss(
        self, training_batch: TrainingBatch
) -> tuple[torch.Tensor, dict[str, Any]]:
    """
    Compute generator loss using DMD-style approach.
    The generator tries to fool the critic (fake_score_transformer).
    """
    with set_forward_context(
            current_timestep=training_batch.timesteps,
            attn_metadata=training_batch.attn_metadata_vsa):
        generator_pred_video = self._generator_multi_step_simulation_forward(
            training_batch)

    with set_forward_context(current_timestep=training_batch.timesteps,
                             attn_metadata=training_batch.attn_metadata):
        dmd_loss = self._dmd_forward(
            generator_pred_video=generator_pred_video,
            training_batch=training_batch)

    log_dict = {
        "dmdtrain_gradient_norm": torch.tensor(0.0, device=self.device)
    }

    return dmd_loss, log_dict
fastvideo.training.self_forcing_distillation_pipeline.SelfForcingDistillationPipeline.initialize_training_pipeline
initialize_training_pipeline(training_args: TrainingArgs)

Initialize the self-forcing training pipeline.

Source code in fastvideo/training/self_forcing_distillation_pipeline.py
def initialize_training_pipeline(self, training_args: TrainingArgs):
    """Initialize the self-forcing training pipeline."""
    # Check if FSDP2 auto wrap is enabled - not supported for self-forcing distillation
    if os.environ.get("FASTVIDEO_FSDP2_AUTOWRAP", "0") == "1":
        raise NotImplementedError(
            "FASTVIDEO_FSDP2_AUTOWRAP is not implemented for self-forcing distillation. "
            "Please set FASTVIDEO_FSDP2_AUTOWRAP=0 or unset the environment variable."
        )

    logger.info("Initializing self-forcing distillation pipeline...")

    self.generator_ema: EMA_FSDP | None = None
    self.generator_ema_2: EMA_FSDP | None = None

    super().initialize_training_pipeline(training_args)
    try:
        logger.info("RANK: %s, entered initialize_training_pipeline",
                    self.global_rank,
                    local_main_process_only=False)
    except Exception:
        logger.info("Entered initialize_training_pipeline (rank unknown)")

    self.noise_scheduler = SelfForcingFlowMatchScheduler(
        num_inference_steps=1000,
        shift=5.0,
        sigma_min=0.0,
        extra_one_step=True,
        training=True)
    self.dfake_gen_update_ratio = getattr(training_args,
                                          'dfake_gen_update_ratio', 5)

    self.num_frame_per_block = getattr(training_args, 'num_frame_per_block',
                                       3)
    self.independent_first_frame = getattr(training_args,
                                           'independent_first_frame', False)
    self.same_step_across_blocks = getattr(training_args,
                                           'same_step_across_blocks', False)
    self.last_step_only = getattr(training_args, 'last_step_only', False)
    self.context_noise = getattr(training_args, 'context_noise', 0)

    self.kv_cache1: list[dict[str, Any]] | None = None
    self.crossattn_cache: list[dict[str, Any]] | None = None

    logger.info("Self-forcing generator update ratio: %s",
                self.dfake_gen_update_ratio)
    logger.info("RANK: %s, exiting initialize_training_pipeline",
                self.global_rank,
                local_main_process_only=False)
fastvideo.training.self_forcing_distillation_pipeline.SelfForcingDistillationPipeline.train
train() -> None

Main training loop with self-forcing specific logging.

Source code in fastvideo/training/self_forcing_distillation_pipeline.py
@profile_region("profiler_region_training_train")
def train(self) -> None:
    """Main training loop with self-forcing specific logging."""
    assert self.training_args.seed is not None, "seed must be set"
    seed = self.training_args.seed

    # Set the same seed within each SP group to ensure reproducibility
    if self.sp_world_size > 1:
        # Use the same seed for all processes within the same SP group
        sp_group_seed = seed + (self.global_rank // self.sp_world_size)
        set_random_seed(sp_group_seed)
    else:
        set_random_seed(seed + self.global_rank)

    self.noise_random_generator = torch.Generator(device="cpu").manual_seed(
        self.seed)
    self.noise_gen_cuda = torch.Generator(device="cuda").manual_seed(
        self.seed)
    self.validation_random_generator = torch.Generator(
        device="cpu").manual_seed(self.seed)
    logger.info("Initialized random seeds with seed: %s", seed)

    self.current_trainstep = self.init_steps

    if self.training_args.resume_from_checkpoint:
        self._resume_from_checkpoint()
        logger.info("Resumed from checkpoint, random states restored")
    else:
        logger.info("Starting training from scratch")

    self.train_loader_iter = iter(self.train_dataloader)

    step_times: deque[float] = deque(maxlen=100)

    self._log_training_info()
    self._log_validation(self.transformer, self.training_args,
                         self.init_steps)

    progress_bar = tqdm(
        range(0, self.training_args.max_train_steps),
        initial=self.init_steps,
        desc="Steps",
        disable=self.local_rank > 0,
    )

    use_vsa = vsa_available and envs.FASTVIDEO_ATTENTION_BACKEND == "VIDEO_SPARSE_ATTN"
    for step in range(self.init_steps + 1,
                      self.training_args.max_train_steps + 1):
        start_time = time.perf_counter()
        if use_vsa:
            vsa_sparsity = self.training_args.VSA_sparsity
            vsa_decay_rate = self.training_args.VSA_decay_rate
            vsa_decay_interval_steps = self.training_args.VSA_decay_interval_steps
            if vsa_decay_interval_steps > 1:
                current_decay_times = min(step // vsa_decay_interval_steps,
                                          vsa_sparsity // vsa_decay_rate)
                current_vsa_sparsity = current_decay_times * vsa_decay_rate
            else:
                current_vsa_sparsity = vsa_sparsity
        else:
            current_vsa_sparsity = 0.0

        training_batch = TrainingBatch()
        self.current_trainstep = step
        training_batch.current_vsa_sparsity = current_vsa_sparsity

        if (step >= self.training_args.ema_start_step) and \
                (self.generator_ema is None) and (self.training_args.ema_decay > 0):
            self.generator_ema = EMA_FSDP(
                self.transformer, decay=self.training_args.ema_decay)
            logger.info("Created generator EMA at step %s with decay=%s",
                        step, self.training_args.ema_decay)

            # Create EMA for transformer_2 if it exists
            if self.transformer_2 is not None and self.generator_ema_2 is None:
                self.generator_ema_2 = EMA_FSDP(
                    self.transformer_2, decay=self.training_args.ema_decay)
                logger.info(
                    "Created generator EMA_2 at step %s with decay=%s",
                    step, self.training_args.ema_decay)

        with torch.autocast("cuda", dtype=torch.bfloat16):
            training_batch = self.train_one_step(training_batch)

        total_loss = training_batch.total_loss
        generator_loss = training_batch.generator_loss
        fake_score_loss = training_batch.fake_score_loss
        grad_norm = training_batch.grad_norm

        step_time = time.perf_counter() - start_time
        step_times.append(step_time)
        avg_step_time = sum(step_times) / len(step_times)

        progress_bar.set_postfix({
            "total_loss":
            f"{total_loss:.4f}",
            "generator_loss":
            f"{generator_loss:.4f}",
            "fake_score_loss":
            f"{fake_score_loss:.4f}",
            "step_time":
            f"{step_time:.2f}s",
            "grad_norm":
            grad_norm,
            "ema":
            "✓" if (self.generator_ema is not None and self.is_ema_ready())
            else "✗",
            "ema2":
            "✓" if (self.generator_ema_2 is not None
                    and self.is_ema_ready()) else "✗",
        })
        progress_bar.update(1)

        if self.global_rank == 0:
            log_data = {
                "train_total_loss":
                total_loss,
                "train_fake_score_loss":
                fake_score_loss,
                "learning_rate":
                self.lr_scheduler.get_last_lr()[0],
                "fake_score_learning_rate":
                self.fake_score_lr_scheduler.get_last_lr()[0],
                "step_time":
                step_time,
                "avg_step_time":
                avg_step_time,
                "grad_norm":
                grad_norm,
            }
            if (step % self.dfake_gen_update_ratio == 0):
                log_data["train_generator_loss"] = generator_loss
            if use_vsa:
                log_data["VSA_train_sparsity"] = current_vsa_sparsity

            if self.generator_ema is not None or self.generator_ema_2 is not None:
                log_data["ema_enabled"] = self.generator_ema is not None
                log_data["ema_2_enabled"] = self.generator_ema_2 is not None
                log_data["ema_decay"] = self.training_args.ema_decay
            else:
                log_data["ema_enabled"] = False
                log_data["ema_2_enabled"] = False

            ema_stats = self.get_ema_stats()
            log_data.update(ema_stats)

            if training_batch.dmd_latent_vis_dict:
                dmd_additional_logs = {
                    "generator_timestep":
                    training_batch.
                    dmd_latent_vis_dict["generator_timestep"].item(),
                    "dmd_timestep":
                    training_batch.dmd_latent_vis_dict["dmd_timestep"].item(
                    ),
                }
                log_data.update(dmd_additional_logs)

            faker_score_additional_logs = {
                "fake_score_timestep":
                training_batch.
                fake_score_latent_vis_dict["fake_score_timestep"].item(),
            }
            log_data.update(faker_score_additional_logs)

            self.tracker.log(log_data, step)

            if self.training_args.log_validation and step % self.training_args.validation_steps == 0 and self.training_args.log_visualization:
                self.visualize_intermediate_latents(training_batch,
                                                    self.training_args,
                                                    step)

        if (self.training_args.training_state_checkpointing_steps > 0
                and step %
                self.training_args.training_state_checkpointing_steps == 0):
            print("rank", self.global_rank,
                  "save training state checkpoint at step", step)
            save_distillation_checkpoint(
                self.transformer,
                self.fake_score_transformer,
                self.global_rank,
                self.training_args.output_dir,
                step,
                self.optimizer,
                self.fake_score_optimizer,
                self.train_dataloader,
                self.lr_scheduler,
                self.fake_score_lr_scheduler,
                self.noise_random_generator,
                self.generator_ema,
                # MoE support
                generator_transformer_2=getattr(self, 'transformer_2',
                                                None),
                real_score_transformer_2=getattr(
                    self, 'real_score_transformer_2', None),
                fake_score_transformer_2=getattr(
                    self, 'fake_score_transformer_2', None),
                generator_optimizer_2=getattr(self, 'optimizer_2', None),
                fake_score_optimizer_2=getattr(self,
                                               'fake_score_optimizer_2',
                                               None),
                generator_scheduler_2=getattr(self, 'lr_scheduler_2', None),
                fake_score_scheduler_2=getattr(self,
                                               'fake_score_lr_scheduler_2',
                                               None),
                generator_ema_2=getattr(self, 'generator_ema_2', None))

            if self.transformer:
                self.transformer.train()
            self.sp_group.barrier()

        if (self.training_args.weight_only_checkpointing_steps > 0
                and step %
                self.training_args.weight_only_checkpointing_steps == 0):
            print("rank", self.global_rank,
                  "save weight-only checkpoint at step", step)
            save_distillation_checkpoint(
                self.transformer,
                self.fake_score_transformer,
                self.global_rank,
                self.training_args.output_dir,
                f"{step}_weight_only",
                only_save_generator_weight=True,
                generator_ema=self.generator_ema,
                # MoE support
                generator_transformer_2=getattr(self, 'transformer_2',
                                                None),
                real_score_transformer_2=getattr(
                    self, 'real_score_transformer_2', None),
                fake_score_transformer_2=getattr(
                    self, 'fake_score_transformer_2', None),
                generator_optimizer_2=getattr(self, 'optimizer_2', None),
                fake_score_optimizer_2=getattr(self,
                                               'fake_score_optimizer_2',
                                               None),
                generator_scheduler_2=getattr(self, 'lr_scheduler_2', None),
                fake_score_scheduler_2=getattr(self,
                                               'fake_score_lr_scheduler_2',
                                               None),
                generator_ema_2=getattr(self, 'generator_ema_2', None))

            if self.training_args.use_ema and self.is_ema_ready():
                self.save_ema_weights(self.training_args.output_dir, step)

        if self.training_args.log_validation and step % self.training_args.validation_steps == 0:
            self._log_validation(self.transformer, self.training_args, step)

    self.tracker.finish()

    print("rank", self.global_rank,
          "save final training state checkpoint at step",
          self.training_args.max_train_steps)
    save_distillation_checkpoint(
        self.transformer,
        self.fake_score_transformer,
        self.global_rank,
        self.training_args.output_dir,
        self.training_args.max_train_steps,
        self.optimizer,
        self.fake_score_optimizer,
        self.train_dataloader,
        self.lr_scheduler,
        self.fake_score_lr_scheduler,
        self.noise_random_generator,
        self.generator_ema,
        # MoE support
        generator_transformer_2=getattr(self, 'transformer_2', None),
        real_score_transformer_2=getattr(self, 'real_score_transformer_2',
                                         None),
        fake_score_transformer_2=getattr(self, 'fake_score_transformer_2',
                                         None),
        generator_optimizer_2=getattr(self, 'optimizer_2', None),
        fake_score_optimizer_2=getattr(self, 'fake_score_optimizer_2',
                                       None),
        generator_scheduler_2=getattr(self, 'lr_scheduler_2', None),
        fake_score_scheduler_2=getattr(self, 'fake_score_lr_scheduler_2',
                                       None),
        generator_ema_2=getattr(self, 'generator_ema_2', None))

    if self.training_args.use_ema and self.is_ema_ready():
        self.save_ema_weights(self.training_args.output_dir,
                              self.training_args.max_train_steps)

    if envs.FASTVIDEO_TORCH_PROFILER_DIR:
        logger.info("Stopping profiler...")
        self.profiler_controller.stop()
        logger.info("Profiler stopped.")

    if get_sp_group():
        cleanup_dist_env_and_memory()
fastvideo.training.self_forcing_distillation_pipeline.SelfForcingDistillationPipeline.train_one_step
train_one_step(training_batch: TrainingBatch) -> TrainingBatch

Self-forcing training step that alternates between generator and critic training.

Source code in fastvideo/training/self_forcing_distillation_pipeline.py
def train_one_step(self, training_batch: TrainingBatch) -> TrainingBatch:
    """
    Self-forcing training step that alternates between generator and critic training.
    """
    gradient_accumulation_steps = getattr(self.training_args,
                                          'gradient_accumulation_steps', 1)
    train_generator = (self.current_trainstep %
                       self.dfake_gen_update_ratio == 0)

    batches = []
    for _ in range(gradient_accumulation_steps):
        batch = self._prepare_distillation(training_batch)
        batch = self._get_next_batch(batch)
        batch = self._normalize_dit_input(batch)
        batch = self._prepare_dit_inputs(batch)
        batch = self._build_attention_metadata(batch)
        batch.attn_metadata_vsa = copy.deepcopy(batch.attn_metadata)
        if batch.attn_metadata is not None:
            batch.attn_metadata.VSA_sparsity = 0.0
        batches.append(batch)

    training_batch.dmd_latent_vis_dict = {}
    training_batch.fake_score_latent_vis_dict = {}

    if train_generator:
        logger.debug("Training generator at step %s",
                     self.current_trainstep)
        self.optimizer.zero_grad()
        if self.transformer_2 is not None:
            self.optimizer_2.zero_grad()
        total_generator_loss = 0.0
        generator_log_dict = {}

        for batch in batches:
            # Create a new batch with detached tensors
            batch_gen = TrainingBatch()
            for key, value in batch.__dict__.items():
                if isinstance(value, torch.Tensor):
                    setattr(batch_gen, key, value.detach().clone())
                elif isinstance(value, dict):
                    setattr(
                        batch_gen, key, {
                            k:
                            v.detach().clone() if isinstance(
                                v, torch.Tensor) else copy.deepcopy(v)
                            for k, v in value.items()
                        })
                else:
                    setattr(batch_gen, key, copy.deepcopy(value))

            generator_loss, gen_log_dict = self.generator_loss(batch_gen)
            with set_forward_context(current_timestep=batch_gen.timesteps,
                                     attn_metadata=batch_gen.attn_metadata):
                (generator_loss / gradient_accumulation_steps).backward()
            total_generator_loss += generator_loss.detach().item()
            generator_log_dict.update(gen_log_dict)
            # Store visualization data from generator training
            if hasattr(batch_gen, 'dmd_latent_vis_dict'):
                training_batch.dmd_latent_vis_dict.update(
                    batch_gen.dmd_latent_vis_dict)

        # Only clip gradients and step optimizer for the model that is currently training
        if hasattr(
                self, 'train_transformer_2'
        ) and self.train_transformer_2 and self.transformer_2 is not None:
            self._clip_model_grad_norm_(batch_gen, self.transformer_2)
            self.optimizer_2.step()
            self.lr_scheduler_2.step()
        else:
            self._clip_model_grad_norm_(batch_gen, self.transformer)
            self.optimizer.step()
            self.lr_scheduler.step()

        if self.generator_ema is not None:
            if hasattr(
                    self, 'train_transformer_2'
            ) and self.train_transformer_2 and self.transformer_2 is not None:
                # Update EMA for transformer_2 when training it
                if self.generator_ema_2 is not None:
                    self.generator_ema_2.update(self.transformer_2)
            else:
                self.generator_ema.update(self.transformer)

        avg_generator_loss = torch.tensor(total_generator_loss /
                                          gradient_accumulation_steps,
                                          device=self.device)
        world_group = get_world_group()
        world_group.all_reduce(avg_generator_loss,
                               op=torch.distributed.ReduceOp.AVG)
        training_batch.generator_loss = avg_generator_loss.item()
    else:
        training_batch.generator_loss = 0.0

    logger.debug("Training critic at step %s", self.current_trainstep)
    self.fake_score_optimizer.zero_grad()
    total_critic_loss = 0.0
    critic_log_dict = {}

    for batch in batches:
        # Create a new batch with detached tensors
        batch_critic = TrainingBatch()
        for key, value in batch.__dict__.items():
            if isinstance(value, torch.Tensor):
                setattr(batch_critic, key, value.detach().clone())
            elif isinstance(value, dict):
                setattr(
                    batch_critic, key, {
                        k:
                        v.detach().clone()
                        if isinstance(v, torch.Tensor) else copy.deepcopy(v)
                        for k, v in value.items()
                    })
            else:
                setattr(batch_critic, key, copy.deepcopy(value))

        critic_loss, crit_log_dict = self.critic_loss(batch_critic)
        with set_forward_context(current_timestep=batch_critic.timesteps,
                                 attn_metadata=batch_critic.attn_metadata):
            (critic_loss / gradient_accumulation_steps).backward()
        total_critic_loss += critic_loss.detach().item()
        critic_log_dict.update(crit_log_dict)
        # Store visualization data from critic training
        if hasattr(batch_critic, 'fake_score_latent_vis_dict'):
            training_batch.fake_score_latent_vis_dict.update(
                batch_critic.fake_score_latent_vis_dict)

    if self.train_fake_score_transformer_2 and self.fake_score_transformer_2 is not None:
        self._clip_model_grad_norm_(batch_critic,
                                    self.fake_score_transformer_2)
        self.fake_score_optimizer_2.step()
        self.fake_score_lr_scheduler_2.step()
    else:
        self._clip_model_grad_norm_(batch_critic,
                                    self.fake_score_transformer)
        self.fake_score_optimizer.step()
        self.fake_score_lr_scheduler.step()

    avg_critic_loss = torch.tensor(total_critic_loss /
                                   gradient_accumulation_steps,
                                   device=self.device)
    world_group = get_world_group()
    world_group.all_reduce(avg_critic_loss,
                           op=torch.distributed.ReduceOp.AVG)
    training_batch.fake_score_loss = avg_critic_loss.item()

    training_batch.total_loss = training_batch.generator_loss + training_batch.fake_score_loss
    return training_batch
fastvideo.training.self_forcing_distillation_pipeline.SelfForcingDistillationPipeline.visualize_intermediate_latents
visualize_intermediate_latents(training_batch: TrainingBatch, training_args: TrainingArgs, step: int)

Add visualization data to tracker logging and save frames to disk.

Source code in fastvideo/training/self_forcing_distillation_pipeline.py
def visualize_intermediate_latents(self, training_batch: TrainingBatch,
                                   training_args: TrainingArgs, step: int):
    """Add visualization data to tracker logging and save frames to disk."""
    tracker_loss_dict: dict[str, Any] = {}

    # Debug logging
    if hasattr(training_batch, 'dmd_latent_vis_dict'):
        logger.info("DMD latent keys: %s",
                    list(training_batch.dmd_latent_vis_dict.keys()))
    if hasattr(training_batch, 'fake_score_latent_vis_dict'):
        logger.info("Fake score latent keys: %s",
                    list(training_batch.fake_score_latent_vis_dict.keys()))

    # Process generator predictions if available
    if hasattr(
            training_batch,
            'dmd_latent_vis_dict') and training_batch.dmd_latent_vis_dict:
        dmd_latents_vis_dict = training_batch.dmd_latent_vis_dict
        dmd_log_keys = [
            'generator_pred_video', 'real_score_pred_video',
            'faker_score_pred_video'
        ]

        for latent_key in dmd_log_keys:
            if latent_key in dmd_latents_vis_dict:
                logger.info("Processing DMD latent: %s", latent_key)
                latents = dmd_latents_vis_dict[latent_key]
                if not isinstance(latents, torch.Tensor):
                    logger.warning("Expected tensor for %s, got %s",
                                   latent_key, type(latents))
                    continue

                latents = latents.detach()
                latents = latents.permute(0, 2, 1, 3, 4)

                if isinstance(self.vae.scaling_factor, torch.Tensor):
                    latents = latents / self.vae.scaling_factor.to(
                        latents.device, latents.dtype)
                else:
                    latents = latents / self.vae.scaling_factor

                if (hasattr(self.vae, "shift_factor")
                        and self.vae.shift_factor is not None):
                    if isinstance(self.vae.shift_factor, torch.Tensor):
                        latents += self.vae.shift_factor.to(
                            latents.device, latents.dtype)
                    else:
                        latents += self.vae.shift_factor

                with torch.autocast("cuda", dtype=torch.bfloat16):
                    video = self.vae.decode(latents)
                video = (video / 2 + 0.5).clamp(0, 1)
                video = video.cpu().float()
                video = video.permute(0, 2, 1, 3, 4)
                video = (video * 255).numpy().astype(np.uint8)
                video_artifact = self.tracker.video(video,
                                                    fps=24,
                                                    format="mp4")
                if video_artifact is not None:
                    tracker_loss_dict[f"dmd_{latent_key}"] = video_artifact
                del video, latents

    # Process critic predictions
    if hasattr(training_batch, 'fake_score_latent_vis_dict'
               ) and training_batch.fake_score_latent_vis_dict:
        fake_score_latents_vis_dict = training_batch.fake_score_latent_vis_dict
        fake_score_log_keys = ['generator_pred_video']

        for latent_key in fake_score_log_keys:
            if latent_key in fake_score_latents_vis_dict:
                logger.info("Processing critic latent: %s", latent_key)
                latents = fake_score_latents_vis_dict[latent_key]
                if not isinstance(latents, torch.Tensor):
                    logger.warning("Expected tensor for %s, got %s",
                                   latent_key, type(latents))
                    continue

                latents = latents.detach()
                latents = latents.permute(0, 2, 1, 3, 4)

                if isinstance(self.vae.scaling_factor, torch.Tensor):
                    latents = latents / self.vae.scaling_factor.to(
                        latents.device, latents.dtype)
                else:
                    latents = latents / self.vae.scaling_factor

                if (hasattr(self.vae, "shift_factor")
                        and self.vae.shift_factor is not None):
                    if isinstance(self.vae.shift_factor, torch.Tensor):
                        latents += self.vae.shift_factor.to(
                            latents.device, latents.dtype)
                    else:
                        latents += self.vae.shift_factor

                with torch.autocast("cuda", dtype=torch.bfloat16):
                    video = self.vae.decode(latents)
                video = (video / 2 + 0.5).clamp(0, 1)
                video = video.cpu().float()
                video = video.permute(0, 2, 1, 3, 4)
                video = (video * 255).numpy().astype(np.uint8)
                video_artifact = self.tracker.video(video,
                                                    fps=24,
                                                    format="mp4")
                if video_artifact is not None:
                    tracker_loss_dict[
                        f"critic_{latent_key}"] = video_artifact
                del video, latents

    # Log metadata
    if hasattr(
            training_batch,
            'dmd_latent_vis_dict') and training_batch.dmd_latent_vis_dict:
        if "generator_timestep" in training_batch.dmd_latent_vis_dict:
            tracker_loss_dict[
                "generator_timestep"] = training_batch.dmd_latent_vis_dict[
                    "generator_timestep"].item()
        if "dmd_timestep" in training_batch.dmd_latent_vis_dict:
            tracker_loss_dict[
                "dmd_timestep"] = training_batch.dmd_latent_vis_dict[
                    "dmd_timestep"].item()

    if hasattr(
            training_batch, 'fake_score_latent_vis_dict'
    ) and training_batch.fake_score_latent_vis_dict and "fake_score_timestep" in training_batch.fake_score_latent_vis_dict:
        tracker_loss_dict[
            "fake_score_timestep"] = training_batch.fake_score_latent_vis_dict[
                "fake_score_timestep"].item()

    # Log final dict contents
    logger.info("Final tracker_loss_dict keys: %s",
                list(tracker_loss_dict.keys()))

    if self.global_rank == 0 and tracker_loss_dict:
        self.tracker.log_artifacts(tracker_loss_dict, step)

Functions

fastvideo.training.trackers

Utilities for logging metrics and artifacts to external trackers.

This module is inspired by the trackers implementation in https://github.com/huggingface/finetrainers and provides a minimal, shared interface that can be used across all FastVideo training pipelines.

Classes

fastvideo.training.trackers.BaseTracker
BaseTracker()

Base tracker implementation.

The default tracker stores timing information but does not emit any logs.

Source code in fastvideo/training/trackers.py
def __init__(self) -> None:
    self._timed_metrics: dict[str, float] = {}
Functions
fastvideo.training.trackers.BaseTracker.finish
finish() -> None

Finalize the tracker session.

Source code in fastvideo/training/trackers.py
def finish(self) -> None:  # pragma: no cover - interface
    """Finalize the tracker session."""
fastvideo.training.trackers.BaseTracker.log
log(metrics: dict[str, Any], step: int) -> None

Log metrics for the given step.

Source code in fastvideo/training/trackers.py
def log(self, metrics: dict[str, Any],
        step: int) -> None:  # pragma: no cover - interface
    """Log metrics for the given step."""
    # Merge timing metrics with provided metrics
    metrics = {**self._timed_metrics, **metrics}
    self._timed_metrics = {}
fastvideo.training.trackers.BaseTracker.log_artifacts
log_artifacts(artifacts: dict[str, Any], step: int) -> None

Log artifacts such as videos or images.

By default this is treated the same as :meth:log.

Source code in fastvideo/training/trackers.py
def log_artifacts(self, artifacts: dict[str, Any], step: int) -> None:
    """Log artifacts such as videos or images.

    By default this is treated the same as :meth:`log`.
    """

    if artifacts:
        self.log(artifacts, step)
fastvideo.training.trackers.BaseTracker.video
video(data: Any, *, caption: str | None = None, fps: int | None = None, format: str | None = None) -> Any | None

Create a tracker specific video artifact.

Trackers that do not support video artifacts should return None.

Source code in fastvideo/training/trackers.py
def video(
    self,
    data: Any,
    *,
    caption: str | None = None,
    fps: int | None = None,
    format: str | None = None,
) -> Any | None:
    """Create a tracker specific video artifact.

    Trackers that do not support video artifacts should return ``None``.
    """

    return None
fastvideo.training.trackers.DummyTracker
DummyTracker()

Bases: BaseTracker

Tracker implementation used when logging is disabled.

Source code in fastvideo/training/trackers.py
def __init__(self) -> None:
    self._timed_metrics: dict[str, float] = {}
fastvideo.training.trackers.SequentialTracker
SequentialTracker(trackers: Iterable[BaseTracker])

Bases: BaseTracker

A tracker that forwards logging calls to a sequence of trackers.

Source code in fastvideo/training/trackers.py
def __init__(self, trackers: Iterable[BaseTracker]) -> None:
    super().__init__()
    self._trackers: list[BaseTracker] = list(trackers)
fastvideo.training.trackers.Timer dataclass
Timer(name: str, _start_time: float | None = None, _end_time: float | None = None)

Simple timer utility used by the trackers.

fastvideo.training.trackers.WandbTracker
WandbTracker(experiment_name: str, log_dir: str, *, config: dict[str, Any] | None = None, run_name: str | None = None)

Bases: BaseTracker

Tracker implementation for Weights & Biases.

Source code in fastvideo/training/trackers.py
def __init__(
    self,
    experiment_name: str,
    log_dir: str,
    *,
    config: dict[str, Any] | None = None,
    run_name: str | None = None,
) -> None:
    super().__init__()

    import wandb

    pathlib.Path(log_dir).mkdir(parents=True, exist_ok=True)

    self._wandb = wandb
    self._run = wandb.init(
        project=experiment_name,
        dir=log_dir,
        config=config,
        name=run_name,
    )
    logger.info("Initialized Weights & Biases tracker")

Functions

fastvideo.training.trackers.initialize_trackers
initialize_trackers(trackers: Iterable[str], *, experiment_name: str, config: dict[str, Any] | None, log_dir: str, run_name: str | None = None) -> BaseTracker

Create tracker instances based on trackers configuration.

Source code in fastvideo/training/trackers.py
def initialize_trackers(
    trackers: Iterable[str],
    *,
    experiment_name: str,
    config: dict[str, Any] | None,
    log_dir: str,
    run_name: str | None = None,
) -> BaseTracker:
    """Create tracker instances based on ``trackers`` configuration."""

    tracker_names = [tracker.lower() for tracker in trackers]
    if not tracker_names:
        return DummyTracker()

    unsupported = [
        name for name in tracker_names if name not in SUPPORTED_TRACKERS
    ]
    if unsupported:
        raise ValueError(
            f"Unsupported tracker(s) provided: {unsupported}. Supported trackers: {sorted(SUPPORTED_TRACKERS)}"
        )

    tracker_instances: list[BaseTracker] = []
    for tracker_name in tracker_names:
        if tracker_name == Trackers.NONE.value:
            tracker_instances.append(DummyTracker())
        elif tracker_name == Trackers.WANDB.value:
            tracker_instances.append(
                WandbTracker(
                    experiment_name,
                    os.path.abspath(log_dir),
                    config=config,
                    run_name=run_name,
                ))

    if not tracker_instances:
        return DummyTracker()

    if len(tracker_instances) == 1:
        return tracker_instances[0]

    return SequentialTracker(tracker_instances)

fastvideo.training.training_pipeline

Classes

fastvideo.training.training_pipeline.TrainingPipeline
TrainingPipeline(model_path: str, fastvideo_args: TrainingArgs, required_config_modules: list[str] | None = None, loaded_modules: dict[str, Module] | None = None)

Bases: LoRAPipeline, ABC

A pipeline for training a model. All training pipelines should inherit from this class. All reusable components and code should be implemented in this class.

Source code in fastvideo/training/training_pipeline.py
def __init__(
        self,
        model_path: str,
        fastvideo_args: TrainingArgs,
        required_config_modules: list[str] | None = None,
        loaded_modules: dict[str, torch.nn.Module] | None = None) -> None:
    fastvideo_args.inference_mode = False
    self.lora_training = fastvideo_args.lora_training
    if self.lora_training and fastvideo_args.lora_rank is None:
        raise ValueError("lora rank must be set when using lora training")

    set_random_seed(fastvideo_args.seed)  # for lora param init
    super().__init__(model_path, fastvideo_args, required_config_modules,
                     loaded_modules)  # type: ignore
    self.tracker = DummyTracker()
Functions
fastvideo.training.training_pipeline.TrainingPipeline.visualize_intermediate_latents
visualize_intermediate_latents(training_batch: TrainingBatch, training_args: TrainingArgs, step: int)

Add visualization data to tracker logging and save frames to disk.

Source code in fastvideo/training/training_pipeline.py
def visualize_intermediate_latents(self, training_batch: TrainingBatch,
                                   training_args: TrainingArgs, step: int):
    """Add visualization data to tracker logging and save frames to disk."""
    raise NotImplementedError(
        "Visualize intermediate latents is not implemented for training pipeline"
    )

Functions

fastvideo.training.training_utils

Classes

fastvideo.training.training_utils.EMA_FSDP
EMA_FSDP(module, decay: float = 0.999, mode: str = 'local_shard')
FSDP2-friendly EMA with two modes
  • mode="local_shard" (default): maintain float32 CPU EMA of local parameter shards on every rank. Provides a context manager to temporarily swap EMA weights into the live model for teacher forward.
  • mode="rank0_full": maintain a consolidated float32 CPU EMA of full parameters on rank 0 only using gather_state_dict_on_cpu_rank0(). Useful for checkpoint export; not for teacher forward.

Usage (local_shard for CM teacher): ema = EMA_FSDP(model, decay=0.999, mode="local_shard") for step in ...: ema.update(model) with ema.apply_to_model(model): with torch.no_grad(): y_teacher = model(...)

Usage (rank0_full for export): ema = EMA_FSDP(model, decay=0.999, mode="rank0_full") ema.update(model) ema.state_dict() # on rank 0

Source code in fastvideo/training/training_utils.py
def __init__(self, module, decay: float = 0.999, mode: str = "local_shard"):
    self.decay = float(decay)
    self.mode = mode
    self.shadow: dict[str, torch.Tensor] = {}
    self.rank = dist.get_rank() if dist.is_initialized() else 0
    if self.mode not in {"local_shard", "rank0_full"}:
        raise ValueError(f"Unsupported EMA_FSDP mode: {self.mode}")
    self._init_shadow(module)
Functions
fastvideo.training.training_utils.EMA_FSDP.copy_to_unwrapped
copy_to_unwrapped(module) -> None

Copy EMA weights into a non-sharded (unwrapped) module. Intended for export/eval. For mode="rank0_full", only rank 0 has the full EMA state.

Source code in fastvideo/training/training_utils.py
@torch.no_grad()
def copy_to_unwrapped(self, module) -> None:
    """
    Copy EMA weights into a non-sharded (unwrapped) module. Intended for export/eval.
    For mode="rank0_full", only rank 0 has the full EMA state.
    """
    if self.mode == "rank0_full" and self.rank != 0:
        return
    name_to_param = dict(module.named_parameters())
    for n, w in self.shadow.items():
        if n in name_to_param:
            p = name_to_param[n]
            p.data.copy_(w.to(dtype=p.dtype, device=p.device))

Functions

fastvideo.training.training_utils.clip_grad_norm_
clip_grad_norm_(parameters: Tensor | list[Tensor], max_norm: float, norm_type: float = 2.0, error_if_nonfinite: bool = False, foreach: bool | None = None, pp_mesh: DeviceMesh | None = None) -> Tensor

Clip the gradient norm of parameters.

Gradient norm clipping requires computing the gradient norm over the entire model. torch.nn.utils.clip_grad_norm_ only computes gradient norm along DP/FSDP/TP dimensions. We need to manually reduce the gradient norm across PP stages. See https://github.com/pytorch/torchtitan/issues/596 for details.

Parameters:

Name Type Description Default
parameters `torch.Tensor` or `List[torch.Tensor]`

Tensors that will have gradients normalized.

required
max_norm `float`

Maximum norm of the gradients after clipping.

required
norm_type `float`, defaults to `2.0`

Type of p-norm to use. Can be inf for infinity norm.

2.0
error_if_nonfinite `bool`, defaults to `False`

If True, an error is thrown if the total norm of the gradients from parameters is nan, inf, or -inf.

False
foreach `bool`, defaults to `None`

Use the faster foreach-based implementation. If None, use the foreach implementation for CUDA and CPU native tensors and silently fall back to the slow implementation for other device types.

None
pp_mesh `torch.distributed.device_mesh.DeviceMesh`, defaults to `None`

Pipeline parallel device mesh. If not None, will reduce gradient norm across PP stages.

None

Returns:

Type Description
Tensor

torch.Tensor: Total norm of the gradients

Source code in fastvideo/training/training_utils.py
@torch.no_grad()
def clip_grad_norm_(
    parameters: torch.Tensor | list[torch.Tensor],
    max_norm: float,
    norm_type: float = 2.0,
    error_if_nonfinite: bool = False,
    foreach: bool | None = None,
    pp_mesh: torch.distributed.device_mesh.DeviceMesh | None = None,
) -> torch.Tensor:
    r"""
    Clip the gradient norm of parameters.

    Gradient norm clipping requires computing the gradient norm over the entire model.
    `torch.nn.utils.clip_grad_norm_` only computes gradient norm along DP/FSDP/TP dimensions.
    We need to manually reduce the gradient norm across PP stages.
    See https://github.com/pytorch/torchtitan/issues/596 for details.

    Args:
        parameters (`torch.Tensor` or `List[torch.Tensor]`):
            Tensors that will have gradients normalized.
        max_norm (`float`):
            Maximum norm of the gradients after clipping.
        norm_type (`float`, defaults to `2.0`):
            Type of p-norm to use. Can be `inf` for infinity norm.
        error_if_nonfinite (`bool`, defaults to `False`):
            If `True`, an error is thrown if the total norm of the gradients from `parameters` is `nan`, `inf`, or `-inf`.
        foreach (`bool`, defaults to `None`):
            Use the faster foreach-based implementation. If `None`, use the foreach implementation for CUDA and CPU native tensors
            and silently fall back to the slow implementation for other device types.
        pp_mesh (`torch.distributed.device_mesh.DeviceMesh`, defaults to `None`):
            Pipeline parallel device mesh. If not `None`, will reduce gradient norm across PP stages.

    Returns:
        `torch.Tensor`:
            Total norm of the gradients
    """
    grads = [p.grad for p in parameters if p.grad is not None]

    # TODO(aryan): Wait for next Pytorch release to use `torch.nn.utils.get_total_norm`
    # total_norm = torch.nn.utils.get_total_norm(grads, norm_type, error_if_nonfinite, foreach)
    total_norm = _get_total_norm(grads, norm_type, error_if_nonfinite, foreach)

    # If total_norm is a DTensor, the placements must be `torch.distributed._tensor.ops.math_ops._NormPartial`.
    # We can simply reduce the DTensor to get the total norm in this tensor's process group
    # and then convert it to a local tensor.
    # It has two purposes:
    #   1. to make sure the total norm is computed correctly when PP is used (see below)
    #   2. to return a reduced total_norm tensor whose .item() would return the correct value
    if isinstance(total_norm, torch.distributed.tensor.DTensor):
        # Will reach here if any non-PP parallelism is used.
        # If only using PP, total_norm will be a local tensor.
        total_norm = total_norm.full_tensor()

    if pp_mesh is not None:
        raise NotImplementedError("Pipeline parallel is not supported")
        if math.isinf(norm_type):
            dist.all_reduce(total_norm,
                            op=dist.ReduceOp.MAX,
                            group=pp_mesh.get_group())
        else:
            total_norm **= norm_type
            dist.all_reduce(total_norm,
                            op=dist.ReduceOp.SUM,
                            group=pp_mesh.get_group())
            total_norm **= 1.0 / norm_type

    _clip_grads_with_norm_(parameters, max_norm, total_norm, foreach)
    return total_norm
fastvideo.training.training_utils.compute_density_for_timestep_sampling
compute_density_for_timestep_sampling(weighting_scheme: str, batch_size: int, generator, logit_mean: float | None = None, logit_std: float | None = None, mode_scale: float | None = None)

Compute the density for sampling the timesteps when doing SD3 training.

Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.

SD3 paper reference: https://arxiv.org/abs/2403.03206v1.

Source code in fastvideo/training/training_utils.py
def compute_density_for_timestep_sampling(
    weighting_scheme: str,
    batch_size: int,
    generator,
    logit_mean: float | None = None,
    logit_std: float | None = None,
    mode_scale: float | None = None,
):
    """
    Compute the density for sampling the timesteps when doing SD3 training.

    Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.

    SD3 paper reference: https://arxiv.org/abs/2403.03206v1.
    """
    if weighting_scheme == "logit_normal":
        # See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$).
        u = torch.normal(
            mean=logit_mean,
            std=logit_std,
            size=(batch_size, ),
            device="cpu",
            generator=generator,
        )
        u = torch.nn.functional.sigmoid(u)
    elif weighting_scheme == "mode":
        u = torch.rand(size=(batch_size, ), device="cpu", generator=generator)
        u = 1 - u - mode_scale * (torch.cos(math.pi * u / 2)**2 - 1 + u)
    else:
        u = torch.rand(size=(batch_size, ), device="cpu", generator=generator)
    return u
fastvideo.training.training_utils.custom_to_hf_state_dict
custom_to_hf_state_dict(state_dict: dict[str, Any] | Iterator[tuple[str, Tensor]], reverse_param_names_mapping: dict[str, tuple[str, int, int]]) -> dict[str, Any]

Convert fastvideo's custom model format to diffusers format using reverse_param_names_mapping.

Parameters:

Name Type Description Default
state_dict dict[str, Any] | Iterator[tuple[str, Tensor]]

State dict in fastvideo's custom format

required
reverse_param_names_mapping dict[str, tuple[str, int, int]]

Reverse mapping from fastvideo's custom format to diffusers format

required

Returns:

Type Description
dict[str, Any]

State dict in diffusers format

Source code in fastvideo/training/training_utils.py
def custom_to_hf_state_dict(
    state_dict: dict[str, Any] | Iterator[tuple[str, torch.Tensor]],
    reverse_param_names_mapping: dict[str, tuple[str, int,
                                                 int]]) -> dict[str, Any]:
    """
    Convert fastvideo's custom model format to diffusers format using reverse_param_names_mapping.

    Args:
        state_dict: State dict in fastvideo's custom format
        reverse_param_names_mapping: Reverse mapping from fastvideo's custom format to diffusers format

    Returns:
        State dict in diffusers format
    """
    assert len(
        reverse_param_names_mapping) > 0, "reverse_param_names_mapping is empty"
    if isinstance(state_dict, Iterator):
        state_dict = dict(state_dict)
    new_state_dict = {}
    # Group parameters that need to be split (merged parameters)
    merge_groups: dict[str, list[tuple[str, int, int]]] = {}

    # First pass: collect all merge groups
    for training_key, (
            diffusers_key, merge_index,
            num_params_to_merge) in reverse_param_names_mapping.items():
        if merge_index is not None:
            # This is a merged parameter that needs to be split
            if training_key not in merge_groups:
                merge_groups[training_key] = []
            merge_groups[training_key].append(
                (diffusers_key, merge_index, num_params_to_merge))

    # Second pass: handle merged parameters by splitting them
    used_keys = set()
    for training_key, splits in merge_groups.items():
        if training_key in state_dict:
            v = state_dict[training_key]
            # Sort by merge_index to ensure correct order
            splits.sort(key=lambda x: x[1])
            total = splits[0][2]
            split_size = v.shape[0] // total
            split_tensors = torch.split(v, split_size, dim=0)

            for diffusers_key, split_index, _ in splits:
                new_state_dict[diffusers_key] = split_tensors[split_index]
            used_keys.add(training_key)

    # Third pass: handle regular parameters (direct mappings)
    for training_key, v in state_dict.items():
        if training_key in used_keys:
            continue

        if training_key in reverse_param_names_mapping:
            diffusers_key, merge_index, _ = reverse_param_names_mapping[
                training_key]
            if merge_index is None:
                # Direct mapping
                new_state_dict[diffusers_key] = v
        else:
            # No mapping found, keep as is
            new_state_dict[training_key] = v

    return new_state_dict
fastvideo.training.training_utils.get_constant_schedule
get_constant_schedule(optimizer: Optimizer, last_epoch: int = -1) -> LambdaLR

Create a schedule with a constant learning rate, using the learning rate set in optimizer.

Parameters:

Name Type Description Default
optimizer [`~torch.optim.Optimizer`]

The optimizer for which to schedule the learning rate.

required
last_epoch `int`, *optional*, defaults to -1

The index of the last epoch when resuming training.

-1
Return

torch.optim.lr_scheduler.LambdaLR with the appropriate schedule.

Source code in fastvideo/training/training_utils.py
def get_constant_schedule(optimizer: Optimizer,
                          last_epoch: int = -1) -> LambdaLR:
    """
    Create a schedule with a constant learning rate, using the learning rate set in optimizer.

    Args:
        optimizer ([`~torch.optim.Optimizer`]):
            The optimizer for which to schedule the learning rate.
        last_epoch (`int`, *optional*, defaults to -1):
            The index of the last epoch when resuming training.

    Return:
        `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
    """
    return LambdaLR(optimizer, lambda _: 1, last_epoch=last_epoch)
fastvideo.training.training_utils.get_constant_schedule_with_warmup
get_constant_schedule_with_warmup(optimizer: Optimizer, num_warmup_steps: int, last_epoch: int = -1) -> LambdaLR

Create a schedule with a constant learning rate preceded by a warmup period during which the learning rate increases linearly between 0 and the initial lr set in the optimizer.

Parameters:

Name Type Description Default
optimizer [`~torch.optim.Optimizer`]

The optimizer for which to schedule the learning rate.

required
num_warmup_steps `int`

The number of steps for the warmup phase.

required
last_epoch `int`, *optional*, defaults to -1

The index of the last epoch when resuming training.

-1
Return

torch.optim.lr_scheduler.LambdaLR with the appropriate schedule.

Source code in fastvideo/training/training_utils.py
def get_constant_schedule_with_warmup(optimizer: Optimizer,
                                      num_warmup_steps: int,
                                      last_epoch: int = -1) -> LambdaLR:
    """
    Create a schedule with a constant learning rate preceded by a warmup period during which the learning rate
    increases linearly between 0 and the initial lr set in the optimizer.

    Args:
        optimizer ([`~torch.optim.Optimizer`]):
            The optimizer for which to schedule the learning rate.
        num_warmup_steps (`int`):
            The number of steps for the warmup phase.
        last_epoch (`int`, *optional*, defaults to -1):
            The index of the last epoch when resuming training.

    Return:
        `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
    """

    def lr_lambda(current_step: int):
        if current_step < num_warmup_steps:
            return float(current_step) / float(max(1.0, num_warmup_steps))
        return 1.0

    return LambdaLR(optimizer, lr_lambda, last_epoch=last_epoch)
fastvideo.training.training_utils.get_cosine_schedule_with_min_lr
get_cosine_schedule_with_min_lr(optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, min_lr_ratio: float = 0.1, num_cycles: float = 0.5, last_epoch: int = -1) -> LambdaLR

Create a schedule with a learning rate that decreases following the values of the cosine function between the initial lr set in the optimizer to a minimum lr (min_lr_ratio * initial_lr), after a warmup period during which it increases linearly between 0 and the initial lr set in the optimizer.

Parameters:

Name Type Description Default
optimizer [`~torch.optim.Optimizer`]

The optimizer for which to schedule the learning rate.

required
num_warmup_steps `int`

The number of steps for the warmup phase.

required
num_training_steps `int`

The total number of training steps.

required
min_lr_ratio `float`, *optional*, defaults to 0.1

The ratio of minimum learning rate to initial learning rate.

0.1
num_cycles `float`, *optional*, defaults to 0.5

The number of periods of the cosine function in a schedule.

0.5
last_epoch `int`, *optional*, defaults to -1

The index of the last epoch when resuming training.

-1
Return

torch.optim.lr_scheduler.LambdaLR with the appropriate schedule.

Source code in fastvideo/training/training_utils.py
def get_cosine_schedule_with_min_lr(optimizer: Optimizer,
                                    num_warmup_steps: int,
                                    num_training_steps: int,
                                    min_lr_ratio: float = 0.1,
                                    num_cycles: float = 0.5,
                                    last_epoch: int = -1) -> LambdaLR:
    """
    Create a schedule with a learning rate that decreases following the values of the cosine function between the
    initial lr set in the optimizer to a minimum lr (min_lr_ratio * initial_lr), after a warmup period during which 
    it increases linearly between 0 and the initial lr set in the optimizer.

    Args:
        optimizer ([`~torch.optim.Optimizer`]):
            The optimizer for which to schedule the learning rate.
        num_warmup_steps (`int`):
            The number of steps for the warmup phase.
        num_training_steps (`int`):
            The total number of training steps.
        min_lr_ratio (`float`, *optional*, defaults to 0.1):
            The ratio of minimum learning rate to initial learning rate.
        num_cycles (`float`, *optional*, defaults to 0.5):
            The number of periods of the cosine function in a schedule.
        last_epoch (`int`, *optional*, defaults to -1):
            The index of the last epoch when resuming training.

    Return:
        `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
    """

    def lr_lambda(current_step):
        if current_step < num_warmup_steps:
            return float(current_step) / float(max(1, num_warmup_steps))
        progress = float(current_step - num_warmup_steps) / float(
            max(1, num_training_steps - num_warmup_steps))
        # Cosine decay from 1.0 to min_lr_ratio over num_cycles periods
        # Use the same formula as standard cosine but ensure minimum is min_lr_ratio instead of 0
        cosine_value = 0.5 * (
            1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))
        # Ensure the value doesn't go below min_lr_ratio
        return max(min_lr_ratio, cosine_value)

    return LambdaLR(optimizer, lr_lambda, last_epoch)
fastvideo.training.training_utils.get_cosine_schedule_with_warmup
get_cosine_schedule_with_warmup(optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, num_cycles: float = 0.5, last_epoch: int = -1) -> LambdaLR

Create a schedule with a learning rate that decreases following the values of the cosine function between the initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the initial lr set in the optimizer.

Parameters:

Name Type Description Default
optimizer [`~torch.optim.Optimizer`]

The optimizer for which to schedule the learning rate.

required
num_warmup_steps `int`

The number of steps for the warmup phase.

required
num_training_steps `int`

The total number of training steps.

required
num_periods `float`, *optional*, defaults to 0.5

The number of periods of the cosine function in a schedule (the default is to just decrease from the max value to 0 following a half-cosine).

required
last_epoch `int`, *optional*, defaults to -1

The index of the last epoch when resuming training.

-1
Return

torch.optim.lr_scheduler.LambdaLR with the appropriate schedule.

Source code in fastvideo/training/training_utils.py
def get_cosine_schedule_with_warmup(optimizer: Optimizer,
                                    num_warmup_steps: int,
                                    num_training_steps: int,
                                    num_cycles: float = 0.5,
                                    last_epoch: int = -1) -> LambdaLR:
    """
    Create a schedule with a learning rate that decreases following the values of the cosine function between the
    initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the
    initial lr set in the optimizer.

    Args:
        optimizer ([`~torch.optim.Optimizer`]):
            The optimizer for which to schedule the learning rate.
        num_warmup_steps (`int`):
            The number of steps for the warmup phase.
        num_training_steps (`int`):
            The total number of training steps.
        num_periods (`float`, *optional*, defaults to 0.5):
            The number of periods of the cosine function in a schedule (the default is to just decrease from the max
            value to 0 following a half-cosine).
        last_epoch (`int`, *optional*, defaults to -1):
            The index of the last epoch when resuming training.

    Return:
        `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
    """

    def lr_lambda(current_step):
        if current_step < num_warmup_steps:
            return float(current_step) / float(max(1, num_warmup_steps))
        progress = float(current_step - num_warmup_steps) / float(
            max(1, num_training_steps - num_warmup_steps))
        return max(
            0.0, 0.5 *
            (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))

    return LambdaLR(optimizer, lr_lambda, last_epoch)
fastvideo.training.training_utils.get_cosine_with_hard_restarts_schedule_with_warmup
get_cosine_with_hard_restarts_schedule_with_warmup(optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, num_cycles: int = 1, last_epoch: int = -1) -> LambdaLR

Create a schedule with a learning rate that decreases following the values of the cosine function between the initial lr set in the optimizer to 0, with several hard restarts, after a warmup period during which it increases linearly between 0 and the initial lr set in the optimizer.

Parameters:

Name Type Description Default
optimizer [`~torch.optim.Optimizer`]

The optimizer for which to schedule the learning rate.

required
num_warmup_steps `int`

The number of steps for the warmup phase.

required
num_training_steps `int`

The total number of training steps.

required
num_cycles `int`, *optional*, defaults to 1

The number of hard restarts to use.

1
last_epoch `int`, *optional*, defaults to -1

The index of the last epoch when resuming training.

-1
Return

torch.optim.lr_scheduler.LambdaLR with the appropriate schedule.

Source code in fastvideo/training/training_utils.py
def get_cosine_with_hard_restarts_schedule_with_warmup(
        optimizer: Optimizer,
        num_warmup_steps: int,
        num_training_steps: int,
        num_cycles: int = 1,
        last_epoch: int = -1) -> LambdaLR:
    """
    Create a schedule with a learning rate that decreases following the values of the cosine function between the
    initial lr set in the optimizer to 0, with several hard restarts, after a warmup period during which it increases
    linearly between 0 and the initial lr set in the optimizer.

    Args:
        optimizer ([`~torch.optim.Optimizer`]):
            The optimizer for which to schedule the learning rate.
        num_warmup_steps (`int`):
            The number of steps for the warmup phase.
        num_training_steps (`int`):
            The total number of training steps.
        num_cycles (`int`, *optional*, defaults to 1):
            The number of hard restarts to use.
        last_epoch (`int`, *optional*, defaults to -1):
            The index of the last epoch when resuming training.

    Return:
        `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
    """

    def lr_lambda(current_step):
        if current_step < num_warmup_steps:
            return float(current_step) / float(max(1, num_warmup_steps))
        progress = float(current_step - num_warmup_steps) / float(
            max(1, num_training_steps - num_warmup_steps))
        if progress >= 1.0:
            return 0.0
        return max(
            0.0, 0.5 * (1.0 + math.cos(math.pi *
                                       ((float(num_cycles) * progress) % 1.0))))

    return LambdaLR(optimizer, lr_lambda, last_epoch)
fastvideo.training.training_utils.get_linear_schedule_with_warmup
get_linear_schedule_with_warmup(optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, last_epoch: int = -1) -> LambdaLR

Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer.

Parameters:

Name Type Description Default
optimizer [`~torch.optim.Optimizer`]

The optimizer for which to schedule the learning rate.

required
num_warmup_steps `int`

The number of steps for the warmup phase.

required
num_training_steps `int`

The total number of training steps.

required
last_epoch `int`, *optional*, defaults to -1

The index of the last epoch when resuming training.

-1
Return

torch.optim.lr_scheduler.LambdaLR with the appropriate schedule.

Source code in fastvideo/training/training_utils.py
def get_linear_schedule_with_warmup(optimizer: Optimizer,
                                    num_warmup_steps: int,
                                    num_training_steps: int,
                                    last_epoch: int = -1) -> LambdaLR:
    """
    Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after
    a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer.

    Args:
        optimizer ([`~torch.optim.Optimizer`]):
            The optimizer for which to schedule the learning rate.
        num_warmup_steps (`int`):
            The number of steps for the warmup phase.
        num_training_steps (`int`):
            The total number of training steps.
        last_epoch (`int`, *optional*, defaults to -1):
            The index of the last epoch when resuming training.

    Return:
        `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
    """

    def lr_lambda(current_step: int):
        if current_step < num_warmup_steps:
            return float(current_step) / float(max(1, num_warmup_steps))
        return max(
            0.0,
            float(num_training_steps - current_step) /
            float(max(1, num_training_steps - num_warmup_steps)))

    return LambdaLR(optimizer, lr_lambda, last_epoch)
fastvideo.training.training_utils.get_piecewise_constant_schedule
get_piecewise_constant_schedule(optimizer: Optimizer, step_rules: str, last_epoch: int = -1) -> LambdaLR

Create a schedule with a constant learning rate, using the learning rate set in optimizer.

Parameters:

Name Type Description Default
optimizer [`~torch.optim.Optimizer`]

The optimizer for which to schedule the learning rate.

required
step_rules `string`

The rules for the learning rate. ex: rule_steps="1:10,0.1:20,0.01:30,0.005" it means that the learning rate if multiple 1 for the first 10 steps, multiple 0.1 for the next 20 steps, multiple 0.01 for the next 30 steps and multiple 0.005 for the other steps.

required
last_epoch `int`, *optional*, defaults to -1

The index of the last epoch when resuming training.

-1
Return

torch.optim.lr_scheduler.LambdaLR with the appropriate schedule.

Source code in fastvideo/training/training_utils.py
def get_piecewise_constant_schedule(optimizer: Optimizer,
                                    step_rules: str,
                                    last_epoch: int = -1) -> LambdaLR:
    """
    Create a schedule with a constant learning rate, using the learning rate set in optimizer.

    Args:
        optimizer ([`~torch.optim.Optimizer`]):
            The optimizer for which to schedule the learning rate.
        step_rules (`string`):
            The rules for the learning rate. ex: rule_steps="1:10,0.1:20,0.01:30,0.005" it means that the learning rate
            if multiple 1 for the first 10 steps, multiple 0.1 for the next 20 steps, multiple 0.01 for the next 30
            steps and multiple 0.005 for the other steps.
        last_epoch (`int`, *optional*, defaults to -1):
            The index of the last epoch when resuming training.

    Return:
        `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
    """

    rules_dict = {}
    rule_list = step_rules.split(",")
    for rule_str in rule_list[:-1]:
        value_str, steps_str = rule_str.split(":")
        steps = int(steps_str)
        value = float(value_str)
        rules_dict[steps] = value
    last_lr_multiple = float(rule_list[-1])

    def create_rules_function(
            rules_dict: dict,
            last_lr_multiple: float) -> Callable[[int], float]:

        def rule_func(steps: int) -> float:
            for step_threshold, lr_multiple in sorted(rules_dict.items()):
                if steps < step_threshold:
                    return lr_multiple
            return last_lr_multiple

        return rule_func

    rules_func = create_rules_function(rules_dict, last_lr_multiple)

    return LambdaLR(optimizer, rules_func, last_epoch=last_epoch)
fastvideo.training.training_utils.get_polynomial_decay_schedule_with_warmup
get_polynomial_decay_schedule_with_warmup(optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, lr_end: float = 1e-07, power: float = 1.0, last_epoch: int = -1) -> LambdaLR

Create a schedule with a learning rate that decreases as a polynomial decay from the initial lr set in the optimizer to end lr defined by lr_end, after a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer.

Parameters:

Name Type Description Default
optimizer [`~torch.optim.Optimizer`]

The optimizer for which to schedule the learning rate.

required
num_warmup_steps `int`

The number of steps for the warmup phase.

required
num_training_steps `int`

The total number of training steps.

required
lr_end `float`, *optional*, defaults to 1e-7

The end LR.

1e-07
power `float`, *optional*, defaults to 1.0

Power factor.

1.0
last_epoch `int`, *optional*, defaults to -1

The index of the last epoch when resuming training.

-1

Note: power defaults to 1.0 as in the fairseq implementation, which in turn is based on the original BERT implementation at https://github.com/google-research/bert/blob/f39e881b169b9d53bea03d2d341b31707a6c052b/optimization.py#L37

Return

torch.optim.lr_scheduler.LambdaLR with the appropriate schedule.

Source code in fastvideo/training/training_utils.py
def get_polynomial_decay_schedule_with_warmup(
    optimizer: Optimizer,
    num_warmup_steps: int,
    num_training_steps: int,
    lr_end: float = 1e-7,
    power: float = 1.0,
    last_epoch: int = -1,
) -> LambdaLR:
    """
    Create a schedule with a learning rate that decreases as a polynomial decay from the initial lr set in the
    optimizer to end lr defined by *lr_end*, after a warmup period during which it increases linearly from 0 to the
    initial lr set in the optimizer.

    Args:
        optimizer ([`~torch.optim.Optimizer`]):
            The optimizer for which to schedule the learning rate.
        num_warmup_steps (`int`):
            The number of steps for the warmup phase.
        num_training_steps (`int`):
            The total number of training steps.
        lr_end (`float`, *optional*, defaults to 1e-7):
            The end LR.
        power (`float`, *optional*, defaults to 1.0):
            Power factor.
        last_epoch (`int`, *optional*, defaults to -1):
            The index of the last epoch when resuming training.

    Note: *power* defaults to 1.0 as in the fairseq implementation, which in turn is based on the original BERT
    implementation at
    https://github.com/google-research/bert/blob/f39e881b169b9d53bea03d2d341b31707a6c052b/optimization.py#L37

    Return:
        `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.

    """

    lr_init = optimizer.defaults["lr"]
    if not (lr_init > lr_end):
        raise ValueError(
            f"lr_end ({lr_end}) must be smaller than initial lr ({lr_init})")

    def lr_lambda(current_step: int):
        if current_step < num_warmup_steps:
            return float(current_step) / float(max(1, num_warmup_steps))
        elif current_step > num_training_steps:
            return lr_end / lr_init  # as LambdaLR multiplies by lr_init
        else:
            lr_range = lr_init - lr_end
            decay_steps = num_training_steps - num_warmup_steps
            pct_remaining = 1 - (current_step - num_warmup_steps) / decay_steps
            decay = lr_range * pct_remaining**power + lr_end
            return decay / lr_init  # as LambdaLR multiplies by lr_init

    return LambdaLR(optimizer, lr_lambda, last_epoch)
fastvideo.training.training_utils.get_scheduler
get_scheduler(name: str | SchedulerType, optimizer: Optimizer, step_rules: str | None = None, num_warmup_steps: int | None = None, num_training_steps: int | None = None, num_cycles: int = 1, power: float = 1.0, min_lr_ratio: float = 0.1, last_epoch: int = -1) -> LambdaLR

Unified API to get any scheduler from its name.

Parameters:

Name Type Description Default
name `str` or `SchedulerType`

The name of the scheduler to use.

required
optimizer `torch.optim.Optimizer`

The optimizer that will be used during training.

required
step_rules `str`, *optional*

A string representing the step rules to use. This is only used by the PIECEWISE_CONSTANT scheduler.

None
num_warmup_steps `int`, *optional*

The number of warmup steps to do. This is not required by all schedulers (hence the argument being optional), the function will raise an error if it's unset and the scheduler type requires it.

None
num_training_steps `int``, *optional*

The number of training steps to do. This is not required by all schedulers (hence the argument being optional), the function will raise an error if it's unset and the scheduler type requires it.

None
num_cycles `int`, *optional*

The number of hard restarts used in COSINE_WITH_RESTARTS scheduler.

1
power `float`, *optional*, defaults to 1.0

Power factor. See POLYNOMIAL scheduler

1.0
min_lr_ratio `float`, *optional*, defaults to 0.1

The ratio of minimum learning rate to initial learning rate. Used in COSINE_WITH_MIN_LR scheduler.

0.1
last_epoch `int`, *optional*, defaults to -1

The index of the last epoch when resuming training.

-1
Source code in fastvideo/training/training_utils.py
def get_scheduler(
    name: str | SchedulerType,
    optimizer: Optimizer,
    step_rules: str | None = None,
    num_warmup_steps: int | None = None,
    num_training_steps: int | None = None,
    num_cycles: int = 1,
    power: float = 1.0,
    min_lr_ratio: float = 0.1,
    last_epoch: int = -1,
) -> LambdaLR:
    """
    Unified API to get any scheduler from its name.

    Args:
        name (`str` or `SchedulerType`):
            The name of the scheduler to use.
        optimizer (`torch.optim.Optimizer`):
            The optimizer that will be used during training.
        step_rules (`str`, *optional*):
            A string representing the step rules to use. This is only used by the `PIECEWISE_CONSTANT` scheduler.
        num_warmup_steps (`int`, *optional*):
            The number of warmup steps to do. This is not required by all schedulers (hence the argument being
            optional), the function will raise an error if it's unset and the scheduler type requires it.
        num_training_steps (`int``, *optional*):
            The number of training steps to do. This is not required by all schedulers (hence the argument being
            optional), the function will raise an error if it's unset and the scheduler type requires it.
        num_cycles (`int`, *optional*):
            The number of hard restarts used in `COSINE_WITH_RESTARTS` scheduler.
        power (`float`, *optional*, defaults to 1.0):
            Power factor. See `POLYNOMIAL` scheduler
        min_lr_ratio (`float`, *optional*, defaults to 0.1):
            The ratio of minimum learning rate to initial learning rate. Used in `COSINE_WITH_MIN_LR` scheduler.
        last_epoch (`int`, *optional*, defaults to -1):
            The index of the last epoch when resuming training.
    """
    name = SchedulerType(name)
    schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
    if name == SchedulerType.CONSTANT:
        return schedule_func(optimizer, last_epoch=last_epoch)

    if name == SchedulerType.PIECEWISE_CONSTANT:
        return schedule_func(optimizer,
                             step_rules=step_rules,
                             last_epoch=last_epoch)

    # All other schedulers require `num_warmup_steps`
    if num_warmup_steps is None:
        raise ValueError(
            f"{name} requires `num_warmup_steps`, please provide that argument."
        )

    if name == SchedulerType.CONSTANT_WITH_WARMUP:
        return schedule_func(optimizer,
                             num_warmup_steps=num_warmup_steps,
                             last_epoch=last_epoch)

    # All other schedulers require `num_training_steps`
    if num_training_steps is None:
        raise ValueError(
            f"{name} requires `num_training_steps`, please provide that argument."
        )

    if name == SchedulerType.COSINE_WITH_RESTARTS:
        return schedule_func(
            optimizer,
            num_warmup_steps=num_warmup_steps,
            num_training_steps=num_training_steps,
            num_cycles=num_cycles,
            last_epoch=last_epoch,
        )

    if name == SchedulerType.POLYNOMIAL:
        return schedule_func(
            optimizer,
            num_warmup_steps=num_warmup_steps,
            num_training_steps=num_training_steps,
            power=power,
            last_epoch=last_epoch,
        )

    if name == SchedulerType.COSINE_WITH_MIN_LR:
        return schedule_func(
            optimizer,
            num_warmup_steps=num_warmup_steps,
            num_training_steps=num_training_steps,
            min_lr_ratio=min_lr_ratio,
            last_epoch=last_epoch,
        )

    return schedule_func(optimizer,
                         num_warmup_steps=num_warmup_steps,
                         num_training_steps=num_training_steps,
                         last_epoch=last_epoch)
fastvideo.training.training_utils.load_checkpoint
load_checkpoint(transformer, rank, checkpoint_path, optimizer=None, dataloader=None, scheduler=None, noise_generator=None) -> int

Load checkpoint following finetrainer's distributed checkpoint approach. Returns the step number from which training should resume.

Source code in fastvideo/training/training_utils.py
def load_checkpoint(transformer,
                    rank,
                    checkpoint_path,
                    optimizer=None,
                    dataloader=None,
                    scheduler=None,
                    noise_generator=None) -> int:
    """
    Load checkpoint following finetrainer's distributed checkpoint approach.
    Returns the step number from which training should resume.
    """
    if not os.path.exists(checkpoint_path):
        logger.warning("Checkpoint path %s does not exist", checkpoint_path)
        return 0

    # Extract step number from checkpoint path
    step = int(os.path.basename(checkpoint_path).split('-')[-1])

    if rank == 0:
        logger.info("Loading checkpoint from step %s", step)

    dcp_dir = os.path.join(checkpoint_path, "distributed_checkpoint")

    if not os.path.exists(dcp_dir):
        logger.warning("Distributed checkpoint directory %s does not exist",
                       dcp_dir)
        return 0

    states = {
        "model": ModelWrapper(transformer),
        "random_state": RandomStateWrapper(noise_generator),
    }

    if optimizer is not None:
        states["optimizer"] = OptimizerWrapper(transformer, optimizer)

    if dataloader is not None:
        states["dataloader"] = dataloader

    if scheduler is not None:
        states["scheduler"] = SchedulerWrapper(scheduler)

    logger.info("rank: %s, loading distributed checkpoint from %s",
                rank,
                dcp_dir,
                local_main_process_only=False)

    begin_time = time.perf_counter()
    dcp.load(states, checkpoint_id=dcp_dir)
    end_time = time.perf_counter()

    logger.info("rank: %s, distributed checkpoint loaded in %.2f seconds",
                rank,
                end_time - begin_time,
                local_main_process_only=False)
    logger.info("--> checkpoint loaded from step %s", step)

    return step
fastvideo.training.training_utils.load_distillation_checkpoint
load_distillation_checkpoint(generator_transformer, fake_score_transformer, rank, checkpoint_path, generator_optimizer=None, fake_score_optimizer=None, dataloader=None, generator_scheduler=None, fake_score_scheduler=None, noise_generator=None, generator_ema=None, generator_transformer_2=None, real_score_transformer_2=None, fake_score_transformer_2=None, generator_optimizer_2=None, fake_score_optimizer_2=None, generator_scheduler_2=None, fake_score_scheduler_2=None, generator_ema_2=None) -> int

Load distillation checkpoint with both generator and fake_score models. Supports MoE (Mixture of Experts) models with transformer_2 variants. Returns the step number from which training should resume.

Parameters:

Name Type Description Default
generator_transformer

Main generator transformer model

required
fake_score_transformer

Main fake score transformer model

required
generator_transformer_2

Secondary generator transformer for MoE (optional)

None
real_score_transformer_2

Secondary real score transformer for MoE (optional)

None
fake_score_transformer_2

Secondary fake score transformer for MoE (optional)

None
generator_optimizer_2

Optimizer for generator_transformer_2 (optional)

None
fake_score_optimizer_2

Optimizer for fake_score_transformer_2 (optional)

None
generator_scheduler_2

Scheduler for generator_transformer_2 (optional)

None
fake_score_scheduler_2

Scheduler for fake_score_transformer_2 (optional)

None
generator_ema_2

EMA for generator_transformer_2 (optional)

None
Source code in fastvideo/training/training_utils.py
def load_distillation_checkpoint(
        generator_transformer,
        fake_score_transformer,
        rank,
        checkpoint_path,
        generator_optimizer=None,
        fake_score_optimizer=None,
        dataloader=None,
        generator_scheduler=None,
        fake_score_scheduler=None,
        noise_generator=None,
        generator_ema=None,
        # MoE support
        generator_transformer_2=None,
        real_score_transformer_2=None,
        fake_score_transformer_2=None,
        generator_optimizer_2=None,
        fake_score_optimizer_2=None,
        generator_scheduler_2=None,
        fake_score_scheduler_2=None,
        generator_ema_2=None) -> int:
    """
    Load distillation checkpoint with both generator and fake_score models.
    Supports MoE (Mixture of Experts) models with transformer_2 variants.
    Returns the step number from which training should resume.

    Args:
        generator_transformer: Main generator transformer model
        fake_score_transformer: Main fake score transformer model
        generator_transformer_2: Secondary generator transformer for MoE (optional)
        real_score_transformer_2: Secondary real score transformer for MoE (optional)
        fake_score_transformer_2: Secondary fake score transformer for MoE (optional)
        generator_optimizer_2: Optimizer for generator_transformer_2 (optional)
        fake_score_optimizer_2: Optimizer for fake_score_transformer_2 (optional)
        generator_scheduler_2: Scheduler for generator_transformer_2 (optional)
        fake_score_scheduler_2: Scheduler for fake_score_transformer_2 (optional)
        generator_ema_2: EMA for generator_transformer_2 (optional)
    """
    if not os.path.exists(checkpoint_path):
        logger.warning("Distillation checkpoint path %s does not exist",
                       checkpoint_path)
        return 0

    # Extract step number from checkpoint path
    step = int(os.path.basename(checkpoint_path).split('-')[-1])

    if rank == 0:
        logger.info("Loading distillation checkpoint from step %s", step)

    # Load generator distributed checkpoint
    generator_dcp_dir = os.path.join(checkpoint_path, "distributed_checkpoint",
                                     "generator")
    if not os.path.exists(generator_dcp_dir):
        logger.warning(
            "Generator distributed checkpoint directory %s does not exist",
            generator_dcp_dir)
        return 0

    generator_states = {
        "model": ModelWrapper(generator_transformer),
    }

    if generator_optimizer is not None:
        generator_states["optimizer"] = OptimizerWrapper(
            generator_transformer, generator_optimizer)

    if dataloader is not None:
        generator_states["dataloader"] = dataloader

    if generator_scheduler is not None:
        generator_states["scheduler"] = SchedulerWrapper(generator_scheduler)

    logger.info("rank: %s, loading generator distributed checkpoint from %s",
                rank,
                generator_dcp_dir,
                local_main_process_only=False)

    begin_time = time.perf_counter()
    dcp.load(generator_states, checkpoint_id=generator_dcp_dir)
    end_time = time.perf_counter()

    logger.info(
        "rank: %s, generator distributed checkpoint loaded in %.2f seconds",
        rank,
        end_time - begin_time,
        local_main_process_only=False)

    # Load EMA state if available and generator_ema is provided
    if generator_ema is not None:
        try:
            ema_state = generator_states.get("ema")
            if ema_state is not None:
                generator_ema.load_state_dict(ema_state)
                logger.info("rank: %s, generator EMA state loaded successfully",
                            rank)
            else:
                logger.info("rank: %s, no EMA state found in checkpoint", rank)
        except Exception as e:
            logger.warning("rank: %s, failed to load EMA state: %s", rank,
                           str(e))

    # Load generator_2 distributed checkpoint (MoE support)
    if generator_transformer_2 is not None:
        generator_2_dcp_dir = os.path.join(checkpoint_path,
                                           "distributed_checkpoint",
                                           "generator_2")
        if os.path.exists(generator_2_dcp_dir):
            generator_2_states = {
                "model": ModelWrapper(generator_transformer_2),
            }

            if generator_optimizer_2 is not None:
                generator_2_states["optimizer"] = OptimizerWrapper(
                    generator_transformer_2, generator_optimizer_2)

            if dataloader is not None:
                generator_2_states["dataloader"] = dataloader

            if generator_scheduler_2 is not None:
                generator_2_states["scheduler"] = SchedulerWrapper(
                    generator_scheduler_2)

            logger.info(
                "rank: %s, loading generator_2 distributed checkpoint from %s",
                rank,
                generator_2_dcp_dir,
                local_main_process_only=False)

            begin_time = time.perf_counter()
            dcp.load(generator_2_states, checkpoint_id=generator_2_dcp_dir)
            end_time = time.perf_counter()

            logger.info(
                "rank: %s, generator_2 distributed checkpoint loaded in %.2f seconds",
                rank,
                end_time - begin_time,
                local_main_process_only=False)

            # Load EMA_2 state if available and generator_ema_2 is provided
            if generator_ema_2 is not None:
                try:
                    ema_2_state = generator_2_states.get("ema")
                    if ema_2_state is not None:
                        generator_ema_2.load_state_dict(ema_2_state)
                        logger.info(
                            "rank: %s, generator_2 EMA state loaded successfully",
                            rank)
                    else:
                        logger.info(
                            "rank: %s, no EMA_2 state found in checkpoint",
                            rank)
                except Exception as e:
                    logger.warning("rank: %s, failed to load EMA_2 state: %s",
                                   rank, str(e))
        else:
            logger.info("rank: %s, generator_2 checkpoint not found, skipping",
                        rank)

    # Load critic distributed checkpoint
    critic_dcp_dir = os.path.join(checkpoint_path, "distributed_checkpoint",
                                  "critic")
    if not os.path.exists(critic_dcp_dir):
        logger.warning(
            "Critic distributed checkpoint directory %s does not exist",
            critic_dcp_dir)
        return 0

    critic_states = {
        "model": ModelWrapper(fake_score_transformer),
    }

    if fake_score_optimizer is not None:
        critic_states["optimizer"] = OptimizerWrapper(fake_score_transformer,
                                                      fake_score_optimizer)

    if dataloader is not None:
        critic_states["dataloader"] = dataloader

    if fake_score_scheduler is not None:
        critic_states["scheduler"] = SchedulerWrapper(fake_score_scheduler)

    logger.info("rank: %s, loading critic distributed checkpoint from %s",
                rank,
                critic_dcp_dir,
                local_main_process_only=False)

    begin_time = time.perf_counter()
    dcp.load(critic_states, checkpoint_id=critic_dcp_dir)
    end_time = time.perf_counter()

    logger.info(
        "rank: %s, critic distributed checkpoint loaded in %.2f seconds",
        rank,
        end_time - begin_time,
        local_main_process_only=False)

    # Load critic_2 distributed checkpoint (MoE support)
    if fake_score_transformer_2 is not None:
        critic_2_dcp_dir = os.path.join(checkpoint_path,
                                        "distributed_checkpoint", "critic_2")
        if os.path.exists(critic_2_dcp_dir):
            critic_2_states = {
                "model": ModelWrapper(fake_score_transformer_2),
            }

            if fake_score_optimizer_2 is not None:
                critic_2_states["optimizer"] = OptimizerWrapper(
                    fake_score_transformer_2, fake_score_optimizer_2)

            if dataloader is not None:
                critic_2_states["dataloader"] = dataloader

            if fake_score_scheduler_2 is not None:
                critic_2_states["scheduler"] = SchedulerWrapper(
                    fake_score_scheduler_2)

            logger.info(
                "rank: %s, loading critic_2 distributed checkpoint from %s",
                rank,
                critic_2_dcp_dir,
                local_main_process_only=False)

            begin_time = time.perf_counter()
            dcp.load(critic_2_states, checkpoint_id=critic_2_dcp_dir)
            end_time = time.perf_counter()

            logger.info(
                "rank: %s, critic_2 distributed checkpoint loaded in %.2f seconds",
                rank,
                end_time - begin_time,
                local_main_process_only=False)
        else:
            logger.info("rank: %s, critic_2 checkpoint not found, skipping",
                        rank)

    # Load real_score_2 distributed checkpoint (MoE support)
    if real_score_transformer_2 is not None:
        real_score_2_dcp_dir = os.path.join(checkpoint_path,
                                            "distributed_checkpoint",
                                            "real_score_2")
        if os.path.exists(real_score_2_dcp_dir):
            real_score_2_states = {
                "model": ModelWrapper(real_score_transformer_2),
            }

            if dataloader is not None:
                real_score_2_states["dataloader"] = dataloader

            logger.info(
                "rank: %s, loading real_score_2 distributed checkpoint from %s",
                rank,
                real_score_2_dcp_dir,
                local_main_process_only=False)

            begin_time = time.perf_counter()
            dcp.load(real_score_2_states, checkpoint_id=real_score_2_dcp_dir)
            end_time = time.perf_counter()

            logger.info(
                "rank: %s, real_score_2 distributed checkpoint loaded in %.2f seconds",
                rank,
                end_time - begin_time,
                local_main_process_only=False)
        else:
            logger.info("rank: %s, real_score_2 checkpoint not found, skipping",
                        rank)

    # Load shared random state
    shared_dcp_dir = os.path.join(checkpoint_path, "distributed_checkpoint",
                                  "shared")
    if not os.path.exists(shared_dcp_dir):
        logger.warning("Shared random state directory %s does not exist",
                       shared_dcp_dir)
        return 0

    shared_states = {
        "random_state": RandomStateWrapper(noise_generator),
    }

    begin_time = time.perf_counter()
    dcp.load(shared_states, checkpoint_id=shared_dcp_dir)
    end_time = time.perf_counter()

    logger.info("rank: %s, shared random state loaded in %.2f seconds",
                rank,
                end_time - begin_time,
                local_main_process_only=False)
    logger.info("--> distillation checkpoint loaded from step %s", step)
    return step
fastvideo.training.training_utils.save_checkpoint
save_checkpoint(transformer, rank, output_dir, step, optimizer=None, dataloader=None, scheduler=None, noise_generator=None) -> None

Save checkpoint following finetrainer's distributed checkpoint approach. Saves both distributed checkpoint and consolidated model weights.

Source code in fastvideo/training/training_utils.py
def save_checkpoint(transformer,
                    rank,
                    output_dir,
                    step,
                    optimizer=None,
                    dataloader=None,
                    scheduler=None,
                    noise_generator=None) -> None:
    """
    Save checkpoint following finetrainer's distributed checkpoint approach.
    Saves both distributed checkpoint and consolidated model weights.
    """
    save_dir = os.path.join(output_dir, f"checkpoint-{step}")
    os.makedirs(save_dir, exist_ok=True)

    states = {
        "model": ModelWrapper(transformer),
        "random_state": RandomStateWrapper(noise_generator),
    }

    if optimizer is not None:
        states["optimizer"] = OptimizerWrapper(transformer, optimizer)

    if dataloader is not None:
        states["dataloader"] = dataloader

    if scheduler is not None:
        states["scheduler"] = SchedulerWrapper(scheduler)
    dcp_dir = os.path.join(save_dir, "distributed_checkpoint")
    logger.info("rank: %s, saving distributed checkpoint to %s",
                rank,
                dcp_dir,
                local_main_process_only=False)

    begin_time = time.perf_counter()
    dcp.save(states, checkpoint_id=dcp_dir)
    end_time = time.perf_counter()

    logger.info("rank: %s, distributed checkpoint saved in %.2f seconds",
                rank,
                end_time - begin_time,
                local_main_process_only=False)

    cpu_state = gather_state_dict_on_cpu_rank0(transformer, device=None)
    if rank == 0:
        # Save model weights (consolidated)
        transformer_save_dir = os.path.join(save_dir, "transformer")
        os.makedirs(transformer_save_dir, exist_ok=True)
        weight_path = os.path.join(transformer_save_dir,
                                   "diffusion_pytorch_model.safetensors")
        logger.info("rank: %s, saving consolidated checkpoint to %s",
                    rank,
                    weight_path,
                    local_main_process_only=False)

        # Convert training format to diffusers format and save
        diffusers_state_dict = custom_to_hf_state_dict(
            cpu_state, transformer.reverse_param_names_mapping)
        save_file(diffusers_state_dict, weight_path)

        logger.info("rank: %s, consolidated checkpoint saved to %s",
                    rank,
                    weight_path,
                    local_main_process_only=False)

        # Save model config
        config_dict = transformer.hf_config
        if "dtype" in config_dict:
            del config_dict["dtype"]  # TODO
        config_path = os.path.join(transformer_save_dir, "config.json")
        # save dict as json
        with open(config_path, "w") as f:
            json.dump(config_dict, f, indent=4)
        logger.info("--> checkpoint saved at step %s to %s", step, weight_path)
fastvideo.training.training_utils.save_distillation_checkpoint
save_distillation_checkpoint(generator_transformer, fake_score_transformer, rank, output_dir, step, generator_optimizer=None, fake_score_optimizer=None, dataloader=None, generator_scheduler=None, fake_score_scheduler=None, noise_generator=None, generator_ema=None, only_save_generator_weight=False, generator_transformer_2=None, real_score_transformer_2=None, fake_score_transformer_2=None, generator_optimizer_2=None, fake_score_optimizer_2=None, generator_scheduler_2=None, fake_score_scheduler_2=None, generator_ema_2=None) -> None

Save distillation checkpoint with both generator and fake_score models. Supports MoE (Mixture of Experts) models with transformer_2 variants. Saves both distributed checkpoint and consolidated model weights. Only saves the generator model for inference (consolidated weights).

Parameters:

Name Type Description Default
generator_transformer

Main generator transformer model

required
fake_score_transformer

Main fake score transformer model

required
only_save_generator_weight

If True, only save the generator model weights for inference without saving distributed checkpoint for training resume.

False
generator_transformer_2

Secondary generator transformer for MoE (optional)

None
real_score_transformer_2

Secondary real score transformer for MoE (optional)

None
fake_score_transformer_2

Secondary fake score transformer for MoE (optional)

None
generator_optimizer_2

Optimizer for generator_transformer_2 (optional)

None
fake_score_optimizer_2

Optimizer for fake_score_transformer_2 (optional)

None
generator_scheduler_2

Scheduler for generator_transformer_2 (optional)

None
fake_score_scheduler_2

Scheduler for fake_score_transformer_2 (optional)

None
generator_ema_2

EMA for generator_transformer_2 (optional)

None
Source code in fastvideo/training/training_utils.py
def save_distillation_checkpoint(
        generator_transformer,
        fake_score_transformer,
        rank,
        output_dir,
        step,
        generator_optimizer=None,
        fake_score_optimizer=None,
        dataloader=None,
        generator_scheduler=None,
        fake_score_scheduler=None,
        noise_generator=None,
        generator_ema=None,
        only_save_generator_weight=False,
        # MoE support
        generator_transformer_2=None,
        real_score_transformer_2=None,
        fake_score_transformer_2=None,
        generator_optimizer_2=None,
        fake_score_optimizer_2=None,
        generator_scheduler_2=None,
        fake_score_scheduler_2=None,
        generator_ema_2=None) -> None:
    """
    Save distillation checkpoint with both generator and fake_score models.
    Supports MoE (Mixture of Experts) models with transformer_2 variants.
    Saves both distributed checkpoint and consolidated model weights.
    Only saves the generator model for inference (consolidated weights).

    Args:
        generator_transformer: Main generator transformer model
        fake_score_transformer: Main fake score transformer model
        only_save_generator_weight: If True, only save the generator model weights for inference
                                   without saving distributed checkpoint for training resume.
        generator_transformer_2: Secondary generator transformer for MoE (optional)
        real_score_transformer_2: Secondary real score transformer for MoE (optional) 
        fake_score_transformer_2: Secondary fake score transformer for MoE (optional)
        generator_optimizer_2: Optimizer for generator_transformer_2 (optional)
        fake_score_optimizer_2: Optimizer for fake_score_transformer_2 (optional)
        generator_scheduler_2: Scheduler for generator_transformer_2 (optional)
        fake_score_scheduler_2: Scheduler for fake_score_transformer_2 (optional)
        generator_ema_2: EMA for generator_transformer_2 (optional)
    """
    save_dir = os.path.join(output_dir, f"checkpoint-{step}")
    os.makedirs(save_dir, exist_ok=True)

    # Create directories for models
    inference_save_dir = os.path.join(save_dir,
                                      "generator_inference_transformer")

    # Only save distributed checkpoint if not only saving generator weight
    if not only_save_generator_weight:
        # Save generator distributed checkpoint
        generator_states = {
            "model": ModelWrapper(generator_transformer),
        }
        if generator_optimizer is not None:
            generator_states["optimizer"] = OptimizerWrapper(
                generator_transformer, generator_optimizer)
        if dataloader is not None:
            generator_states["dataloader"] = dataloader
        if generator_scheduler is not None:
            generator_states["scheduler"] = SchedulerWrapper(
                generator_scheduler)
        if generator_ema is not None:
            generator_states["ema"] = generator_ema.state_dict()

        generator_dcp_dir = os.path.join(save_dir, "distributed_checkpoint",
                                         "generator")
        logger.info("rank: %s, saving generator distributed checkpoint to %s",
                    rank,
                    generator_dcp_dir,
                    local_main_process_only=False)

        begin_time = time.perf_counter()
        dcp.save(generator_states, checkpoint_id=generator_dcp_dir)
        end_time = time.perf_counter()

        logger.info(
            "rank: %s, generator distributed checkpoint saved in %.2f seconds",
            rank,
            end_time - begin_time,
            local_main_process_only=False)

        # Save generator_2 distributed checkpoint (MoE support)
        if generator_transformer_2 is not None:
            generator_2_states = {
                "model": ModelWrapper(generator_transformer_2),
            }
            if generator_optimizer_2 is not None:
                generator_2_states["optimizer"] = OptimizerWrapper(
                    generator_transformer_2, generator_optimizer_2)
            if dataloader is not None:
                generator_2_states["dataloader"] = dataloader
            if generator_scheduler_2 is not None:
                generator_2_states["scheduler"] = SchedulerWrapper(
                    generator_scheduler_2)
            if generator_ema_2 is not None:
                generator_2_states["ema"] = generator_ema_2.state_dict()

            generator_2_dcp_dir = os.path.join(save_dir,
                                               "distributed_checkpoint",
                                               "generator_2")
            logger.info(
                "rank: %s, saving generator_2 distributed checkpoint to %s",
                rank,
                generator_2_dcp_dir,
                local_main_process_only=False)

            begin_time = time.perf_counter()
            dcp.save(generator_2_states, checkpoint_id=generator_2_dcp_dir)
            end_time = time.perf_counter()

            logger.info(
                "rank: %s, generator_2 distributed checkpoint saved in %.2f seconds",
                rank,
                end_time - begin_time,
                local_main_process_only=False)

        # Save critic distributed checkpoint
        critic_states = {
            "model": ModelWrapper(fake_score_transformer),
        }
        if fake_score_optimizer is not None:
            critic_states["optimizer"] = OptimizerWrapper(
                fake_score_transformer, fake_score_optimizer)
        if dataloader is not None:
            critic_states["dataloader"] = dataloader
        if fake_score_scheduler is not None:
            critic_states["scheduler"] = SchedulerWrapper(fake_score_scheduler)

        critic_dcp_dir = os.path.join(save_dir, "distributed_checkpoint",
                                      "critic")
        logger.info("rank: %s, saving critic distributed checkpoint to %s",
                    rank,
                    critic_dcp_dir,
                    local_main_process_only=False)

        begin_time = time.perf_counter()
        dcp.save(critic_states, checkpoint_id=critic_dcp_dir)
        end_time = time.perf_counter()

        logger.info(
            "rank: %s, critic distributed checkpoint saved in %.2f seconds",
            rank,
            end_time - begin_time,
            local_main_process_only=False)

        # Save critic_2 distributed checkpoint (MoE support)
        if fake_score_transformer_2 is not None:
            critic_2_states = {
                "model": ModelWrapper(fake_score_transformer_2),
            }
            if fake_score_optimizer_2 is not None:
                critic_2_states["optimizer"] = OptimizerWrapper(
                    fake_score_transformer_2, fake_score_optimizer_2)
            if dataloader is not None:
                critic_2_states["dataloader"] = dataloader
            if fake_score_scheduler_2 is not None:
                critic_2_states["scheduler"] = SchedulerWrapper(
                    fake_score_scheduler_2)

            critic_2_dcp_dir = os.path.join(save_dir, "distributed_checkpoint",
                                            "critic_2")
            logger.info(
                "rank: %s, saving critic_2 distributed checkpoint to %s",
                rank,
                critic_2_dcp_dir,
                local_main_process_only=False)

            begin_time = time.perf_counter()
            dcp.save(critic_2_states, checkpoint_id=critic_2_dcp_dir)
            end_time = time.perf_counter()

            logger.info(
                "rank: %s, critic_2 distributed checkpoint saved in %.2f seconds",
                rank,
                end_time - begin_time,
                local_main_process_only=False)

        # Save real_score_transformer_2 distributed checkpoint (MoE support)
        if real_score_transformer_2 is not None:
            real_score_2_states = {
                "model": ModelWrapper(real_score_transformer_2),
            }
            # Note: real_score_transformer_2 typically doesn't have optimizer/scheduler
            # since it's used for inference only, but we include dataloader for consistency
            if dataloader is not None:
                real_score_2_states["dataloader"] = dataloader

            real_score_2_dcp_dir = os.path.join(save_dir,
                                                "distributed_checkpoint",
                                                "real_score_2")
            logger.info(
                "rank: %s, saving real_score_2 distributed checkpoint to %s",
                rank,
                real_score_2_dcp_dir,
                local_main_process_only=False)

            begin_time = time.perf_counter()
            dcp.save(real_score_2_states, checkpoint_id=real_score_2_dcp_dir)
            end_time = time.perf_counter()

            logger.info(
                "rank: %s, real_score_2 distributed checkpoint saved in %.2f seconds",
                rank,
                end_time - begin_time,
                local_main_process_only=False)

        # Save shared random state separately
        shared_states = {
            "random_state": RandomStateWrapper(noise_generator),
        }
        shared_dcp_dir = os.path.join(save_dir, "distributed_checkpoint",
                                      "shared")

        dcp.save(shared_states, checkpoint_id=shared_dcp_dir)

    else:
        logger.info(
            "rank: %s, skipping distributed checkpoint save (only_save_generator_weight=True)",
            rank,
            local_main_process_only=False)

    # Save generator model weights (consolidated) for inference
    cpu_state = gather_state_dict_on_cpu_rank0(generator_transformer,
                                               device=None)

    if rank == 0:
        # Save generator model weights (consolidated) for inference
        os.makedirs(inference_save_dir, exist_ok=True)
        weight_path = os.path.join(inference_save_dir,
                                   "diffusion_pytorch_model.safetensors")
        logger.info(
            "rank: %s, saving consolidated generator inference checkpoint to %s",
            rank,
            weight_path,
            local_main_process_only=False)

        # Convert training format to diffusers format and save
        diffusers_state_dict = custom_to_hf_state_dict(
            cpu_state, generator_transformer.reverse_param_names_mapping)
        save_file(diffusers_state_dict, weight_path)

        logger.info(
            "rank: %s, consolidated generator inference checkpoint saved to %s",
            rank,
            weight_path,
            local_main_process_only=False)

        # Save model config
        config_dict = generator_transformer.hf_config
        if "dtype" in config_dict:
            del config_dict["dtype"]  # TODO
        config_path = os.path.join(inference_save_dir, "config.json")
        # save dict as json
        with open(config_path, "w") as f:
            json.dump(config_dict, f, indent=4)
        logger.info("--> distillation checkpoint saved at step %s to %s", step,
                    weight_path)

        # Save generator_2 model weights (consolidated) for inference (MoE support)
        if generator_transformer_2 is not None:
            inference_save_dir_2 = os.path.join(
                save_dir, "generator_2_inference_transformer")
            cpu_state_2 = gather_state_dict_on_cpu_rank0(
                generator_transformer_2, device=None)

            if rank == 0:
                os.makedirs(inference_save_dir_2, exist_ok=True)
                weight_path_2 = os.path.join(
                    inference_save_dir_2, "diffusion_pytorch_model.safetensors")
                logger.info(
                    "rank: %s, saving consolidated generator_2 inference checkpoint to %s",
                    rank,
                    weight_path_2,
                    local_main_process_only=False)

                # Convert training format to diffusers format and save
                diffusers_state_dict_2 = custom_to_hf_state_dict(
                    cpu_state_2,
                    generator_transformer_2.reverse_param_names_mapping)
                save_file(diffusers_state_dict_2, weight_path_2)

                logger.info(
                    "rank: %s, consolidated generator_2 inference checkpoint saved to %s",
                    rank,
                    weight_path_2,
                    local_main_process_only=False)

                # Save model config
                config_dict_2 = generator_transformer_2.hf_config
                if "dtype" in config_dict_2:
                    del config_dict_2["dtype"]  # TODO
                config_path_2 = os.path.join(inference_save_dir_2,
                                             "config.json")
                with open(config_path_2, "w") as f:
                    json.dump(config_dict_2, f, indent=4)
                logger.info(
                    "--> generator_2 distillation checkpoint saved at step %s to %s",
                    step, weight_path_2)

fastvideo.training.wan_distillation_pipeline

Classes

fastvideo.training.wan_distillation_pipeline.WanDistillationPipeline
WanDistillationPipeline(model_path: str, fastvideo_args: TrainingArgs, required_config_modules: list[str] | None = None, loaded_modules: dict[str, Module] | None = None)

Bases: DistillationPipeline

A distillation pipeline for Wan that uses a single transformer model. The main transformer serves as the student model, and copies are made for teacher and critic.

Source code in fastvideo/training/training_pipeline.py
def __init__(
        self,
        model_path: str,
        fastvideo_args: TrainingArgs,
        required_config_modules: list[str] | None = None,
        loaded_modules: dict[str, torch.nn.Module] | None = None) -> None:
    fastvideo_args.inference_mode = False
    self.lora_training = fastvideo_args.lora_training
    if self.lora_training and fastvideo_args.lora_rank is None:
        raise ValueError("lora rank must be set when using lora training")

    set_random_seed(fastvideo_args.seed)  # for lora param init
    super().__init__(model_path, fastvideo_args, required_config_modules,
                     loaded_modules)  # type: ignore
    self.tracker = DummyTracker()
Functions
fastvideo.training.wan_distillation_pipeline.WanDistillationPipeline.create_training_stages
create_training_stages(training_args: TrainingArgs)

May be used in future refactors.

Source code in fastvideo/training/wan_distillation_pipeline.py
def create_training_stages(self, training_args: TrainingArgs):
    """
    May be used in future refactors.
    """
    pass
fastvideo.training.wan_distillation_pipeline.WanDistillationPipeline.initialize_pipeline
initialize_pipeline(fastvideo_args: FastVideoArgs)

Initialize Wan-specific scheduler.

Source code in fastvideo/training/wan_distillation_pipeline.py
def initialize_pipeline(self, fastvideo_args: FastVideoArgs):
    """Initialize Wan-specific scheduler."""
    self.modules["scheduler"] = FlowMatchEulerDiscreteScheduler(
        shift=fastvideo_args.pipeline_config.flow_shift)

Functions

fastvideo.training.wan_i2v_distillation_pipeline

Classes

fastvideo.training.wan_i2v_distillation_pipeline.WanI2VDistillationPipeline
WanI2VDistillationPipeline(model_path: str, fastvideo_args: TrainingArgs, required_config_modules: list[str] | None = None, loaded_modules: dict[str, Module] | None = None)

Bases: DistillationPipeline

A distillation pipeline for Wan that uses a single transformer model. The main transformer serves as the student model, and copies are made for teacher and critic.

Source code in fastvideo/training/training_pipeline.py
def __init__(
        self,
        model_path: str,
        fastvideo_args: TrainingArgs,
        required_config_modules: list[str] | None = None,
        loaded_modules: dict[str, torch.nn.Module] | None = None) -> None:
    fastvideo_args.inference_mode = False
    self.lora_training = fastvideo_args.lora_training
    if self.lora_training and fastvideo_args.lora_rank is None:
        raise ValueError("lora rank must be set when using lora training")

    set_random_seed(fastvideo_args.seed)  # for lora param init
    super().__init__(model_path, fastvideo_args, required_config_modules,
                     loaded_modules)  # type: ignore
    self.tracker = DummyTracker()
Functions
fastvideo.training.wan_i2v_distillation_pipeline.WanI2VDistillationPipeline.create_training_stages
create_training_stages(training_args: TrainingArgs)

May be used in future refactors.

Source code in fastvideo/training/wan_i2v_distillation_pipeline.py
def create_training_stages(self, training_args: TrainingArgs):
    """
    May be used in future refactors.
    """
    pass
fastvideo.training.wan_i2v_distillation_pipeline.WanI2VDistillationPipeline.initialize_pipeline
initialize_pipeline(fastvideo_args: FastVideoArgs)

Initialize Wan-specific scheduler.

Source code in fastvideo/training/wan_i2v_distillation_pipeline.py
def initialize_pipeline(self, fastvideo_args: FastVideoArgs):
    """Initialize Wan-specific scheduler."""
    self.modules["scheduler"] = FlowMatchEulerDiscreteScheduler(
        shift=fastvideo_args.pipeline_config.flow_shift)

Functions

fastvideo.training.wan_i2v_training_pipeline

Classes

fastvideo.training.wan_i2v_training_pipeline.WanI2VTrainingPipeline
WanI2VTrainingPipeline(model_path: str, fastvideo_args: TrainingArgs, required_config_modules: list[str] | None = None, loaded_modules: dict[str, Module] | None = None)

Bases: TrainingPipeline

A training pipeline for Wan.

Source code in fastvideo/training/training_pipeline.py
def __init__(
        self,
        model_path: str,
        fastvideo_args: TrainingArgs,
        required_config_modules: list[str] | None = None,
        loaded_modules: dict[str, torch.nn.Module] | None = None) -> None:
    fastvideo_args.inference_mode = False
    self.lora_training = fastvideo_args.lora_training
    if self.lora_training and fastvideo_args.lora_rank is None:
        raise ValueError("lora rank must be set when using lora training")

    set_random_seed(fastvideo_args.seed)  # for lora param init
    super().__init__(model_path, fastvideo_args, required_config_modules,
                     loaded_modules)  # type: ignore
    self.tracker = DummyTracker()
Functions
fastvideo.training.wan_i2v_training_pipeline.WanI2VTrainingPipeline.create_training_stages
create_training_stages(training_args: TrainingArgs)

May be used in future refactors.

Source code in fastvideo/training/wan_i2v_training_pipeline.py
def create_training_stages(self, training_args: TrainingArgs):
    """
    May be used in future refactors.
    """
    pass

Functions

fastvideo.training.wan_self_forcing_distillation_pipeline

Classes

fastvideo.training.wan_self_forcing_distillation_pipeline.WanSelfForcingDistillationPipeline
WanSelfForcingDistillationPipeline(model_path: str, fastvideo_args: TrainingArgs, required_config_modules: list[str] | None = None, loaded_modules: dict[str, Module] | None = None)

Bases: SelfForcingDistillationPipeline

A self-forcing distillation pipeline for Wan that uses the self-forcing methodology with DMD for video generation.

Source code in fastvideo/training/training_pipeline.py
def __init__(
        self,
        model_path: str,
        fastvideo_args: TrainingArgs,
        required_config_modules: list[str] | None = None,
        loaded_modules: dict[str, torch.nn.Module] | None = None) -> None:
    fastvideo_args.inference_mode = False
    self.lora_training = fastvideo_args.lora_training
    if self.lora_training and fastvideo_args.lora_rank is None:
        raise ValueError("lora rank must be set when using lora training")

    set_random_seed(fastvideo_args.seed)  # for lora param init
    super().__init__(model_path, fastvideo_args, required_config_modules,
                     loaded_modules)  # type: ignore
    self.tracker = DummyTracker()
Functions
fastvideo.training.wan_self_forcing_distillation_pipeline.WanSelfForcingDistillationPipeline.create_training_stages
create_training_stages(training_args: TrainingArgs)

May be used in future refactors.

Source code in fastvideo/training/wan_self_forcing_distillation_pipeline.py
def create_training_stages(self, training_args: TrainingArgs):
    """
    May be used in future refactors.
    """
    pass

Functions

fastvideo.training.wan_training_pipeline

Classes

fastvideo.training.wan_training_pipeline.WanTrainingPipeline
WanTrainingPipeline(model_path: str, fastvideo_args: TrainingArgs, required_config_modules: list[str] | None = None, loaded_modules: dict[str, Module] | None = None)

Bases: TrainingPipeline

A training pipeline for Wan.

Source code in fastvideo/training/training_pipeline.py
def __init__(
        self,
        model_path: str,
        fastvideo_args: TrainingArgs,
        required_config_modules: list[str] | None = None,
        loaded_modules: dict[str, torch.nn.Module] | None = None) -> None:
    fastvideo_args.inference_mode = False
    self.lora_training = fastvideo_args.lora_training
    if self.lora_training and fastvideo_args.lora_rank is None:
        raise ValueError("lora rank must be set when using lora training")

    set_random_seed(fastvideo_args.seed)  # for lora param init
    super().__init__(model_path, fastvideo_args, required_config_modules,
                     loaded_modules)  # type: ignore
    self.tracker = DummyTracker()
Functions
fastvideo.training.wan_training_pipeline.WanTrainingPipeline.create_training_stages
create_training_stages(training_args: TrainingArgs)

May be used in future refactors.

Source code in fastvideo/training/wan_training_pipeline.py
def create_training_stages(self, training_args: TrainingArgs):
    """
    May be used in future refactors.
    """
    pass

Functions