fastvideo.v1.training.training_utils#

Module Contents#

Functions#

clip_grad_norm_

Clip the gradient norm of parameters.

clip_grad_norm_while_handling_failing_dtensor_cases

compute_density_for_timestep_sampling

Compute the density for sampling the timesteps when doing SD3 training.

gather_state_dict_on_cpu_rank0

get_sigmas

load_checkpoint

Load checkpoint following finetrainer’s distributed checkpoint approach. Returns the step number from which training should resume.

normalize_dit_input

save_checkpoint

Save checkpoint following finetrainer’s distributed checkpoint approach. Saves both distributed checkpoint and consolidated model weights.

Data#

API#

fastvideo.v1.training.training_utils.clip_grad_norm_(parameters: Union[torch.Tensor, List[torch.Tensor]], max_norm: float, norm_type: float = 2.0, error_if_nonfinite: bool = False, foreach: Optional[bool] = None, pp_mesh: Optional[torch.distributed.device_mesh.DeviceMesh] = None) torch.Tensor[source]#

Clip the gradient norm of parameters.

Gradient norm clipping requires computing the gradient norm over the entire model. torch.nn.utils.clip_grad_norm_ only computes gradient norm along DP/FSDP/TP dimensions. We need to manually reduce the gradient norm across PP stages. See https://github.com/pytorch/torchtitan/issues/596 for details.

Parameters:
  • parameters (torch.Tensor or List[torch.Tensor]) – Tensors that will have gradients normalized.

  • max_norm (float) – Maximum norm of the gradients after clipping.

  • norm_type (float, defaults to 2.0) – Type of p-norm to use. Can be inf for infinity norm.

  • error_if_nonfinite (bool, defaults to False) – If True, an error is thrown if the total norm of the gradients from parameters is nan, inf, or -inf.

  • foreach (bool, defaults to None) – Use the faster foreach-based implementation. If None, use the foreach implementation for CUDA and CPU native tensors and silently fall back to the slow implementation for other device types.

  • pp_mesh (torch.distributed.device_mesh.DeviceMesh, defaults to None) – Pipeline parallel device mesh. If not None, will reduce gradient norm across PP stages.

Returns:

Total norm of the gradients

Return type:

torch.Tensor

fastvideo.v1.training.training_utils.clip_grad_norm_while_handling_failing_dtensor_cases(parameters: Union[torch.Tensor, List[torch.Tensor]], max_norm: float, norm_type: float = 2.0, error_if_nonfinite: bool = False, foreach: Optional[bool] = None, pp_mesh: Optional[torch.distributed.device_mesh.DeviceMesh] = None) Optional[torch.Tensor][source]#
fastvideo.v1.training.training_utils.compute_density_for_timestep_sampling(weighting_scheme: str, batch_size: int, generator, logit_mean: Optional[float] = None, logit_std: Optional[float] = None, mode_scale: Optional[float] = None)[source]#

Compute the density for sampling the timesteps when doing SD3 training.

Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.

SD3 paper reference: https://arxiv.org/abs/2403.03206v1.

fastvideo.v1.training.training_utils.gather_state_dict_on_cpu_rank0(model, device: Optional[torch.device] = None) Dict[str, Any][source]#
fastvideo.v1.training.training_utils.get_sigmas(noise_scheduler, device, timesteps, n_dim=4, dtype=torch.float32) torch.Tensor[source]#
fastvideo.v1.training.training_utils.load_checkpoint(transformer, rank, checkpoint_path, optimizer=None, dataloader=None, scheduler=None, noise_generator=None) int[source]#

Load checkpoint following finetrainer’s distributed checkpoint approach. Returns the step number from which training should resume.

fastvideo.v1.training.training_utils.logger[source]#

‘init_logger(…)’

fastvideo.v1.training.training_utils.normalize_dit_input(model_type, latents, args=None) torch.Tensor[source]#
fastvideo.v1.training.training_utils.save_checkpoint(transformer, rank, output_dir, step, optimizer=None, dataloader=None, scheduler=None, noise_generator=None) None[source]#

Save checkpoint following finetrainer’s distributed checkpoint approach. Saves both distributed checkpoint and consolidated model weights.