Skip to content

distillation_pipeline

Classes

fastvideo.training.distillation_pipeline.DistillationPipeline

DistillationPipeline(model_path: str, fastvideo_args: TrainingArgs, required_config_modules: list[str] | None = None, loaded_modules: dict[str, Module] | None = None)

Bases: TrainingPipeline

A distillation pipeline for training a 3 step model. Inherits from TrainingPipeline to reuse training infrastructure.

Source code in fastvideo/training/training_pipeline.py
def __init__(
        self,
        model_path: str,
        fastvideo_args: TrainingArgs,
        required_config_modules: list[str] | None = None,
        loaded_modules: dict[str, torch.nn.Module] | None = None) -> None:
    fastvideo_args.inference_mode = False
    self.lora_training = fastvideo_args.lora_training
    if self.lora_training and fastvideo_args.lora_rank is None:
        raise ValueError("lora rank must be set when using lora training")

    set_random_seed(fastvideo_args.seed)  # for lora param init
    super().__init__(model_path, fastvideo_args, required_config_modules,
                     loaded_modules)  # type: ignore
    self.tracker = DummyTracker()

Functions

fastvideo.training.distillation_pipeline.DistillationPipeline.apply_ema_to_model
apply_ema_to_model(model)

Apply EMA weights to the model for validation or inference.

Source code in fastvideo/training/distillation_pipeline.py
def apply_ema_to_model(self, model):
    """Apply EMA weights to the model for validation or inference."""
    if model is self.transformer and self.generator_ema is not None:
        with self.generator_ema.apply_to_model(model):
            return model
    elif model is self.transformer_2 and self.generator_ema_2 is not None:
        with self.generator_ema_2.apply_to_model(model):
            return model
    return model
fastvideo.training.distillation_pipeline.DistillationPipeline.get_ema_2_model_copy
get_ema_2_model_copy() -> Module | None

Get a copy of the transformer_2 model with EMA weights applied.

Source code in fastvideo/training/distillation_pipeline.py
def get_ema_2_model_copy(self) -> torch.nn.Module | None:
    """Get a copy of the transformer_2 model with EMA weights applied."""
    if self.generator_ema_2 is not None and self.transformer_2 is not None:
        ema_2_model = copy.deepcopy(self.transformer_2)
        self.generator_ema_2.copy_to_unwrapped(ema_2_model)
        return ema_2_model
    return None
fastvideo.training.distillation_pipeline.DistillationPipeline.get_ema_model_copy
get_ema_model_copy() -> Module | None

Get a copy of the model with EMA weights applied.

Source code in fastvideo/training/distillation_pipeline.py
def get_ema_model_copy(self) -> torch.nn.Module | None:
    """Get a copy of the model with EMA weights applied."""
    if self.generator_ema is not None:
        ema_model = copy.deepcopy(self.transformer)
        self.generator_ema.copy_to_unwrapped(ema_model)
        return ema_model
    return None
fastvideo.training.distillation_pipeline.DistillationPipeline.get_ema_stats
get_ema_stats() -> dict[str, Any]

Get EMA statistics for monitoring.

Source code in fastvideo/training/distillation_pipeline.py
def get_ema_stats(self) -> dict[str, Any]:
    """Get EMA statistics for monitoring."""
    ema_enabled = self.generator_ema is not None
    ema_2_enabled = self.generator_ema_2 is not None

    if not ema_enabled and not ema_2_enabled:
        return {
            "ema_enabled": False,
            "ema_2_enabled": False,
            "ema_decay": None,
            "ema_start_step": self.training_args.ema_start_step,
            "ema_ready": False,
            "ema_2_ready": False,
            "ema_step": self.current_trainstep,
        }

    return {
        "ema_enabled": ema_enabled,
        "ema_2_enabled": ema_2_enabled,
        "ema_decay": self.training_args.ema_decay,
        "ema_start_step": self.training_args.ema_start_step,
        "ema_ready": self.is_ema_ready() if ema_enabled else False,
        "ema_2_ready": self.is_ema_ready() if ema_2_enabled else False,
        "ema_step": self.current_trainstep,
    }
fastvideo.training.distillation_pipeline.DistillationPipeline.initialize_training_pipeline
initialize_training_pipeline(training_args: TrainingArgs)

Initialize the distillation training pipeline with multiple models.

Source code in fastvideo/training/distillation_pipeline.py
def initialize_training_pipeline(self, training_args: TrainingArgs):
    """Initialize the distillation training pipeline with multiple models."""
    logger.info("Initializing distillation pipeline...")

    super().initialize_training_pipeline(training_args)

    self.noise_scheduler = self.get_module("scheduler")
    self.vae = self.get_module("vae")
    self.vae.requires_grad_(False)

    self.timestep_shift = self.training_args.pipeline_config.flow_shift
    self.noise_scheduler = FlowMatchEulerDiscreteScheduler(
        shift=self.timestep_shift)

    if self.training_args.boundary_ratio is not None:
        self.boundary_timestep = self.training_args.boundary_ratio * self.noise_scheduler.num_train_timesteps
    else:
        self.boundary_timestep = None

    if training_args.real_score_model_path:
        logger.info("Loading real score transformer from: %s",
                    training_args.real_score_model_path)
        training_args.override_transformer_cls_name = "WanTransformer3DModel"
        self.real_score_transformer = self.load_module_from_path(
            training_args.real_score_model_path, "transformer",
            training_args)
        try:
            self.real_score_transformer_2 = self.load_module_from_path(
                training_args.real_score_model_path, "transformer_2",
                training_args)
            logger.info("Loaded real score transformer_2 for MoE support")
        except Exception:
            logger.info(
                "real score transformer_2 not found, using single transformer"
            )
            self.real_score_transformer_2 = None
    else:
        self.real_score_transformer = self.get_module(
            "real_score_transformer")
        self.real_score_transformer_2 = self.get_module(
            "real_score_transformer_2")

    if training_args.fake_score_model_path:
        logger.info("Loading fake score transformer from: %s",
                    training_args.fake_score_model_path)
        training_args.override_transformer_cls_name = "WanTransformer3DModel"
        self.fake_score_transformer = self.load_module_from_path(
            training_args.fake_score_model_path, "transformer",
            training_args)
        try:
            self.fake_score_transformer_2 = self.load_module_from_path(
                training_args.fake_score_model_path, "transformer_2",
                training_args)
            logger.info("Loaded fake score transformer_2 for MoE support")
        except Exception:
            logger.info(
                "fake score transformer_2 not found, using single transformer"
            )
            self.fake_score_transformer_2 = None
    else:
        self.fake_score_transformer = self.get_module(
            "fake_score_transformer")
        self.fake_score_transformer_2 = self.get_module(
            "fake_score_transformer_2")

    self.real_score_transformer.requires_grad_(False)
    self.real_score_transformer.eval()
    if self.real_score_transformer_2 is not None:
        self.real_score_transformer_2.requires_grad_(False)
        self.real_score_transformer_2.eval()

    # Set training modes for fake score transformers (trainable)
    self.fake_score_transformer.requires_grad_(True)
    self.fake_score_transformer.train()
    if self.fake_score_transformer_2 is not None:
        self.fake_score_transformer_2.requires_grad_(True)
        self.fake_score_transformer_2.train()

    if training_args.enable_gradient_checkpointing_type is not None:
        self.fake_score_transformer = apply_activation_checkpointing(
            self.fake_score_transformer,
            checkpointing_type=training_args.
            enable_gradient_checkpointing_type)
        if self.fake_score_transformer_2 is not None:
            self.fake_score_transformer_2 = apply_activation_checkpointing(
                self.fake_score_transformer_2,
                checkpointing_type=training_args.
                enable_gradient_checkpointing_type)

        self.real_score_transformer = apply_activation_checkpointing(
            self.real_score_transformer,
            checkpointing_type=training_args.
            enable_gradient_checkpointing_type)
        if self.real_score_transformer_2 is not None:
            self.real_score_transformer_2 = apply_activation_checkpointing(
                self.real_score_transformer_2,
                checkpointing_type=training_args.
                enable_gradient_checkpointing_type)

    # Initialize optimizers
    fake_score_params = list(
        filter(lambda p: p.requires_grad,
               self.fake_score_transformer.parameters()))

    # Use separate learning rate for fake_score_transformer if specified
    fake_score_lr = training_args.fake_score_learning_rate
    if fake_score_lr == 0.0:
        fake_score_lr = training_args.learning_rate

    betas_str = training_args.fake_score_betas
    betas = tuple(float(x.strip()) for x in betas_str.split(","))

    self.fake_score_optimizer = torch.optim.AdamW(
        fake_score_params,
        lr=fake_score_lr,
        betas=betas,
        weight_decay=training_args.weight_decay,
        eps=1e-8,
    )

    self.fake_score_lr_scheduler = get_scheduler(
        training_args.fake_score_lr_scheduler,
        optimizer=self.fake_score_optimizer,
        num_warmup_steps=training_args.lr_warmup_steps,
        num_training_steps=training_args.max_train_steps,
        num_cycles=training_args.lr_num_cycles,
        power=training_args.lr_power,
        min_lr_ratio=training_args.min_lr_ratio,
        last_epoch=self.init_steps - 1,
    )

    if self.fake_score_transformer_2 is not None:
        fake_score_params_2 = list(
            filter(lambda p: p.requires_grad,
                   self.fake_score_transformer_2.parameters()))
        self.fake_score_optimizer_2 = torch.optim.AdamW(
            fake_score_params_2,
            lr=fake_score_lr,
            betas=betas,
            weight_decay=training_args.weight_decay,
            eps=1e-8,
        )
        self.fake_score_lr_scheduler_2 = get_scheduler(
            training_args.fake_score_lr_scheduler,
            optimizer=self.fake_score_optimizer_2,
            num_warmup_steps=training_args.lr_warmup_steps,
            num_training_steps=training_args.max_train_steps,
            num_cycles=training_args.lr_num_cycles,
            power=training_args.lr_power,
            min_lr_ratio=training_args.min_lr_ratio,
            last_epoch=self.init_steps - 1,
        )

    logger.info(
        "Distillation optimizers initialized: generator and fake_score")

    self.generator_update_interval = self.training_args.generator_update_interval
    logger.info(
        "Distillation pipeline initialized with generator_update_interval=%s",
        self.generator_update_interval)

    self.denoising_step_list = torch.tensor(
        self.training_args.pipeline_config.dmd_denoising_steps,
        dtype=torch.long,
        device=get_local_torch_device())

    if training_args.warp_denoising_step:  # Warp the denoising step according to the scheduler time shift
        timesteps = torch.cat((self.noise_scheduler.timesteps.cpu(),
                               torch.tensor([0],
                                            dtype=torch.float32))).cuda()
        self.denoising_step_list = timesteps[1000 -
                                             self.denoising_step_list]
        logger.info("Warping denoising_step_list")

    self.denoising_step_list = self.denoising_step_list.to(
        get_local_torch_device())
    logger.info("Distillation generator model to %s denoising steps: %s",
                len(self.denoising_step_list), self.denoising_step_list)
    self.num_train_timestep = self.noise_scheduler.num_train_timesteps

    self.min_timestep = int(self.training_args.min_timestep_ratio *
                            self.num_train_timestep)
    self.max_timestep = int(self.training_args.max_timestep_ratio *
                            self.num_train_timestep)

    self.real_score_guidance_scale = self.training_args.real_score_guidance_scale

    self.generator_ema: EMA_FSDP | None = None
    self.generator_ema_2: EMA_FSDP | None = None
    if (self.training_args.ema_decay
            is not None) and (self.training_args.ema_decay > 0.0):
        self.generator_ema = EMA_FSDP(self.transformer,
                                      decay=self.training_args.ema_decay)
        logger.info("Initialized generator EMA with decay=%s",
                    self.training_args.ema_decay)

        # Initialize EMA for transformer_2 if it exists
        if self.transformer_2 is not None:
            self.generator_ema_2 = EMA_FSDP(
                self.transformer_2, decay=self.training_args.ema_decay)
            logger.info("Initialized generator EMA_2 with decay=%s",
                        self.training_args.ema_decay)
    else:
        logger.info("Generator EMA disabled (ema_decay <= 0.0)")
fastvideo.training.distillation_pipeline.DistillationPipeline.initialize_validation_pipeline abstractmethod
initialize_validation_pipeline(training_args: TrainingArgs)

Initialize validation pipeline - must be implemented by subclasses.

Source code in fastvideo/training/distillation_pipeline.py
@abstractmethod
def initialize_validation_pipeline(self, training_args: TrainingArgs):
    """Initialize validation pipeline - must be implemented by subclasses."""
    raise NotImplementedError(
        "Distillation pipelines must implement this method")
fastvideo.training.distillation_pipeline.DistillationPipeline.is_ema_ready
is_ema_ready(current_step: int | None = None)

Check if EMA is ready for use (after ema_start_step).

Source code in fastvideo/training/distillation_pipeline.py
def is_ema_ready(self, current_step: int | None = None):
    """Check if EMA is ready for use (after ema_start_step)."""
    if current_step is None:
        current_step = getattr(self, 'current_trainstep', 0)
    return (self.generator_ema is not None
            and current_step >= self.training_args.ema_start_step)
fastvideo.training.distillation_pipeline.DistillationPipeline.load_module_from_path
load_module_from_path(model_path: str, module_type: str, training_args: TrainingArgs)

Load a module from a specific path using the same loading logic as the pipeline.

Parameters:

Name Type Description Default
model_path str

Path to the model

required
module_type str

Type of module to load (e.g., "transformer")

required
training_args TrainingArgs

Training arguments

required

Returns:

Type Description

The loaded module

Source code in fastvideo/training/distillation_pipeline.py
def load_module_from_path(self, model_path: str, module_type: str,
                          training_args: "TrainingArgs"):
    """
    Load a module from a specific path using the same loading logic as the pipeline.

    Args:
        model_path: Path to the model
        module_type: Type of module to load (e.g., "transformer")
        training_args: Training arguments

    Returns:
        The loaded module
    """
    logger.info("Loading %s from custom path: %s", module_type, model_path)
    # Set flag to prevent custom weight loading for teacher/critic models
    training_args._loading_teacher_critic_model = True

    try:
        from fastvideo.models.loader.component_loader import (
            PipelineComponentLoader)

        # Download the model if it's a Hugging Face model ID
        local_model_path = maybe_download_model(model_path)
        logger.info("Model downloaded/found at: %s", local_model_path)
        config = verify_model_config_and_directory(local_model_path)

        if module_type not in config:
            if hasattr(self, '_extra_config_module_map'
                       ) and module_type in self._extra_config_module_map:
                extra_module = self._extra_config_module_map[module_type]
                if extra_module in config:
                    module_type = extra_module
                    logger.info("Using %s for %s", extra_module,
                                module_type)
                else:
                    raise ValueError(
                        f"Module {module_type} not found in config at {local_model_path}"
                    )
            else:
                raise ValueError(
                    f"Module {module_type} not found in config at {local_model_path}"
                )

        module_info = config[module_type]
        if module_info is None:
            raise ValueError(
                f"Module {module_type} has null value in config at {local_model_path}"
            )

        transformers_or_diffusers, architecture = module_info
        component_path = os.path.join(local_model_path, module_type)
        module = PipelineComponentLoader.load_module(
            module_name=module_type,
            component_model_path=component_path,
            transformers_or_diffusers=transformers_or_diffusers,
            fastvideo_args=training_args,
        )

        logger.info("Successfully loaded %s from %s", module_type,
                    component_path)
        return module
    finally:
        # Always clean up the flag
        if hasattr(training_args, '_loading_teacher_critic_model'):
            delattr(training_args, '_loading_teacher_critic_model')
fastvideo.training.distillation_pipeline.DistillationPipeline.reset_ema
reset_ema()

Reset EMA to current model weights.

Source code in fastvideo/training/distillation_pipeline.py
def reset_ema(self):
    """Reset EMA to current model weights."""
    if self.generator_ema is not None:
        logger.info("Resetting EMA to current model weights")
        self.generator_ema.update(self.transformer)
        # Force update to current weights by setting decay to 0 temporarily
        original_decay = self.generator_ema.decay
        self.generator_ema.decay = 0.0
        self.generator_ema.update(self.transformer)
        self.generator_ema.decay = original_decay
        logger.info("EMA reset completed")
    else:
        logger.warning("Cannot reset EMA: EMA not initialized")

    if self.generator_ema_2 is not None:
        logger.info("Resetting EMA_2 to current model weights")
        self.generator_ema_2.update(self.transformer_2)
        # Force update to current weights by setting decay to 0 temporarily
        original_decay_2 = self.generator_ema_2.decay
        self.generator_ema_2.decay = 0.0
        self.generator_ema_2.update(self.transformer_2)
        self.generator_ema_2.decay = original_decay_2
        logger.info("EMA_2 reset completed")
fastvideo.training.distillation_pipeline.DistillationPipeline.save_ema_weights
save_ema_weights(output_dir: str, step: int)

Save EMA weights separately for inference purposes.

Source code in fastvideo/training/distillation_pipeline.py
def save_ema_weights(self, output_dir: str, step: int):
    """Save EMA weights separately for inference purposes."""
    if self.generator_ema is None and self.generator_ema_2 is None:
        logger.warning("Cannot save EMA weights: No EMA initialized")
        return

    if not self.is_ema_ready():
        logger.warning(
            "Cannot save EMA weights: EMA not ready yet (step < ema_start_step)"
        )
        return

    try:
        # Save main transformer EMA
        if self.generator_ema is not None:
            ema_model = self.get_ema_model_copy()
            if ema_model is None:
                logger.warning("Failed to create EMA model copy")
            else:
                ema_save_dir = os.path.join(output_dir,
                                            f"ema_checkpoint-{step}")
                os.makedirs(ema_save_dir, exist_ok=True)

                # save as diffusers format
                from safetensors.torch import save_file

                from fastvideo.training.training_utils import (
                    custom_to_hf_state_dict, gather_state_dict_on_cpu_rank0)
                cpu_state = gather_state_dict_on_cpu_rank0(ema_model,
                                                           device=None)

                if self.global_rank == 0:
                    weight_path = os.path.join(
                        ema_save_dir, "diffusion_pytorch_model.safetensors")
                    diffusers_state_dict = custom_to_hf_state_dict(
                        cpu_state, ema_model.reverse_param_names_mapping)
                    save_file(diffusers_state_dict, weight_path)

                    config_dict = ema_model.hf_config
                    if "dtype" in config_dict:
                        del config_dict["dtype"]
                    config_path = os.path.join(ema_save_dir, "config.json")
                    with open(config_path, "w") as f:
                        json.dump(config_dict, f, indent=4)

                    logger.info("EMA weights saved to %s", weight_path)

                del ema_model

        # Save transformer_2 EMA
        if self.generator_ema_2 is not None:
            ema_2_model = self.get_ema_2_model_copy()
            if ema_2_model is None:
                logger.warning("Failed to create EMA_2 model copy")
            else:
                ema_2_save_dir = os.path.join(output_dir,
                                              f"ema_2_checkpoint-{step}")
                os.makedirs(ema_2_save_dir, exist_ok=True)

                # save as diffusers format
                from safetensors.torch import save_file

                from fastvideo.training.training_utils import (
                    custom_to_hf_state_dict, gather_state_dict_on_cpu_rank0)
                cpu_state_2 = gather_state_dict_on_cpu_rank0(ema_2_model,
                                                             device=None)

                if self.global_rank == 0:
                    weight_path_2 = os.path.join(
                        ema_2_save_dir,
                        "diffusion_pytorch_model.safetensors")
                    diffusers_state_dict_2 = custom_to_hf_state_dict(
                        cpu_state_2,
                        ema_2_model.reverse_param_names_mapping)
                    save_file(diffusers_state_dict_2, weight_path_2)

                    config_dict_2 = ema_2_model.hf_config
                    if "dtype" in config_dict_2:
                        del config_dict_2["dtype"]
                    config_path_2 = os.path.join(ema_2_save_dir,
                                                 "config.json")
                    with open(config_path_2, "w") as f:
                        json.dump(config_dict_2, f, indent=4)

                    logger.info("EMA_2 weights saved to %s", weight_path_2)

                del ema_2_model

    except Exception as e:
        logger.error("Failed to save EMA weights: %s", str(e))
fastvideo.training.distillation_pipeline.DistillationPipeline.train
train() -> None

Main training loop with distillation-specific logging.

Source code in fastvideo/training/distillation_pipeline.py
def train(self) -> None:
    """Main training loop with distillation-specific logging."""
    assert self.training_args.seed is not None, "seed must be set"
    seed = self.training_args.seed

    # Set the same seed within each SP group to ensure reproducibility
    if self.sp_world_size > 1:
        # Use the same seed for all processes within the same SP group
        sp_group_seed = seed + (self.global_rank // self.sp_world_size)
        set_random_seed(sp_group_seed)
        logger.info("Rank %s: Using SP group seed %s", self.global_rank,
                    sp_group_seed)
    else:
        set_random_seed(seed + self.global_rank)

    # Set random seeds for deterministic training
    self.noise_random_generator = torch.Generator(device="cpu").manual_seed(
        self.seed)
    self.noise_gen_cuda = torch.Generator(device="cuda").manual_seed(
        self.seed)
    self.validation_random_generator = torch.Generator(
        device="cpu").manual_seed(self.seed)
    logger.info("Initialized random seeds with seed: %s", seed)

    # Initialize current_trainstep for EMA ready checks
    #TODO: check if needed
    self.current_trainstep = self.init_steps

    # Resume from checkpoint if specified (this will restore random states)
    if self.training_args.resume_from_checkpoint:
        self._resume_from_checkpoint()
        logger.info("Resumed from checkpoint, random states restored")
    else:
        logger.info("Starting training from scratch")

    self.train_loader_iter = iter(self.train_dataloader)

    step_times: deque[float] = deque(maxlen=100)

    self._log_training_info()
    self._log_validation(self.transformer, self.training_args,
                         self.init_steps)

    progress_bar = tqdm(
        range(0, self.training_args.max_train_steps),
        initial=self.init_steps,
        desc="Steps",
        disable=self.local_rank > 0,
    )

    use_vsa = vsa_available and envs.FASTVIDEO_ATTENTION_BACKEND == "VIDEO_SPARSE_ATTN"
    for step in range(self.init_steps + 1,
                      self.training_args.max_train_steps + 1):
        start_time = time.perf_counter()
        if use_vsa:
            vsa_sparsity = self.training_args.VSA_sparsity
            vsa_decay_rate = self.training_args.VSA_decay_rate
            vsa_decay_interval_steps = self.training_args.VSA_decay_interval_steps
            if vsa_decay_interval_steps > 1:
                current_decay_times = min(step // vsa_decay_interval_steps,
                                          vsa_sparsity // vsa_decay_rate)
                current_vsa_sparsity = current_decay_times * vsa_decay_rate
            else:
                current_vsa_sparsity = vsa_sparsity
        else:
            current_vsa_sparsity = 0.0

        training_batch = TrainingBatch()
        self.current_trainstep = step
        training_batch.current_vsa_sparsity = current_vsa_sparsity

        if (step >= self.training_args.ema_start_step) and \
                (self.generator_ema is None) and (self.training_args.ema_decay > 0):
            self.generator_ema = EMA_FSDP(
                self.transformer, decay=self.training_args.ema_decay)
            logger.info("Created generator EMA at step %s with decay=%s",
                        step, self.training_args.ema_decay)

            # Create EMA for transformer_2 if it exists
            if self.transformer_2 is not None and self.generator_ema_2 is None:
                self.generator_ema_2 = EMA_FSDP(
                    self.transformer_2, decay=self.training_args.ema_decay)
                logger.info(
                    "Created generator EMA_2 at step %s with decay=%s",
                    step, self.training_args.ema_decay)

        with torch.autocast("cuda", dtype=torch.bfloat16):
            training_batch = self.train_one_step(training_batch)

        total_loss = training_batch.total_loss
        generator_loss = training_batch.generator_loss
        fake_score_loss = training_batch.fake_score_loss
        grad_norm = training_batch.grad_norm

        step_time = time.perf_counter() - start_time
        step_times.append(step_time)
        avg_step_time = sum(step_times) / len(step_times)

        progress_bar.set_postfix({
            "total_loss":
            f"{total_loss:.4f}",
            "generator_loss":
            f"{generator_loss:.4f}",
            "fake_score_loss":
            f"{fake_score_loss:.4f}",
            "step_time":
            f"{step_time:.2f}s",
            "grad_norm":
            grad_norm,
            "ema":
            "✓" if (self.generator_ema is not None and self.is_ema_ready())
            else "✗",
            "ema2":
            "✓" if (self.generator_ema_2 is not None
                    and self.is_ema_ready()) else "✗",
        })
        progress_bar.update(1)

        if self.global_rank == 0:
            # Prepare logging data
            log_data = {
                "train_total_loss":
                total_loss,
                "train_fake_score_loss":
                fake_score_loss,
                "learning_rate":
                self.lr_scheduler.get_last_lr()[0],
                "fake_score_learning_rate":
                self.fake_score_lr_scheduler.get_last_lr()[0],
                "step_time":
                step_time,
                "avg_step_time":
                avg_step_time,
                "grad_norm":
                grad_norm,
            }
            # Only log generator loss when generator is actually trained
            if (step % self.generator_update_interval == 0):
                log_data["train_generator_loss"] = generator_loss
            if use_vsa:
                log_data["VSA_train_sparsity"] = current_vsa_sparsity

            if self.generator_ema is not None or self.generator_ema_2 is not None:
                log_data["ema_enabled"] = self.generator_ema is not None
                log_data["ema_2_enabled"] = self.generator_ema_2 is not None
                log_data["ema_decay"] = self.training_args.ema_decay
            else:
                log_data["ema_enabled"] = False
                log_data["ema_2_enabled"] = False

            ema_stats = self.get_ema_stats()
            log_data.update(ema_stats)

            if training_batch.dmd_latent_vis_dict:
                dmd_additional_logs = {
                    "generator_timestep":
                    training_batch.
                    dmd_latent_vis_dict["generator_timestep"].item(),
                    "dmd_timestep":
                    training_batch.dmd_latent_vis_dict["dmd_timestep"].item(
                    ),
                }
                log_data.update(dmd_additional_logs)

            faker_score_additional_logs = {
                "fake_score_timestep":
                training_batch.
                fake_score_latent_vis_dict["fake_score_timestep"].item(),
            }
            log_data.update(faker_score_additional_logs)

            self.tracker.log(log_data, step)

        # Save training state checkpoint (for resuming training)
        if (self.training_args.training_state_checkpointing_steps > 0
                and step %
                self.training_args.training_state_checkpointing_steps == 0):
            print("rank", self.global_rank,
                  "save training state checkpoint at step", step)
            save_distillation_checkpoint(
                self.transformer,
                self.fake_score_transformer,
                self.global_rank,
                self.training_args.output_dir,
                step,
                self.optimizer,
                self.fake_score_optimizer,
                self.train_dataloader,
                self.lr_scheduler,
                self.fake_score_lr_scheduler,
                self.noise_random_generator,
                self.generator_ema,
                # MoE support
                generator_transformer_2=getattr(self, 'transformer_2',
                                                None),
                real_score_transformer_2=getattr(
                    self, 'real_score_transformer_2', None),
                fake_score_transformer_2=getattr(
                    self, 'fake_score_transformer_2', None),
                generator_optimizer_2=getattr(self, 'optimizer_2', None),
                fake_score_optimizer_2=getattr(self,
                                               'fake_score_optimizer_2',
                                               None),
                generator_scheduler_2=getattr(self, 'lr_scheduler_2', None),
                fake_score_scheduler_2=getattr(self,
                                               'fake_score_lr_scheduler_2',
                                               None),
                generator_ema_2=getattr(self, 'generator_ema_2', None))

            if self.transformer:
                self.transformer.train()
            self.sp_group.barrier()

        # Save weight-only checkpoint
        if (self.training_args.weight_only_checkpointing_steps > 0
                and step %
                self.training_args.weight_only_checkpointing_steps == 0):
            print("rank", self.global_rank,
                  "save weight-only checkpoint at step", step)
            save_distillation_checkpoint(
                self.transformer,
                self.fake_score_transformer,
                self.global_rank,
                self.training_args.output_dir,
                f"{step}_weight_only",
                only_save_generator_weight=True,
                generator_ema=self.generator_ema,
                # MoE support
                generator_transformer_2=getattr(self, 'transformer_2',
                                                None),
                real_score_transformer_2=getattr(
                    self, 'real_score_transformer_2', None),
                fake_score_transformer_2=getattr(
                    self, 'fake_score_transformer_2', None),
                generator_optimizer_2=getattr(self, 'optimizer_2', None),
                fake_score_optimizer_2=getattr(self,
                                               'fake_score_optimizer_2',
                                               None),
                generator_scheduler_2=getattr(self, 'lr_scheduler_2', None),
                fake_score_scheduler_2=getattr(self,
                                               'fake_score_lr_scheduler_2',
                                               None),
                generator_ema_2=getattr(self, 'generator_ema_2', None))

            if self.training_args.use_ema and self.is_ema_ready():
                self.save_ema_weights(self.training_args.output_dir, step)

        if self.training_args.log_validation and step % self.training_args.validation_steps == 0:
            if self.training_args.log_visualization:
                self.visualize_intermediate_latents(training_batch,
                                                    self.training_args,
                                                    step)
            self._log_validation(self.transformer, self.training_args, step)

    self.tracker.finish()

    # Save final training state checkpoint
    print("rank", self.global_rank,
          "save final training state checkpoint at step",
          self.training_args.max_train_steps)
    save_distillation_checkpoint(
        self.transformer,
        self.fake_score_transformer,
        self.global_rank,
        self.training_args.output_dir,
        self.training_args.max_train_steps,
        self.optimizer,
        self.fake_score_optimizer,
        self.train_dataloader,
        self.lr_scheduler,
        self.fake_score_lr_scheduler,
        self.noise_random_generator,
        self.generator_ema,
        # MoE support
        generator_transformer_2=getattr(self, 'transformer_2', None),
        real_score_transformer_2=getattr(self, 'real_score_transformer_2',
                                         None),
        fake_score_transformer_2=getattr(self, 'fake_score_transformer_2',
                                         None),
        generator_optimizer_2=getattr(self, 'optimizer_2', None),
        fake_score_optimizer_2=getattr(self, 'fake_score_optimizer_2',
                                       None),
        generator_scheduler_2=getattr(self, 'lr_scheduler_2', None),
        fake_score_scheduler_2=getattr(self, 'fake_score_lr_scheduler_2',
                                       None),
        generator_ema_2=getattr(self, 'generator_ema_2', None))

    if self.training_args.use_ema and self.is_ema_ready():
        self.save_ema_weights(self.training_args.output_dir,
                              self.training_args.max_train_steps)

    if envs.FASTVIDEO_TORCH_PROFILER_DIR:
        logger.info("Stopping profiler...")
        self.profiler_controller.stop()
        logger.info("Profiler stopped.")

    if get_sp_group():
        cleanup_dist_env_and_memory()
fastvideo.training.distillation_pipeline.DistillationPipeline.visualize_intermediate_latents
visualize_intermediate_latents(training_batch: TrainingBatch, training_args: TrainingArgs, step: int)

Add visualization data to tracker logging and save frames to disk.

Source code in fastvideo/training/distillation_pipeline.py
def visualize_intermediate_latents(self, training_batch: TrainingBatch,
                                   training_args: TrainingArgs, step: int):
    """Add visualization data to tracker logging and save frames to disk."""
    tracker_loss_dict: dict[str, Any] = {}
    dmd_latents_vis_dict = training_batch.dmd_latent_vis_dict
    fake_score_latents_vis_dict = training_batch.fake_score_latent_vis_dict
    fake_score_log_keys = ['generator_pred_video']
    dmd_log_keys = ['faker_score_pred_video', 'real_score_pred_video']

    for latent_key in fake_score_log_keys:
        latents = fake_score_latents_vis_dict[latent_key]
        latents = latents.permute(0, 2, 1, 3, 4)

        if isinstance(self.vae.scaling_factor, torch.Tensor):
            latents = latents / self.vae.scaling_factor.to(
                latents.device, latents.dtype)
        else:
            latents = latents / self.vae.scaling_factor

        # Apply shifting if needed
        if (hasattr(self.vae, "shift_factor")
                and self.vae.shift_factor is not None):
            if isinstance(self.vae.shift_factor, torch.Tensor):
                latents += self.vae.shift_factor.to(latents.device,
                                                    latents.dtype)
            else:
                latents += self.vae.shift_factor
            with torch.autocast("cuda", dtype=torch.bfloat16):
                video = self.vae.decode(latents)
            video = (video / 2 + 0.5).clamp(0, 1)
            video = video.cpu().float()
            video = video.permute(0, 2, 1, 3, 4)
            video = (video * 255).numpy().astype(np.uint8)
            video_artifact = self.tracker.video(
                video, fps=24, format="mp4")  # change to 16 for Wan2.1
            if video_artifact is not None:
                tracker_loss_dict[latent_key] = video_artifact
            # Clean up references
            del video, latents

    # Process DMD training data if available - use decode_stage instead of self.vae.decode
    if 'generator_pred_video' in dmd_latents_vis_dict:
        for latent_key in dmd_log_keys:
            latents = dmd_latents_vis_dict[latent_key]
            latents = latents.permute(0, 2, 1, 3, 4)
            # decoded_latent = decode_stage(ForwardBatch(data_type="video", latents=latents), training_args)
            if isinstance(self.vae.scaling_factor, torch.Tensor):
                latents = latents / self.vae.scaling_factor.to(
                    latents.device, latents.dtype)
            else:
                latents = latents / self.vae.scaling_factor

            # Apply shifting if needed
            if (hasattr(self.vae, "shift_factor")
                    and self.vae.shift_factor is not None):
                if isinstance(self.vae.shift_factor, torch.Tensor):
                    latents += self.vae.shift_factor.to(
                        latents.device, latents.dtype)
                else:
                    latents += self.vae.shift_factor
            with torch.autocast("cuda", dtype=torch.bfloat16):
                video = self.vae.decode(latents)
            video = (video / 2 + 0.5).clamp(0, 1)
            video = video.cpu().float()
            video = video.permute(0, 2, 1, 3, 4)
            video = (video * 255).numpy().astype(np.uint8)
            video_artifact = self.tracker.video(
                video, fps=24, format="mp4")  # change to 16 for Wan2.1
            if video_artifact is not None:
                tracker_loss_dict[latent_key] = video_artifact
            # Clean up references
            del video, latents

    # Log to tracker
    if self.global_rank == 0 and tracker_loss_dict:
        self.tracker.log_artifacts(tracker_loss_dict, step)

Functions