fastvideo.training.training_pipeline
#
Module Contents#
Classes#
A pipeline for training a model. All training pipelines should inherit from this class. All reusable components and code should be implemented in this class. |
Data#
API#
- class fastvideo.training.training_pipeline.TrainingPipeline(model_path: str, fastvideo_args: fastvideo.fastvideo_args.TrainingArgs, required_config_modules: list[str] | None = None, loaded_modules: dict[str, torch.nn.Module] | None = None)[source]#
Bases:
fastvideo.pipelines.LoRAPipeline
,abc.ABC
A pipeline for training a model. All training pipelines should inherit from this class. All reusable components and code should be implemented in this class.
Initialization
Initialize the pipeline. After init, the pipeline should be ready to use. The pipeline should be stateless and not hold any batch state.
- create_pipeline_stages(fastvideo_args: fastvideo.fastvideo_args.FastVideoArgs)[source]#
- initialize_training_pipeline(training_args: fastvideo.fastvideo_args.TrainingArgs)[source]#
- abstract initialize_validation_pipeline(training_args: fastvideo.fastvideo_args.TrainingArgs)[source]#
- train_one_step(training_batch: fastvideo.pipelines.TrainingBatch) fastvideo.pipelines.TrainingBatch [source]#
- abstract visualize_intermediate_latents(training_batch: fastvideo.pipelines.TrainingBatch, training_args: fastvideo.fastvideo_args.TrainingArgs, step: int)[source]#
Add visualization data to wandb logging and save frames to disk.