fastvideo.training.distillation_pipeline
#
Module Contents#
Classes#
A distillation pipeline for training a 3 step model. Inherits from TrainingPipeline to reuse training infrastructure. |
Data#
API#
- class fastvideo.training.distillation_pipeline.DistillationPipeline(model_path: str, fastvideo_args: fastvideo.fastvideo_args.TrainingArgs, required_config_modules: list[str] | None = None, loaded_modules: dict[str, torch.nn.Module] | None = None)[source]#
Bases:
fastvideo.training.training_pipeline.TrainingPipeline
A distillation pipeline for training a 3 step model. Inherits from TrainingPipeline to reuse training infrastructure.
Initialization
Initialize the pipeline. After init, the pipeline should be ready to use. The pipeline should be stateless and not hold any batch state.
- create_pipeline_stages(fastvideo_args: fastvideo.fastvideo_args.FastVideoArgs)[source]#
- faker_score_forward(training_batch: fastvideo.pipelines.TrainingBatch) tuple[fastvideo.pipelines.TrainingBatch, torch.Tensor] [source]#
- get_ema_model_copy() torch.nn.Module | None [source]#
Get a copy of the model with EMA weights applied.
- initialize_training_pipeline(training_args: fastvideo.fastvideo_args.TrainingArgs)[source]#
Initialize the distillation training pipeline with multiple models.
- abstract initialize_validation_pipeline(training_args: fastvideo.fastvideo_args.TrainingArgs)[source]#
Initialize validation pipeline - must be implemented by subclasses.
- is_ema_ready(current_step: int | None = None)[source]#
Check if EMA is ready for use (after ema_start_step).
- load_module_from_path(model_path: str, module_type: str, training_args: fastvideo.fastvideo_args.TrainingArgs)[source]#
Load a module from a specific path using the same loading logic as the pipeline.
- Parameters:
model_path β Path to the model
module_type β Type of module to load (e.g., βtransformerβ)
training_args β Training arguments
- Returns:
The loaded module
- save_ema_weights(output_dir: str, step: int)[source]#
Save EMA weights separately for inference purposes.
- train_one_step(training_batch: fastvideo.pipelines.TrainingBatch) fastvideo.pipelines.TrainingBatch [source]#
- visualize_intermediate_latents(training_batch: fastvideo.pipelines.TrainingBatch, training_args: fastvideo.fastvideo_args.TrainingArgs, step: int)[source]#
Add visualization data to wandb logging and save frames to disk.