fastvideo.training.training_utils
#
Module Contents#
Classes#
FSDP2-friendly EMA with two modes: |
|
Functions#
Clip the gradient norm of parameters. |
|
Compute the density for sampling the timesteps when doing SD3 training. |
|
Convert fastvideo’s custom model format to diffusers format using reverse_param_names_mapping. |
|
Create a schedule with a constant learning rate, using the learning rate set in optimizer. |
|
Create a schedule with a constant learning rate preceded by a warmup period during which the learning rate increases linearly between 0 and the initial lr set in the optimizer. |
|
Create a schedule with a learning rate that decreases following the values of the cosine function between the initial lr set in the optimizer to a minimum lr (min_lr_ratio * initial_lr), after a warmup period during which it increases linearly between 0 and the initial lr set in the optimizer. |
|
Create a schedule with a learning rate that decreases following the values of the cosine function between the initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the initial lr set in the optimizer. |
|
Create a schedule with a learning rate that decreases following the values of the cosine function between the initial lr set in the optimizer to 0, with several hard restarts, after a warmup period during which it increases linearly between 0 and the initial lr set in the optimizer. |
|
Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer. |
|
Create a schedule with a constant learning rate, using the learning rate set in optimizer. |
|
Create a schedule with a learning rate that decreases as a polynomial decay from the initial lr set in the optimizer to end lr defined by lr_end, after a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer. |
|
Unified API to get any scheduler from its name. |
|
Load checkpoint following finetrainer’s distributed checkpoint approach. Returns the step number from which training should resume. |
|
Load distillation checkpoint with both generator and fake_score models. Supports MoE (Mixture of Experts) models with transformer_2 variants. Returns the step number from which training should resume. |
|
Save checkpoint following finetrainer’s distributed checkpoint approach. Saves both distributed checkpoint and consolidated model weights. |
|
Save distillation checkpoint with both generator and fake_score models. Supports MoE (Mixture of Experts) models with transformer_2 variants. Saves both distributed checkpoint and consolidated model weights. Only saves the generator model for inference (consolidated weights). |
|
Data#
API#
- class fastvideo.training.training_utils.EMA_FSDP(module, decay: float = 0.999, mode: str = 'local_shard')[source]#
FSDP2-friendly EMA with two modes:
mode=”local_shard” (default): maintain float32 CPU EMA of local parameter shards on every rank. Provides a context manager to temporarily swap EMA weights into the live model for teacher forward.
mode=”rank0_full”: maintain a consolidated float32 CPU EMA of full parameters on rank 0 only using gather_state_dict_on_cpu_rank0(). Useful for checkpoint export; not for teacher forward.
Usage (local_shard for CM teacher): ema = EMA_FSDP(model, decay=0.999, mode=”local_shard”) for step in …: ema.update(model) with ema.apply_to_model(model): with torch.no_grad(): y_teacher = model(…)
Usage (rank0_full for export): ema = EMA_FSDP(model, decay=0.999, mode=”rank0_full”) ema.update(model) ema.state_dict() # on rank 0
Initialization
- copy_to_unwrapped(module) None [source]#
Copy EMA weights into a non-sharded (unwrapped) module. Intended for export/eval. For mode=”rank0_full”, only rank 0 has the full EMA state.
- load_state_dict(sd: dict[str, torch.Tensor])[source]#
- state_dict() dict[str, torch.Tensor] [source]#
- fastvideo.training.training_utils.TYPE_TO_SCHEDULER_FUNCTION: dict[fastvideo.training.training_utils.SchedulerType, fastvideo.training.training_utils.SchedulerFunction][source]#
None
- fastvideo.training.training_utils.clip_grad_norm_(parameters: torch.Tensor | list[torch.Tensor], max_norm: float, norm_type: float = 2.0, error_if_nonfinite: bool = False, foreach: bool | None = None, pp_mesh: torch.distributed.device_mesh.DeviceMesh | None = None) torch.Tensor [source]#
Clip the gradient norm of parameters.
Gradient norm clipping requires computing the gradient norm over the entire model.
torch.nn.utils.clip_grad_norm_
only computes gradient norm along DP/FSDP/TP dimensions. We need to manually reduce the gradient norm across PP stages. See https://github.com/pytorch/torchtitan/issues/596 for details.- Parameters:
parameters (
torch.Tensor
orList[torch.Tensor]
) – Tensors that will have gradients normalized.max_norm (
float
) – Maximum norm of the gradients after clipping.norm_type (
float
, defaults to2.0
) – Type of p-norm to use. Can beinf
for infinity norm.error_if_nonfinite (
bool
, defaults toFalse
) – IfTrue
, an error is thrown if the total norm of the gradients fromparameters
isnan
,inf
, or-inf
.foreach (
bool
, defaults toNone
) – Use the faster foreach-based implementation. IfNone
, use the foreach implementation for CUDA and CPU native tensors and silently fall back to the slow implementation for other device types.pp_mesh (
torch.distributed.device_mesh.DeviceMesh
, defaults toNone
) – Pipeline parallel device mesh. If notNone
, will reduce gradient norm across PP stages.
- Returns:
Total norm of the gradients
- Return type:
torch.Tensor
- fastvideo.training.training_utils.clip_grad_norm_while_handling_failing_dtensor_cases(parameters: torch.Tensor | list[torch.Tensor], max_norm: float, norm_type: float = 2.0, error_if_nonfinite: bool = False, foreach: bool | None = None, pp_mesh: torch.distributed.device_mesh.DeviceMesh | None = None) torch.Tensor | None [source]#
- fastvideo.training.training_utils.compute_density_for_timestep_sampling(weighting_scheme: str, batch_size: int, generator, logit_mean: float | None = None, logit_std: float | None = None, mode_scale: float | None = None)[source]#
Compute the density for sampling the timesteps when doing SD3 training.
Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
SD3 paper reference: https://arxiv.org/abs/2403.03206v1.
- fastvideo.training.training_utils.count_trainable(model: torch.nn.Module) int [source]#
- fastvideo.training.training_utils.custom_to_hf_state_dict(state_dict: dict[str, Any] | collections.abc.Iterator[tuple[str, torch.Tensor]], reverse_param_names_mapping: dict[str, tuple[str, int, int]]) dict[str, Any] [source]#
Convert fastvideo’s custom model format to diffusers format using reverse_param_names_mapping.
- Parameters:
state_dict – State dict in fastvideo’s custom format
reverse_param_names_mapping – Reverse mapping from fastvideo’s custom format to diffusers format
- Returns:
State dict in diffusers format
- fastvideo.training.training_utils.gather_state_dict_on_cpu_rank0(model, device: torch.device | None = None) dict[str, Any] [source]#
- fastvideo.training.training_utils.get_constant_schedule(optimizer: torch.optim.Optimizer, last_epoch: int = -1) torch.optim.lr_scheduler.LambdaLR [source]#
Create a schedule with a constant learning rate, using the learning rate set in optimizer.
- Parameters:
optimizer ([
~torch.optim.Optimizer
]) – The optimizer for which to schedule the learning rate.last_epoch (
int
, optional, defaults to -1) – The index of the last epoch when resuming training.
- Returns:
torch.optim.lr_scheduler.LambdaLR
with the appropriate schedule.
- fastvideo.training.training_utils.get_constant_schedule_with_warmup(optimizer: torch.optim.Optimizer, num_warmup_steps: int, last_epoch: int = -1) torch.optim.lr_scheduler.LambdaLR [source]#
Create a schedule with a constant learning rate preceded by a warmup period during which the learning rate increases linearly between 0 and the initial lr set in the optimizer.
- Parameters:
optimizer ([
~torch.optim.Optimizer
]) – The optimizer for which to schedule the learning rate.num_warmup_steps (
int
) – The number of steps for the warmup phase.last_epoch (
int
, optional, defaults to -1) – The index of the last epoch when resuming training.
- Returns:
torch.optim.lr_scheduler.LambdaLR
with the appropriate schedule.
- fastvideo.training.training_utils.get_cosine_schedule_with_min_lr(optimizer: torch.optim.Optimizer, num_warmup_steps: int, num_training_steps: int, min_lr_ratio: float = 0.1, num_cycles: float = 0.5, last_epoch: int = -1) torch.optim.lr_scheduler.LambdaLR [source]#
Create a schedule with a learning rate that decreases following the values of the cosine function between the initial lr set in the optimizer to a minimum lr (min_lr_ratio * initial_lr), after a warmup period during which it increases linearly between 0 and the initial lr set in the optimizer.
- Parameters:
optimizer ([
~torch.optim.Optimizer
]) – The optimizer for which to schedule the learning rate.num_warmup_steps (
int
) – The number of steps for the warmup phase.num_training_steps (
int
) – The total number of training steps.min_lr_ratio (
float
, optional, defaults to 0.1) – The ratio of minimum learning rate to initial learning rate.num_cycles (
float
, optional, defaults to 0.5) – The number of periods of the cosine function in a schedule.last_epoch (
int
, optional, defaults to -1) – The index of the last epoch when resuming training.
- Returns:
torch.optim.lr_scheduler.LambdaLR
with the appropriate schedule.
- fastvideo.training.training_utils.get_cosine_schedule_with_warmup(optimizer: torch.optim.Optimizer, num_warmup_steps: int, num_training_steps: int, num_cycles: float = 0.5, last_epoch: int = -1) torch.optim.lr_scheduler.LambdaLR [source]#
Create a schedule with a learning rate that decreases following the values of the cosine function between the initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the initial lr set in the optimizer.
- Parameters:
optimizer ([
~torch.optim.Optimizer
]) – The optimizer for which to schedule the learning rate.num_warmup_steps (
int
) – The number of steps for the warmup phase.num_training_steps (
int
) – The total number of training steps.num_periods (
float
, optional, defaults to 0.5) – The number of periods of the cosine function in a schedule (the default is to just decrease from the max value to 0 following a half-cosine).last_epoch (
int
, optional, defaults to -1) – The index of the last epoch when resuming training.
- Returns:
torch.optim.lr_scheduler.LambdaLR
with the appropriate schedule.
- fastvideo.training.training_utils.get_cosine_with_hard_restarts_schedule_with_warmup(optimizer: torch.optim.Optimizer, num_warmup_steps: int, num_training_steps: int, num_cycles: int = 1, last_epoch: int = -1) torch.optim.lr_scheduler.LambdaLR [source]#
Create a schedule with a learning rate that decreases following the values of the cosine function between the initial lr set in the optimizer to 0, with several hard restarts, after a warmup period during which it increases linearly between 0 and the initial lr set in the optimizer.
- Parameters:
optimizer ([
~torch.optim.Optimizer
]) – The optimizer for which to schedule the learning rate.num_warmup_steps (
int
) – The number of steps for the warmup phase.num_training_steps (
int
) – The total number of training steps.num_cycles (
int
, optional, defaults to 1) – The number of hard restarts to use.last_epoch (
int
, optional, defaults to -1) – The index of the last epoch when resuming training.
- Returns:
torch.optim.lr_scheduler.LambdaLR
with the appropriate schedule.
- fastvideo.training.training_utils.get_linear_schedule_with_warmup(optimizer: torch.optim.Optimizer, num_warmup_steps: int, num_training_steps: int, last_epoch: int = -1) torch.optim.lr_scheduler.LambdaLR [source]#
Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer.
- Parameters:
optimizer ([
~torch.optim.Optimizer
]) – The optimizer for which to schedule the learning rate.num_warmup_steps (
int
) – The number of steps for the warmup phase.num_training_steps (
int
) – The total number of training steps.last_epoch (
int
, optional, defaults to -1) – The index of the last epoch when resuming training.
- Returns:
torch.optim.lr_scheduler.LambdaLR
with the appropriate schedule.
- fastvideo.training.training_utils.get_piecewise_constant_schedule(optimizer: torch.optim.Optimizer, step_rules: str, last_epoch: int = -1) torch.optim.lr_scheduler.LambdaLR [source]#
Create a schedule with a constant learning rate, using the learning rate set in optimizer.
- Parameters:
optimizer ([
~torch.optim.Optimizer
]) – The optimizer for which to schedule the learning rate.step_rules (
string
) – The rules for the learning rate. ex: rule_steps=”1:10,0.1:20,0.01:30,0.005” it means that the learning rate if multiple 1 for the first 10 steps, multiple 0.1 for the next 20 steps, multiple 0.01 for the next 30 steps and multiple 0.005 for the other steps.last_epoch (
int
, optional, defaults to -1) – The index of the last epoch when resuming training.
- Returns:
torch.optim.lr_scheduler.LambdaLR
with the appropriate schedule.
- fastvideo.training.training_utils.get_polynomial_decay_schedule_with_warmup(optimizer: torch.optim.Optimizer, num_warmup_steps: int, num_training_steps: int, lr_end: float = 1e-07, power: float = 1.0, last_epoch: int = -1) torch.optim.lr_scheduler.LambdaLR [source]#
Create a schedule with a learning rate that decreases as a polynomial decay from the initial lr set in the optimizer to end lr defined by lr_end, after a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer.
- Parameters:
optimizer ([
~torch.optim.Optimizer
]) – The optimizer for which to schedule the learning rate.num_warmup_steps (
int
) – The number of steps for the warmup phase.num_training_steps (
int
) – The total number of training steps.lr_end (
float
, optional, defaults to 1e-7) – The end LR.power (
float
, optional, defaults to 1.0) – Power factor.last_epoch (
int
, optional, defaults to -1) – The index of the last epoch when resuming training.
Note: power defaults to 1.0 as in the fairseq implementation, which in turn is based on the original BERT implementation at https://github.com/google-research/bert/blob/f39e881b169b9d53bea03d2d341b31707a6c052b/optimization.py#L37
- Returns:
torch.optim.lr_scheduler.LambdaLR
with the appropriate schedule.
- fastvideo.training.training_utils.get_scheduler(name: str | fastvideo.training.training_utils.SchedulerType, optimizer: torch.optim.Optimizer, step_rules: str | None = None, num_warmup_steps: int | None = None, num_training_steps: int | None = None, num_cycles: int = 1, power: float = 1.0, min_lr_ratio: float = 0.1, last_epoch: int = -1) torch.optim.lr_scheduler.LambdaLR [source]#
Unified API to get any scheduler from its name.
- Parameters:
name (
str
orSchedulerType
) – The name of the scheduler to use.optimizer (
torch.optim.Optimizer
) – The optimizer that will be used during training.step_rules (
str
, optional) – A string representing the step rules to use. This is only used by thePIECEWISE_CONSTANT
scheduler.num_warmup_steps (
int
, optional) – The number of warmup steps to do. This is not required by all schedulers (hence the argument being optional), the function will raise an error if it’s unset and the scheduler type requires it.num_training_steps (`int``, optional) – The number of training steps to do. This is not required by all schedulers (hence the argument being optional), the function will raise an error if it’s unset and the scheduler type requires it.
num_cycles (
int
, optional) – The number of hard restarts used inCOSINE_WITH_RESTARTS
scheduler.power (
float
, optional, defaults to 1.0) – Power factor. SeePOLYNOMIAL
schedulermin_lr_ratio (
float
, optional, defaults to 0.1) – The ratio of minimum learning rate to initial learning rate. Used inCOSINE_WITH_MIN_LR
scheduler.last_epoch (
int
, optional, defaults to -1) – The index of the last epoch when resuming training.
- fastvideo.training.training_utils.get_sigmas(noise_scheduler, device, timesteps, n_dim=4, dtype=torch.float32) torch.Tensor [source]#
- fastvideo.training.training_utils.load_checkpoint(transformer, rank, checkpoint_path, optimizer=None, dataloader=None, scheduler=None, noise_generator=None) int [source]#
Load checkpoint following finetrainer’s distributed checkpoint approach. Returns the step number from which training should resume.
- fastvideo.training.training_utils.load_distillation_checkpoint(generator_transformer, fake_score_transformer, rank, checkpoint_path, generator_optimizer=None, fake_score_optimizer=None, dataloader=None, generator_scheduler=None, fake_score_scheduler=None, noise_generator=None, generator_ema=None, generator_transformer_2=None, real_score_transformer_2=None, fake_score_transformer_2=None, generator_optimizer_2=None, fake_score_optimizer_2=None, generator_scheduler_2=None, fake_score_scheduler_2=None, generator_ema_2=None) int [source]#
Load distillation checkpoint with both generator and fake_score models. Supports MoE (Mixture of Experts) models with transformer_2 variants. Returns the step number from which training should resume.
- Parameters:
generator_transformer – Main generator transformer model
fake_score_transformer – Main fake score transformer model
generator_transformer_2 – Secondary generator transformer for MoE (optional)
real_score_transformer_2 – Secondary real score transformer for MoE (optional)
fake_score_transformer_2 – Secondary fake score transformer for MoE (optional)
generator_optimizer_2 – Optimizer for generator_transformer_2 (optional)
fake_score_optimizer_2 – Optimizer for fake_score_transformer_2 (optional)
generator_scheduler_2 – Scheduler for generator_transformer_2 (optional)
fake_score_scheduler_2 – Scheduler for fake_score_transformer_2 (optional)
generator_ema_2 – EMA for generator_transformer_2 (optional)
- fastvideo.training.training_utils.normalize_dit_input(model_type, latents, vae) torch.Tensor [source]#
- fastvideo.training.training_utils.save_checkpoint(transformer, rank, output_dir, step, optimizer=None, dataloader=None, scheduler=None, noise_generator=None) None [source]#
Save checkpoint following finetrainer’s distributed checkpoint approach. Saves both distributed checkpoint and consolidated model weights.
- fastvideo.training.training_utils.save_distillation_checkpoint(generator_transformer, fake_score_transformer, rank, output_dir, step, generator_optimizer=None, fake_score_optimizer=None, dataloader=None, generator_scheduler=None, fake_score_scheduler=None, noise_generator=None, generator_ema=None, only_save_generator_weight=False, generator_transformer_2=None, real_score_transformer_2=None, fake_score_transformer_2=None, generator_optimizer_2=None, fake_score_optimizer_2=None, generator_scheduler_2=None, fake_score_scheduler_2=None, generator_ema_2=None) None [source]#
Save distillation checkpoint with both generator and fake_score models. Supports MoE (Mixture of Experts) models with transformer_2 variants. Saves both distributed checkpoint and consolidated model weights. Only saves the generator model for inference (consolidated weights).
- Parameters:
generator_transformer – Main generator transformer model
fake_score_transformer – Main fake score transformer model
only_save_generator_weight – If True, only save the generator model weights for inference without saving distributed checkpoint for training resume.
generator_transformer_2 – Secondary generator transformer for MoE (optional)
real_score_transformer_2 – Secondary real score transformer for MoE (optional)
fake_score_transformer_2 – Secondary fake score transformer for MoE (optional)
generator_optimizer_2 – Optimizer for generator_transformer_2 (optional)
fake_score_optimizer_2 – Optimizer for fake_score_transformer_2 (optional)
generator_scheduler_2 – Scheduler for generator_transformer_2 (optional)
fake_score_scheduler_2 – Scheduler for fake_score_transformer_2 (optional)
generator_ema_2 – EMA for generator_transformer_2 (optional)
- fastvideo.training.training_utils.shard_latents_across_sp(latents: torch.Tensor, num_latent_t: int) torch.Tensor [source]#
- fastvideo.training.training_utils.shift_timestep(timestep: torch.Tensor, shift: float, num_train_timestep: float) torch.Tensor [source]#