Skip to content

self_forcing_distillation_pipeline

Classes

fastvideo.training.self_forcing_distillation_pipeline.SelfForcingDistillationPipeline

SelfForcingDistillationPipeline(model_path: str, fastvideo_args: TrainingArgs, required_config_modules: list[str] | None = None, loaded_modules: dict[str, Module] | None = None)

Bases: DistillationPipeline

A self-forcing distillation pipeline that alternates between training the generator and critic based on the self-forcing methodology.

This implementation follows the self-forcing approach where: 1. Generator and critic are trained in alternating steps 2. Generator loss uses DMD-style loss with the critic as fake score 3. Critic loss trains the fake score model to distinguish real vs fake

Source code in fastvideo/training/training_pipeline.py
def __init__(
        self,
        model_path: str,
        fastvideo_args: TrainingArgs,
        required_config_modules: list[str] | None = None,
        loaded_modules: dict[str, torch.nn.Module] | None = None) -> None:
    fastvideo_args.inference_mode = False
    self.lora_training = fastvideo_args.lora_training
    if self.lora_training and fastvideo_args.lora_rank is None:
        raise ValueError("lora rank must be set when using lora training")

    set_random_seed(fastvideo_args.seed)  # for lora param init
    super().__init__(model_path, fastvideo_args, required_config_modules,
                     loaded_modules)  # type: ignore
    self.tracker = DummyTracker()

Functions

fastvideo.training.self_forcing_distillation_pipeline.SelfForcingDistillationPipeline.critic_loss
critic_loss(training_batch: TrainingBatch) -> tuple[Tensor, dict[str, Any]]

Compute critic loss using flow matching between noise and generator output. The critic learns to predict the flow from noise to the generator's output.

Source code in fastvideo/training/self_forcing_distillation_pipeline.py
def critic_loss(
        self, training_batch: TrainingBatch
) -> tuple[torch.Tensor, dict[str, Any]]:
    """
    Compute critic loss using flow matching between noise and generator output.
    The critic learns to predict the flow from noise to the generator's output.
    """
    updated_batch, flow_matching_loss = self.faker_score_forward(
        training_batch)
    training_batch.fake_score_latent_vis_dict = updated_batch.fake_score_latent_vis_dict
    log_dict: dict[str, Any] = {}

    return flow_matching_loss, log_dict
fastvideo.training.self_forcing_distillation_pipeline.SelfForcingDistillationPipeline.generate_and_sync_list
generate_and_sync_list(num_blocks: int, num_denoising_steps: int, device: device) -> list[int]

Generate and synchronize random exit flags across distributed processes.

Source code in fastvideo/training/self_forcing_distillation_pipeline.py
def generate_and_sync_list(self, num_blocks: int, num_denoising_steps: int,
                           device: torch.device) -> list[int]:
    """Generate and synchronize random exit flags across distributed processes."""
    logger.info(
        "RANK: %s, enter generate_and_sync_list blocks=%s steps=%s device=%s",
        self.global_rank,
        num_blocks,
        num_denoising_steps,
        str(device),
        local_main_process_only=False)
    rank = dist.get_rank() if dist.is_initialized() else 0

    if rank == 0:
        # Generate random indices
        indices = torch.randint(low=0,
                                high=num_denoising_steps,
                                size=(num_blocks, ),
                                device=device)
        if self.last_step_only:
            indices = torch.ones_like(indices) * (num_denoising_steps - 1)
    else:
        indices = torch.empty(num_blocks, dtype=torch.long, device=device)

    if dist.is_initialized():
        dist.broadcast(indices,
                       src=0)  # Broadcast the random indices to all ranks
    flags = indices.tolist()
    logger.info(
        "RANK: %s, exit generate_and_sync_list flags_len=%s first=%s",
        self.global_rank,
        len(flags),
        flags[0] if len(flags) > 0 else None,
        local_main_process_only=False)
    return flags
fastvideo.training.self_forcing_distillation_pipeline.SelfForcingDistillationPipeline.generator_loss
generator_loss(training_batch: TrainingBatch) -> tuple[Tensor, dict[str, Any]]

Compute generator loss using DMD-style approach. The generator tries to fool the critic (fake_score_transformer).

Source code in fastvideo/training/self_forcing_distillation_pipeline.py
def generator_loss(
        self, training_batch: TrainingBatch
) -> tuple[torch.Tensor, dict[str, Any]]:
    """
    Compute generator loss using DMD-style approach.
    The generator tries to fool the critic (fake_score_transformer).
    """
    with set_forward_context(
            current_timestep=training_batch.timesteps,
            attn_metadata=training_batch.attn_metadata_vsa):
        generator_pred_video = self._generator_multi_step_simulation_forward(
            training_batch)

    with set_forward_context(current_timestep=training_batch.timesteps,
                             attn_metadata=training_batch.attn_metadata):
        dmd_loss = self._dmd_forward(
            generator_pred_video=generator_pred_video,
            training_batch=training_batch)

    log_dict = {
        "dmdtrain_gradient_norm": torch.tensor(0.0, device=self.device)
    }

    return dmd_loss, log_dict
fastvideo.training.self_forcing_distillation_pipeline.SelfForcingDistillationPipeline.initialize_training_pipeline
initialize_training_pipeline(training_args: TrainingArgs)

Initialize the self-forcing training pipeline.

Source code in fastvideo/training/self_forcing_distillation_pipeline.py
def initialize_training_pipeline(self, training_args: TrainingArgs):
    """Initialize the self-forcing training pipeline."""
    # Check if FSDP2 auto wrap is enabled - not supported for self-forcing distillation
    if os.environ.get("FASTVIDEO_FSDP2_AUTOWRAP", "0") == "1":
        raise NotImplementedError(
            "FASTVIDEO_FSDP2_AUTOWRAP is not implemented for self-forcing distillation. "
            "Please set FASTVIDEO_FSDP2_AUTOWRAP=0 or unset the environment variable."
        )

    logger.info("Initializing self-forcing distillation pipeline...")

    self.generator_ema: EMA_FSDP | None = None
    self.generator_ema_2: EMA_FSDP | None = None

    super().initialize_training_pipeline(training_args)
    try:
        logger.info("RANK: %s, entered initialize_training_pipeline",
                    self.global_rank,
                    local_main_process_only=False)
    except Exception:
        logger.info("Entered initialize_training_pipeline (rank unknown)")

    self.noise_scheduler = SelfForcingFlowMatchScheduler(
        num_inference_steps=1000,
        shift=5.0,
        sigma_min=0.0,
        extra_one_step=True,
        training=True)
    self.dfake_gen_update_ratio = getattr(training_args,
                                          'dfake_gen_update_ratio', 5)

    self.num_frame_per_block = getattr(training_args, 'num_frame_per_block',
                                       3)
    self.independent_first_frame = getattr(training_args,
                                           'independent_first_frame', False)
    self.same_step_across_blocks = getattr(training_args,
                                           'same_step_across_blocks', False)
    self.last_step_only = getattr(training_args, 'last_step_only', False)
    self.context_noise = getattr(training_args, 'context_noise', 0)

    self.kv_cache1: list[dict[str, Any]] | None = None
    self.crossattn_cache: list[dict[str, Any]] | None = None

    logger.info("Self-forcing generator update ratio: %s",
                self.dfake_gen_update_ratio)
    logger.info("RANK: %s, exiting initialize_training_pipeline",
                self.global_rank,
                local_main_process_only=False)
fastvideo.training.self_forcing_distillation_pipeline.SelfForcingDistillationPipeline.train
train() -> None

Main training loop with self-forcing specific logging.

Source code in fastvideo/training/self_forcing_distillation_pipeline.py
@profile_region("profiler_region_training_train")
def train(self) -> None:
    """Main training loop with self-forcing specific logging."""
    assert self.training_args.seed is not None, "seed must be set"
    seed = self.training_args.seed

    # Set the same seed within each SP group to ensure reproducibility
    if self.sp_world_size > 1:
        # Use the same seed for all processes within the same SP group
        sp_group_seed = seed + (self.global_rank // self.sp_world_size)
        set_random_seed(sp_group_seed)
    else:
        set_random_seed(seed + self.global_rank)

    self.noise_random_generator = torch.Generator(device="cpu").manual_seed(
        self.seed)
    self.noise_gen_cuda = torch.Generator(device="cuda").manual_seed(
        self.seed)
    self.validation_random_generator = torch.Generator(
        device="cpu").manual_seed(self.seed)
    logger.info("Initialized random seeds with seed: %s", seed)

    self.current_trainstep = self.init_steps

    if self.training_args.resume_from_checkpoint:
        self._resume_from_checkpoint()
        logger.info("Resumed from checkpoint, random states restored")
    else:
        logger.info("Starting training from scratch")

    self.train_loader_iter = iter(self.train_dataloader)

    step_times: deque[float] = deque(maxlen=100)

    self._log_training_info()
    self._log_validation(self.transformer, self.training_args,
                         self.init_steps)

    progress_bar = tqdm(
        range(0, self.training_args.max_train_steps),
        initial=self.init_steps,
        desc="Steps",
        disable=self.local_rank > 0,
    )

    use_vsa = vsa_available and envs.FASTVIDEO_ATTENTION_BACKEND == "VIDEO_SPARSE_ATTN"
    for step in range(self.init_steps + 1,
                      self.training_args.max_train_steps + 1):
        start_time = time.perf_counter()
        if use_vsa:
            vsa_sparsity = self.training_args.VSA_sparsity
            vsa_decay_rate = self.training_args.VSA_decay_rate
            vsa_decay_interval_steps = self.training_args.VSA_decay_interval_steps
            if vsa_decay_interval_steps > 1:
                current_decay_times = min(step // vsa_decay_interval_steps,
                                          vsa_sparsity // vsa_decay_rate)
                current_vsa_sparsity = current_decay_times * vsa_decay_rate
            else:
                current_vsa_sparsity = vsa_sparsity
        else:
            current_vsa_sparsity = 0.0

        training_batch = TrainingBatch()
        self.current_trainstep = step
        training_batch.current_vsa_sparsity = current_vsa_sparsity

        if (step >= self.training_args.ema_start_step) and \
                (self.generator_ema is None) and (self.training_args.ema_decay > 0):
            self.generator_ema = EMA_FSDP(
                self.transformer, decay=self.training_args.ema_decay)
            logger.info("Created generator EMA at step %s with decay=%s",
                        step, self.training_args.ema_decay)

            # Create EMA for transformer_2 if it exists
            if self.transformer_2 is not None and self.generator_ema_2 is None:
                self.generator_ema_2 = EMA_FSDP(
                    self.transformer_2, decay=self.training_args.ema_decay)
                logger.info(
                    "Created generator EMA_2 at step %s with decay=%s",
                    step, self.training_args.ema_decay)

        with torch.autocast("cuda", dtype=torch.bfloat16):
            training_batch = self.train_one_step(training_batch)

        total_loss = training_batch.total_loss
        generator_loss = training_batch.generator_loss
        fake_score_loss = training_batch.fake_score_loss
        grad_norm = training_batch.grad_norm

        step_time = time.perf_counter() - start_time
        step_times.append(step_time)
        avg_step_time = sum(step_times) / len(step_times)

        progress_bar.set_postfix({
            "total_loss":
            f"{total_loss:.4f}",
            "generator_loss":
            f"{generator_loss:.4f}",
            "fake_score_loss":
            f"{fake_score_loss:.4f}",
            "step_time":
            f"{step_time:.2f}s",
            "grad_norm":
            grad_norm,
            "ema":
            "✓" if (self.generator_ema is not None and self.is_ema_ready())
            else "✗",
            "ema2":
            "✓" if (self.generator_ema_2 is not None
                    and self.is_ema_ready()) else "✗",
        })
        progress_bar.update(1)

        if self.global_rank == 0:
            log_data = {
                "train_total_loss":
                total_loss,
                "train_fake_score_loss":
                fake_score_loss,
                "learning_rate":
                self.lr_scheduler.get_last_lr()[0],
                "fake_score_learning_rate":
                self.fake_score_lr_scheduler.get_last_lr()[0],
                "step_time":
                step_time,
                "avg_step_time":
                avg_step_time,
                "grad_norm":
                grad_norm,
            }
            if (step % self.dfake_gen_update_ratio == 0):
                log_data["train_generator_loss"] = generator_loss
            if use_vsa:
                log_data["VSA_train_sparsity"] = current_vsa_sparsity

            if self.generator_ema is not None or self.generator_ema_2 is not None:
                log_data["ema_enabled"] = self.generator_ema is not None
                log_data["ema_2_enabled"] = self.generator_ema_2 is not None
                log_data["ema_decay"] = self.training_args.ema_decay
            else:
                log_data["ema_enabled"] = False
                log_data["ema_2_enabled"] = False

            ema_stats = self.get_ema_stats()
            log_data.update(ema_stats)

            if training_batch.dmd_latent_vis_dict:
                dmd_additional_logs = {
                    "generator_timestep":
                    training_batch.
                    dmd_latent_vis_dict["generator_timestep"].item(),
                    "dmd_timestep":
                    training_batch.dmd_latent_vis_dict["dmd_timestep"].item(
                    ),
                }
                log_data.update(dmd_additional_logs)

            faker_score_additional_logs = {
                "fake_score_timestep":
                training_batch.
                fake_score_latent_vis_dict["fake_score_timestep"].item(),
            }
            log_data.update(faker_score_additional_logs)

            self.tracker.log(log_data, step)

            if self.training_args.log_validation and step % self.training_args.validation_steps == 0 and self.training_args.log_visualization:
                self.visualize_intermediate_latents(training_batch,
                                                    self.training_args,
                                                    step)

        if (self.training_args.training_state_checkpointing_steps > 0
                and step %
                self.training_args.training_state_checkpointing_steps == 0):
            print("rank", self.global_rank,
                  "save training state checkpoint at step", step)
            save_distillation_checkpoint(
                self.transformer,
                self.fake_score_transformer,
                self.global_rank,
                self.training_args.output_dir,
                step,
                self.optimizer,
                self.fake_score_optimizer,
                self.train_dataloader,
                self.lr_scheduler,
                self.fake_score_lr_scheduler,
                self.noise_random_generator,
                self.generator_ema,
                # MoE support
                generator_transformer_2=getattr(self, 'transformer_2',
                                                None),
                real_score_transformer_2=getattr(
                    self, 'real_score_transformer_2', None),
                fake_score_transformer_2=getattr(
                    self, 'fake_score_transformer_2', None),
                generator_optimizer_2=getattr(self, 'optimizer_2', None),
                fake_score_optimizer_2=getattr(self,
                                               'fake_score_optimizer_2',
                                               None),
                generator_scheduler_2=getattr(self, 'lr_scheduler_2', None),
                fake_score_scheduler_2=getattr(self,
                                               'fake_score_lr_scheduler_2',
                                               None),
                generator_ema_2=getattr(self, 'generator_ema_2', None))

            if self.transformer:
                self.transformer.train()
            self.sp_group.barrier()

        if (self.training_args.weight_only_checkpointing_steps > 0
                and step %
                self.training_args.weight_only_checkpointing_steps == 0):
            print("rank", self.global_rank,
                  "save weight-only checkpoint at step", step)
            save_distillation_checkpoint(
                self.transformer,
                self.fake_score_transformer,
                self.global_rank,
                self.training_args.output_dir,
                f"{step}_weight_only",
                only_save_generator_weight=True,
                generator_ema=self.generator_ema,
                # MoE support
                generator_transformer_2=getattr(self, 'transformer_2',
                                                None),
                real_score_transformer_2=getattr(
                    self, 'real_score_transformer_2', None),
                fake_score_transformer_2=getattr(
                    self, 'fake_score_transformer_2', None),
                generator_optimizer_2=getattr(self, 'optimizer_2', None),
                fake_score_optimizer_2=getattr(self,
                                               'fake_score_optimizer_2',
                                               None),
                generator_scheduler_2=getattr(self, 'lr_scheduler_2', None),
                fake_score_scheduler_2=getattr(self,
                                               'fake_score_lr_scheduler_2',
                                               None),
                generator_ema_2=getattr(self, 'generator_ema_2', None))

            if self.training_args.use_ema and self.is_ema_ready():
                self.save_ema_weights(self.training_args.output_dir, step)

        if self.training_args.log_validation and step % self.training_args.validation_steps == 0:
            self._log_validation(self.transformer, self.training_args, step)

    self.tracker.finish()

    print("rank", self.global_rank,
          "save final training state checkpoint at step",
          self.training_args.max_train_steps)
    save_distillation_checkpoint(
        self.transformer,
        self.fake_score_transformer,
        self.global_rank,
        self.training_args.output_dir,
        self.training_args.max_train_steps,
        self.optimizer,
        self.fake_score_optimizer,
        self.train_dataloader,
        self.lr_scheduler,
        self.fake_score_lr_scheduler,
        self.noise_random_generator,
        self.generator_ema,
        # MoE support
        generator_transformer_2=getattr(self, 'transformer_2', None),
        real_score_transformer_2=getattr(self, 'real_score_transformer_2',
                                         None),
        fake_score_transformer_2=getattr(self, 'fake_score_transformer_2',
                                         None),
        generator_optimizer_2=getattr(self, 'optimizer_2', None),
        fake_score_optimizer_2=getattr(self, 'fake_score_optimizer_2',
                                       None),
        generator_scheduler_2=getattr(self, 'lr_scheduler_2', None),
        fake_score_scheduler_2=getattr(self, 'fake_score_lr_scheduler_2',
                                       None),
        generator_ema_2=getattr(self, 'generator_ema_2', None))

    if self.training_args.use_ema and self.is_ema_ready():
        self.save_ema_weights(self.training_args.output_dir,
                              self.training_args.max_train_steps)

    if envs.FASTVIDEO_TORCH_PROFILER_DIR:
        logger.info("Stopping profiler...")
        self.profiler_controller.stop()
        logger.info("Profiler stopped.")

    if get_sp_group():
        cleanup_dist_env_and_memory()
fastvideo.training.self_forcing_distillation_pipeline.SelfForcingDistillationPipeline.train_one_step
train_one_step(training_batch: TrainingBatch) -> TrainingBatch

Self-forcing training step that alternates between generator and critic training.

Source code in fastvideo/training/self_forcing_distillation_pipeline.py
def train_one_step(self, training_batch: TrainingBatch) -> TrainingBatch:
    """
    Self-forcing training step that alternates between generator and critic training.
    """
    gradient_accumulation_steps = getattr(self.training_args,
                                          'gradient_accumulation_steps', 1)
    train_generator = (self.current_trainstep %
                       self.dfake_gen_update_ratio == 0)

    batches = []
    for _ in range(gradient_accumulation_steps):
        batch = self._prepare_distillation(training_batch)
        batch = self._get_next_batch(batch)
        batch = self._normalize_dit_input(batch)
        batch = self._prepare_dit_inputs(batch)
        batch = self._build_attention_metadata(batch)
        batch.attn_metadata_vsa = copy.deepcopy(batch.attn_metadata)
        if batch.attn_metadata is not None:
            batch.attn_metadata.VSA_sparsity = 0.0
        batches.append(batch)

    training_batch.dmd_latent_vis_dict = {}
    training_batch.fake_score_latent_vis_dict = {}

    if train_generator:
        logger.debug("Training generator at step %s",
                     self.current_trainstep)
        self.optimizer.zero_grad()
        if self.transformer_2 is not None:
            self.optimizer_2.zero_grad()
        total_generator_loss = 0.0
        generator_log_dict = {}

        for batch in batches:
            # Create a new batch with detached tensors
            batch_gen = TrainingBatch()
            for key, value in batch.__dict__.items():
                if isinstance(value, torch.Tensor):
                    setattr(batch_gen, key, value.detach().clone())
                elif isinstance(value, dict):
                    setattr(
                        batch_gen, key, {
                            k:
                            v.detach().clone() if isinstance(
                                v, torch.Tensor) else copy.deepcopy(v)
                            for k, v in value.items()
                        })
                else:
                    setattr(batch_gen, key, copy.deepcopy(value))

            generator_loss, gen_log_dict = self.generator_loss(batch_gen)
            with set_forward_context(current_timestep=batch_gen.timesteps,
                                     attn_metadata=batch_gen.attn_metadata):
                (generator_loss / gradient_accumulation_steps).backward()
            total_generator_loss += generator_loss.detach().item()
            generator_log_dict.update(gen_log_dict)
            # Store visualization data from generator training
            if hasattr(batch_gen, 'dmd_latent_vis_dict'):
                training_batch.dmd_latent_vis_dict.update(
                    batch_gen.dmd_latent_vis_dict)

        # Only clip gradients and step optimizer for the model that is currently training
        if hasattr(
                self, 'train_transformer_2'
        ) and self.train_transformer_2 and self.transformer_2 is not None:
            self._clip_model_grad_norm_(batch_gen, self.transformer_2)
            self.optimizer_2.step()
            self.lr_scheduler_2.step()
        else:
            self._clip_model_grad_norm_(batch_gen, self.transformer)
            self.optimizer.step()
            self.lr_scheduler.step()

        if self.generator_ema is not None:
            if hasattr(
                    self, 'train_transformer_2'
            ) and self.train_transformer_2 and self.transformer_2 is not None:
                # Update EMA for transformer_2 when training it
                if self.generator_ema_2 is not None:
                    self.generator_ema_2.update(self.transformer_2)
            else:
                self.generator_ema.update(self.transformer)

        avg_generator_loss = torch.tensor(total_generator_loss /
                                          gradient_accumulation_steps,
                                          device=self.device)
        world_group = get_world_group()
        world_group.all_reduce(avg_generator_loss,
                               op=torch.distributed.ReduceOp.AVG)
        training_batch.generator_loss = avg_generator_loss.item()
    else:
        training_batch.generator_loss = 0.0

    logger.debug("Training critic at step %s", self.current_trainstep)
    self.fake_score_optimizer.zero_grad()
    total_critic_loss = 0.0
    critic_log_dict = {}

    for batch in batches:
        # Create a new batch with detached tensors
        batch_critic = TrainingBatch()
        for key, value in batch.__dict__.items():
            if isinstance(value, torch.Tensor):
                setattr(batch_critic, key, value.detach().clone())
            elif isinstance(value, dict):
                setattr(
                    batch_critic, key, {
                        k:
                        v.detach().clone()
                        if isinstance(v, torch.Tensor) else copy.deepcopy(v)
                        for k, v in value.items()
                    })
            else:
                setattr(batch_critic, key, copy.deepcopy(value))

        critic_loss, crit_log_dict = self.critic_loss(batch_critic)
        with set_forward_context(current_timestep=batch_critic.timesteps,
                                 attn_metadata=batch_critic.attn_metadata):
            (critic_loss / gradient_accumulation_steps).backward()
        total_critic_loss += critic_loss.detach().item()
        critic_log_dict.update(crit_log_dict)
        # Store visualization data from critic training
        if hasattr(batch_critic, 'fake_score_latent_vis_dict'):
            training_batch.fake_score_latent_vis_dict.update(
                batch_critic.fake_score_latent_vis_dict)

    if self.train_fake_score_transformer_2 and self.fake_score_transformer_2 is not None:
        self._clip_model_grad_norm_(batch_critic,
                                    self.fake_score_transformer_2)
        self.fake_score_optimizer_2.step()
        self.fake_score_lr_scheduler_2.step()
    else:
        self._clip_model_grad_norm_(batch_critic,
                                    self.fake_score_transformer)
        self.fake_score_optimizer.step()
        self.fake_score_lr_scheduler.step()

    avg_critic_loss = torch.tensor(total_critic_loss /
                                   gradient_accumulation_steps,
                                   device=self.device)
    world_group = get_world_group()
    world_group.all_reduce(avg_critic_loss,
                           op=torch.distributed.ReduceOp.AVG)
    training_batch.fake_score_loss = avg_critic_loss.item()

    training_batch.total_loss = training_batch.generator_loss + training_batch.fake_score_loss
    return training_batch
fastvideo.training.self_forcing_distillation_pipeline.SelfForcingDistillationPipeline.visualize_intermediate_latents
visualize_intermediate_latents(training_batch: TrainingBatch, training_args: TrainingArgs, step: int)

Add visualization data to tracker logging and save frames to disk.

Source code in fastvideo/training/self_forcing_distillation_pipeline.py
def visualize_intermediate_latents(self, training_batch: TrainingBatch,
                                   training_args: TrainingArgs, step: int):
    """Add visualization data to tracker logging and save frames to disk."""
    tracker_loss_dict: dict[str, Any] = {}

    # Debug logging
    if hasattr(training_batch, 'dmd_latent_vis_dict'):
        logger.info("DMD latent keys: %s",
                    list(training_batch.dmd_latent_vis_dict.keys()))
    if hasattr(training_batch, 'fake_score_latent_vis_dict'):
        logger.info("Fake score latent keys: %s",
                    list(training_batch.fake_score_latent_vis_dict.keys()))

    # Process generator predictions if available
    if hasattr(
            training_batch,
            'dmd_latent_vis_dict') and training_batch.dmd_latent_vis_dict:
        dmd_latents_vis_dict = training_batch.dmd_latent_vis_dict
        dmd_log_keys = [
            'generator_pred_video', 'real_score_pred_video',
            'faker_score_pred_video'
        ]

        for latent_key in dmd_log_keys:
            if latent_key in dmd_latents_vis_dict:
                logger.info("Processing DMD latent: %s", latent_key)
                latents = dmd_latents_vis_dict[latent_key]
                if not isinstance(latents, torch.Tensor):
                    logger.warning("Expected tensor for %s, got %s",
                                   latent_key, type(latents))
                    continue

                latents = latents.detach()
                latents = latents.permute(0, 2, 1, 3, 4)

                if isinstance(self.vae.scaling_factor, torch.Tensor):
                    latents = latents / self.vae.scaling_factor.to(
                        latents.device, latents.dtype)
                else:
                    latents = latents / self.vae.scaling_factor

                if (hasattr(self.vae, "shift_factor")
                        and self.vae.shift_factor is not None):
                    if isinstance(self.vae.shift_factor, torch.Tensor):
                        latents += self.vae.shift_factor.to(
                            latents.device, latents.dtype)
                    else:
                        latents += self.vae.shift_factor

                with torch.autocast("cuda", dtype=torch.bfloat16):
                    video = self.vae.decode(latents)
                video = (video / 2 + 0.5).clamp(0, 1)
                video = video.cpu().float()
                video = video.permute(0, 2, 1, 3, 4)
                video = (video * 255).numpy().astype(np.uint8)
                video_artifact = self.tracker.video(video,
                                                    fps=24,
                                                    format="mp4")
                if video_artifact is not None:
                    tracker_loss_dict[f"dmd_{latent_key}"] = video_artifact
                del video, latents

    # Process critic predictions
    if hasattr(training_batch, 'fake_score_latent_vis_dict'
               ) and training_batch.fake_score_latent_vis_dict:
        fake_score_latents_vis_dict = training_batch.fake_score_latent_vis_dict
        fake_score_log_keys = ['generator_pred_video']

        for latent_key in fake_score_log_keys:
            if latent_key in fake_score_latents_vis_dict:
                logger.info("Processing critic latent: %s", latent_key)
                latents = fake_score_latents_vis_dict[latent_key]
                if not isinstance(latents, torch.Tensor):
                    logger.warning("Expected tensor for %s, got %s",
                                   latent_key, type(latents))
                    continue

                latents = latents.detach()
                latents = latents.permute(0, 2, 1, 3, 4)

                if isinstance(self.vae.scaling_factor, torch.Tensor):
                    latents = latents / self.vae.scaling_factor.to(
                        latents.device, latents.dtype)
                else:
                    latents = latents / self.vae.scaling_factor

                if (hasattr(self.vae, "shift_factor")
                        and self.vae.shift_factor is not None):
                    if isinstance(self.vae.shift_factor, torch.Tensor):
                        latents += self.vae.shift_factor.to(
                            latents.device, latents.dtype)
                    else:
                        latents += self.vae.shift_factor

                with torch.autocast("cuda", dtype=torch.bfloat16):
                    video = self.vae.decode(latents)
                video = (video / 2 + 0.5).clamp(0, 1)
                video = video.cpu().float()
                video = video.permute(0, 2, 1, 3, 4)
                video = (video * 255).numpy().astype(np.uint8)
                video_artifact = self.tracker.video(video,
                                                    fps=24,
                                                    format="mp4")
                if video_artifact is not None:
                    tracker_loss_dict[
                        f"critic_{latent_key}"] = video_artifact
                del video, latents

    # Log metadata
    if hasattr(
            training_batch,
            'dmd_latent_vis_dict') and training_batch.dmd_latent_vis_dict:
        if "generator_timestep" in training_batch.dmd_latent_vis_dict:
            tracker_loss_dict[
                "generator_timestep"] = training_batch.dmd_latent_vis_dict[
                    "generator_timestep"].item()
        if "dmd_timestep" in training_batch.dmd_latent_vis_dict:
            tracker_loss_dict[
                "dmd_timestep"] = training_batch.dmd_latent_vis_dict[
                    "dmd_timestep"].item()

    if hasattr(
            training_batch, 'fake_score_latent_vis_dict'
    ) and training_batch.fake_score_latent_vis_dict and "fake_score_timestep" in training_batch.fake_score_latent_vis_dict:
        tracker_loss_dict[
            "fake_score_timestep"] = training_batch.fake_score_latent_vis_dict[
                "fake_score_timestep"].item()

    # Log final dict contents
    logger.info("Final tracker_loss_dict keys: %s",
                list(tracker_loss_dict.keys()))

    if self.global_rank == 0 and tracker_loss_dict:
        self.tracker.log_artifacts(tracker_loss_dict, step)

Functions