fastvideo.training.distillation_pipeline
#
Module Contents#
Classes#
A distillation pipeline for training a 3 step model. Inherits from TrainingPipeline to reuse training infrastructure. |
Data#
API#
- class fastvideo.training.distillation_pipeline.DistillationPipeline(model_path: str, fastvideo_args: fastvideo.fastvideo_args.TrainingArgs, required_config_modules: list[str] | None = None, loaded_modules: dict[str, torch.nn.Module] | None = None)[source]#
Bases:
fastvideo.training.training_pipeline.TrainingPipeline
A distillation pipeline for training a 3 step model. Inherits from TrainingPipeline to reuse training infrastructure.
Initialization
Initialize the pipeline. After init, the pipeline should be ready to use. The pipeline should be stateless and not hold any batch state.
- create_pipeline_stages(fastvideo_args: fastvideo.fastvideo_args.FastVideoArgs)[source]#
- faker_score_forward(training_batch: fastvideo.pipelines.TrainingBatch) tuple[fastvideo.pipelines.TrainingBatch, torch.Tensor] [source]#
- initialize_training_pipeline(training_args: fastvideo.fastvideo_args.TrainingArgs)[source]#
Initialize the distillation training pipeline with multiple models.
- abstract initialize_validation_pipeline(training_args: fastvideo.fastvideo_args.TrainingArgs)[source]#
Initialize validation pipeline - must be implemented by subclasses.
- train_one_step(training_batch: fastvideo.pipelines.TrainingBatch) fastvideo.pipelines.TrainingBatch [source]#
- visualize_intermediate_latents(training_batch: fastvideo.pipelines.TrainingBatch, training_args: fastvideo.fastvideo_args.TrainingArgs, step: int)[source]#
Add visualization data to wandb logging and save frames to disk.