fastvideo.training.distillation_pipeline#

Module Contents#

Classes#

DistillationPipeline

A distillation pipeline for training a 3 step model. Inherits from TrainingPipeline to reuse training infrastructure.

Data#

API#

class fastvideo.training.distillation_pipeline.DistillationPipeline(model_path: str, fastvideo_args: fastvideo.fastvideo_args.TrainingArgs, required_config_modules: list[str] | None = None, loaded_modules: dict[str, torch.nn.Module] | None = None)[source]#

Bases: fastvideo.training.training_pipeline.TrainingPipeline

A distillation pipeline for training a 3 step model. Inherits from TrainingPipeline to reuse training infrastructure.

Initialization

Initialize the pipeline. After init, the pipeline should be ready to use. The pipeline should be stateless and not hold any batch state.

create_pipeline_stages(fastvideo_args: fastvideo.fastvideo_args.FastVideoArgs)[source]#
current_epoch: int[source]#

0

current_trainstep: int[source]#

None

faker_score_forward(training_batch: fastvideo.pipelines.TrainingBatch) tuple[fastvideo.pipelines.TrainingBatch, torch.Tensor][source]#
init_steps: int[source]#

None

initialize_training_pipeline(training_args: fastvideo.fastvideo_args.TrainingArgs)[source]#

Initialize the distillation training pipeline with multiple models.

abstract initialize_validation_pipeline(training_args: fastvideo.fastvideo_args.TrainingArgs)[source]#

Initialize validation pipeline - must be implemented by subclasses.

train() None[source]#

Main training loop with distillation-specific logging.

train_dataloader: torchdata.stateful_dataloader.StatefulDataLoader[source]#

None

train_loader_iter: collections.abc.Iterator[dict[str, Any]][source]#

None

train_one_step(training_batch: fastvideo.pipelines.TrainingBatch) fastvideo.pipelines.TrainingBatch[source]#
validation_pipeline: fastvideo.pipelines.ComposedPipelineBase[source]#

None

video_latent_shape: tuple[int, ...][source]#

None

video_latent_shape_sp: tuple[int, ...][source]#

None

visualize_intermediate_latents(training_batch: fastvideo.pipelines.TrainingBatch, training_args: fastvideo.fastvideo_args.TrainingArgs, step: int)[source]#

Add visualization data to wandb logging and save frames to disk.

fastvideo.training.distillation_pipeline.logger[source]#

β€˜init_logger(…)’

fastvideo.training.distillation_pipeline.vsa_available[source]#

β€˜is_vsa_available(…)’