fastvideo.training.self_forcing_distillation_pipeline
#
Module Contents#
Classes#
A self-forcing distillation pipeline that alternates between training the generator and critic based on the self-forcing methodology. |
Data#
API#
- class fastvideo.training.self_forcing_distillation_pipeline.SelfForcingDistillationPipeline(model_path: str, fastvideo_args: fastvideo.fastvideo_args.TrainingArgs, required_config_modules: list[str] | None = None, loaded_modules: dict[str, torch.nn.Module] | None = None)[source]#
Bases:
fastvideo.training.distillation_pipeline.DistillationPipeline
A self-forcing distillation pipeline that alternates between training the generator and critic based on the self-forcing methodology.
This implementation follows the self-forcing approach where:
Generator and critic are trained in alternating steps
Generator loss uses DMD-style loss with the critic as fake score
Critic loss trains the fake score model to distinguish real vs fake
Initialization
Initialize the pipeline. After init, the pipeline should be ready to use. The pipeline should be stateless and not hold any batch state.
- critic_loss(training_batch: fastvideo.pipelines.TrainingBatch) tuple[torch.Tensor, dict[str, Any]] [source]#
Compute critic loss using flow matching between noise and generator output. The critic learns to predict the flow from noise to the generatorβs output.
- generate_and_sync_list(num_blocks: int, num_denoising_steps: int, device: torch.device) list[int] [source]#
Generate and synchronize random exit flags across distributed processes.
- generator_loss(training_batch: fastvideo.pipelines.TrainingBatch) tuple[torch.Tensor, dict[str, Any]] [source]#
Compute generator loss using DMD-style approach. The generator tries to fool the critic (fake_score_transformer).
- initialize_training_pipeline(training_args: fastvideo.fastvideo_args.TrainingArgs)[source]#
Initialize the self-forcing training pipeline.
- train_one_step(training_batch: fastvideo.pipelines.TrainingBatch) fastvideo.pipelines.TrainingBatch [source]#
Self-forcing training step that alternates between generator and critic training.
- visualize_intermediate_latents(training_batch: fastvideo.pipelines.TrainingBatch, training_args: fastvideo.fastvideo_args.TrainingArgs, step: int)[source]#
Add visualization data to wandb logging and save frames to disk.