Skip to content

training_utils

Classes

fastvideo.training.training_utils.EMA_FSDP

EMA_FSDP(module, decay: float = 0.999, mode: str = 'local_shard')
FSDP2-friendly EMA with two modes
  • mode="local_shard" (default): maintain float32 CPU EMA of local parameter shards on every rank. Provides a context manager to temporarily swap EMA weights into the live model for teacher forward.
  • mode="rank0_full": maintain a consolidated float32 CPU EMA of full parameters on rank 0 only using gather_state_dict_on_cpu_rank0(). Useful for checkpoint export; not for teacher forward.

Usage (local_shard for CM teacher): ema = EMA_FSDP(model, decay=0.999, mode="local_shard") for step in ...: ema.update(model) with ema.apply_to_model(model): with torch.no_grad(): y_teacher = model(...)

Usage (rank0_full for export): ema = EMA_FSDP(model, decay=0.999, mode="rank0_full") ema.update(model) ema.state_dict() # on rank 0

Source code in fastvideo/training/training_utils.py
def __init__(self, module, decay: float = 0.999, mode: str = "local_shard"):
    self.decay = float(decay)
    self.mode = mode
    self.shadow: dict[str, torch.Tensor] = {}
    self.rank = dist.get_rank() if dist.is_initialized() else 0
    if self.mode not in {"local_shard", "rank0_full"}:
        raise ValueError(f"Unsupported EMA_FSDP mode: {self.mode}")
    self._init_shadow(module)

Functions

fastvideo.training.training_utils.EMA_FSDP.copy_to_unwrapped
copy_to_unwrapped(module) -> None

Copy EMA weights into a non-sharded (unwrapped) module. Intended for export/eval. For mode="rank0_full", only rank 0 has the full EMA state.

Source code in fastvideo/training/training_utils.py
@torch.no_grad()
def copy_to_unwrapped(self, module) -> None:
    """
    Copy EMA weights into a non-sharded (unwrapped) module. Intended for export/eval.
    For mode="rank0_full", only rank 0 has the full EMA state.
    """
    if self.mode == "rank0_full" and self.rank != 0:
        return
    name_to_param = dict(module.named_parameters())
    for n, w in self.shadow.items():
        if n in name_to_param:
            p = name_to_param[n]
            p.data.copy_(w.to(dtype=p.dtype, device=p.device))

Functions

fastvideo.training.training_utils.clip_grad_norm_

clip_grad_norm_(parameters: Tensor | list[Tensor], max_norm: float, norm_type: float = 2.0, error_if_nonfinite: bool = False, foreach: bool | None = None, pp_mesh: DeviceMesh | None = None) -> Tensor

Clip the gradient norm of parameters.

Gradient norm clipping requires computing the gradient norm over the entire model. torch.nn.utils.clip_grad_norm_ only computes gradient norm along DP/FSDP/TP dimensions. We need to manually reduce the gradient norm across PP stages. See https://github.com/pytorch/torchtitan/issues/596 for details.

Parameters:

Name Type Description Default
parameters `torch.Tensor` or `List[torch.Tensor]`

Tensors that will have gradients normalized.

required
max_norm `float`

Maximum norm of the gradients after clipping.

required
norm_type `float`, defaults to `2.0`

Type of p-norm to use. Can be inf for infinity norm.

2.0
error_if_nonfinite `bool`, defaults to `False`

If True, an error is thrown if the total norm of the gradients from parameters is nan, inf, or -inf.

False
foreach `bool`, defaults to `None`

Use the faster foreach-based implementation. If None, use the foreach implementation for CUDA and CPU native tensors and silently fall back to the slow implementation for other device types.

None
pp_mesh `torch.distributed.device_mesh.DeviceMesh`, defaults to `None`

Pipeline parallel device mesh. If not None, will reduce gradient norm across PP stages.

None

Returns:

Type Description
Tensor

torch.Tensor: Total norm of the gradients

Source code in fastvideo/training/training_utils.py
@torch.no_grad()
def clip_grad_norm_(
    parameters: torch.Tensor | list[torch.Tensor],
    max_norm: float,
    norm_type: float = 2.0,
    error_if_nonfinite: bool = False,
    foreach: bool | None = None,
    pp_mesh: torch.distributed.device_mesh.DeviceMesh | None = None,
) -> torch.Tensor:
    r"""
    Clip the gradient norm of parameters.

    Gradient norm clipping requires computing the gradient norm over the entire model.
    `torch.nn.utils.clip_grad_norm_` only computes gradient norm along DP/FSDP/TP dimensions.
    We need to manually reduce the gradient norm across PP stages.
    See https://github.com/pytorch/torchtitan/issues/596 for details.

    Args:
        parameters (`torch.Tensor` or `List[torch.Tensor]`):
            Tensors that will have gradients normalized.
        max_norm (`float`):
            Maximum norm of the gradients after clipping.
        norm_type (`float`, defaults to `2.0`):
            Type of p-norm to use. Can be `inf` for infinity norm.
        error_if_nonfinite (`bool`, defaults to `False`):
            If `True`, an error is thrown if the total norm of the gradients from `parameters` is `nan`, `inf`, or `-inf`.
        foreach (`bool`, defaults to `None`):
            Use the faster foreach-based implementation. If `None`, use the foreach implementation for CUDA and CPU native tensors
            and silently fall back to the slow implementation for other device types.
        pp_mesh (`torch.distributed.device_mesh.DeviceMesh`, defaults to `None`):
            Pipeline parallel device mesh. If not `None`, will reduce gradient norm across PP stages.

    Returns:
        `torch.Tensor`:
            Total norm of the gradients
    """
    grads = [p.grad for p in parameters if p.grad is not None]

    # TODO(aryan): Wait for next Pytorch release to use `torch.nn.utils.get_total_norm`
    # total_norm = torch.nn.utils.get_total_norm(grads, norm_type, error_if_nonfinite, foreach)
    total_norm = _get_total_norm(grads, norm_type, error_if_nonfinite, foreach)

    # If total_norm is a DTensor, the placements must be `torch.distributed._tensor.ops.math_ops._NormPartial`.
    # We can simply reduce the DTensor to get the total norm in this tensor's process group
    # and then convert it to a local tensor.
    # It has two purposes:
    #   1. to make sure the total norm is computed correctly when PP is used (see below)
    #   2. to return a reduced total_norm tensor whose .item() would return the correct value
    if isinstance(total_norm, torch.distributed.tensor.DTensor):
        # Will reach here if any non-PP parallelism is used.
        # If only using PP, total_norm will be a local tensor.
        total_norm = total_norm.full_tensor()

    if pp_mesh is not None:
        raise NotImplementedError("Pipeline parallel is not supported")
        if math.isinf(norm_type):
            dist.all_reduce(total_norm,
                            op=dist.ReduceOp.MAX,
                            group=pp_mesh.get_group())
        else:
            total_norm **= norm_type
            dist.all_reduce(total_norm,
                            op=dist.ReduceOp.SUM,
                            group=pp_mesh.get_group())
            total_norm **= 1.0 / norm_type

    _clip_grads_with_norm_(parameters, max_norm, total_norm, foreach)
    return total_norm

fastvideo.training.training_utils.compute_density_for_timestep_sampling

compute_density_for_timestep_sampling(weighting_scheme: str, batch_size: int, generator, logit_mean: float | None = None, logit_std: float | None = None, mode_scale: float | None = None)

Compute the density for sampling the timesteps when doing SD3 training.

Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.

SD3 paper reference: https://arxiv.org/abs/2403.03206v1.

Source code in fastvideo/training/training_utils.py
def compute_density_for_timestep_sampling(
    weighting_scheme: str,
    batch_size: int,
    generator,
    logit_mean: float | None = None,
    logit_std: float | None = None,
    mode_scale: float | None = None,
):
    """
    Compute the density for sampling the timesteps when doing SD3 training.

    Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.

    SD3 paper reference: https://arxiv.org/abs/2403.03206v1.
    """
    if weighting_scheme == "logit_normal":
        # See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$).
        u = torch.normal(
            mean=logit_mean,
            std=logit_std,
            size=(batch_size, ),
            device="cpu",
            generator=generator,
        )
        u = torch.nn.functional.sigmoid(u)
    elif weighting_scheme == "mode":
        u = torch.rand(size=(batch_size, ), device="cpu", generator=generator)
        u = 1 - u - mode_scale * (torch.cos(math.pi * u / 2)**2 - 1 + u)
    else:
        u = torch.rand(size=(batch_size, ), device="cpu", generator=generator)
    return u

fastvideo.training.training_utils.custom_to_hf_state_dict

custom_to_hf_state_dict(state_dict: dict[str, Any] | Iterator[tuple[str, Tensor]], reverse_param_names_mapping: dict[str, tuple[str, int, int]]) -> dict[str, Any]

Convert fastvideo's custom model format to diffusers format using reverse_param_names_mapping.

Parameters:

Name Type Description Default
state_dict dict[str, Any] | Iterator[tuple[str, Tensor]]

State dict in fastvideo's custom format

required
reverse_param_names_mapping dict[str, tuple[str, int, int]]

Reverse mapping from fastvideo's custom format to diffusers format

required

Returns:

Type Description
dict[str, Any]

State dict in diffusers format

Source code in fastvideo/training/training_utils.py
def custom_to_hf_state_dict(
    state_dict: dict[str, Any] | Iterator[tuple[str, torch.Tensor]],
    reverse_param_names_mapping: dict[str, tuple[str, int,
                                                 int]]) -> dict[str, Any]:
    """
    Convert fastvideo's custom model format to diffusers format using reverse_param_names_mapping.

    Args:
        state_dict: State dict in fastvideo's custom format
        reverse_param_names_mapping: Reverse mapping from fastvideo's custom format to diffusers format

    Returns:
        State dict in diffusers format
    """
    assert len(
        reverse_param_names_mapping) > 0, "reverse_param_names_mapping is empty"
    if isinstance(state_dict, Iterator):
        state_dict = dict(state_dict)
    new_state_dict = {}
    # Group parameters that need to be split (merged parameters)
    merge_groups: dict[str, list[tuple[str, int, int]]] = {}

    # First pass: collect all merge groups
    for training_key, (
            diffusers_key, merge_index,
            num_params_to_merge) in reverse_param_names_mapping.items():
        if merge_index is not None:
            # This is a merged parameter that needs to be split
            if training_key not in merge_groups:
                merge_groups[training_key] = []
            merge_groups[training_key].append(
                (diffusers_key, merge_index, num_params_to_merge))

    # Second pass: handle merged parameters by splitting them
    used_keys = set()
    for training_key, splits in merge_groups.items():
        if training_key in state_dict:
            v = state_dict[training_key]
            # Sort by merge_index to ensure correct order
            splits.sort(key=lambda x: x[1])
            total = splits[0][2]
            split_size = v.shape[0] // total
            split_tensors = torch.split(v, split_size, dim=0)

            for diffusers_key, split_index, _ in splits:
                new_state_dict[diffusers_key] = split_tensors[split_index]
            used_keys.add(training_key)

    # Third pass: handle regular parameters (direct mappings)
    for training_key, v in state_dict.items():
        if training_key in used_keys:
            continue

        if training_key in reverse_param_names_mapping:
            diffusers_key, merge_index, _ = reverse_param_names_mapping[
                training_key]
            if merge_index is None:
                # Direct mapping
                new_state_dict[diffusers_key] = v
        else:
            # No mapping found, keep as is
            new_state_dict[training_key] = v

    return new_state_dict

fastvideo.training.training_utils.get_constant_schedule

get_constant_schedule(optimizer: Optimizer, last_epoch: int = -1) -> LambdaLR

Create a schedule with a constant learning rate, using the learning rate set in optimizer.

Parameters:

Name Type Description Default
optimizer [`~torch.optim.Optimizer`]

The optimizer for which to schedule the learning rate.

required
last_epoch `int`, *optional*, defaults to -1

The index of the last epoch when resuming training.

-1
Return

torch.optim.lr_scheduler.LambdaLR with the appropriate schedule.

Source code in fastvideo/training/training_utils.py
def get_constant_schedule(optimizer: Optimizer,
                          last_epoch: int = -1) -> LambdaLR:
    """
    Create a schedule with a constant learning rate, using the learning rate set in optimizer.

    Args:
        optimizer ([`~torch.optim.Optimizer`]):
            The optimizer for which to schedule the learning rate.
        last_epoch (`int`, *optional*, defaults to -1):
            The index of the last epoch when resuming training.

    Return:
        `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
    """
    return LambdaLR(optimizer, lambda _: 1, last_epoch=last_epoch)

fastvideo.training.training_utils.get_constant_schedule_with_warmup

get_constant_schedule_with_warmup(optimizer: Optimizer, num_warmup_steps: int, last_epoch: int = -1) -> LambdaLR

Create a schedule with a constant learning rate preceded by a warmup period during which the learning rate increases linearly between 0 and the initial lr set in the optimizer.

Parameters:

Name Type Description Default
optimizer [`~torch.optim.Optimizer`]

The optimizer for which to schedule the learning rate.

required
num_warmup_steps `int`

The number of steps for the warmup phase.

required
last_epoch `int`, *optional*, defaults to -1

The index of the last epoch when resuming training.

-1
Return

torch.optim.lr_scheduler.LambdaLR with the appropriate schedule.

Source code in fastvideo/training/training_utils.py
def get_constant_schedule_with_warmup(optimizer: Optimizer,
                                      num_warmup_steps: int,
                                      last_epoch: int = -1) -> LambdaLR:
    """
    Create a schedule with a constant learning rate preceded by a warmup period during which the learning rate
    increases linearly between 0 and the initial lr set in the optimizer.

    Args:
        optimizer ([`~torch.optim.Optimizer`]):
            The optimizer for which to schedule the learning rate.
        num_warmup_steps (`int`):
            The number of steps for the warmup phase.
        last_epoch (`int`, *optional*, defaults to -1):
            The index of the last epoch when resuming training.

    Return:
        `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
    """

    def lr_lambda(current_step: int):
        if current_step < num_warmup_steps:
            return float(current_step) / float(max(1.0, num_warmup_steps))
        return 1.0

    return LambdaLR(optimizer, lr_lambda, last_epoch=last_epoch)

fastvideo.training.training_utils.get_cosine_schedule_with_min_lr

get_cosine_schedule_with_min_lr(optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, min_lr_ratio: float = 0.1, num_cycles: float = 0.5, last_epoch: int = -1) -> LambdaLR

Create a schedule with a learning rate that decreases following the values of the cosine function between the initial lr set in the optimizer to a minimum lr (min_lr_ratio * initial_lr), after a warmup period during which it increases linearly between 0 and the initial lr set in the optimizer.

Parameters:

Name Type Description Default
optimizer [`~torch.optim.Optimizer`]

The optimizer for which to schedule the learning rate.

required
num_warmup_steps `int`

The number of steps for the warmup phase.

required
num_training_steps `int`

The total number of training steps.

required
min_lr_ratio `float`, *optional*, defaults to 0.1

The ratio of minimum learning rate to initial learning rate.

0.1
num_cycles `float`, *optional*, defaults to 0.5

The number of periods of the cosine function in a schedule.

0.5
last_epoch `int`, *optional*, defaults to -1

The index of the last epoch when resuming training.

-1
Return

torch.optim.lr_scheduler.LambdaLR with the appropriate schedule.

Source code in fastvideo/training/training_utils.py
def get_cosine_schedule_with_min_lr(optimizer: Optimizer,
                                    num_warmup_steps: int,
                                    num_training_steps: int,
                                    min_lr_ratio: float = 0.1,
                                    num_cycles: float = 0.5,
                                    last_epoch: int = -1) -> LambdaLR:
    """
    Create a schedule with a learning rate that decreases following the values of the cosine function between the
    initial lr set in the optimizer to a minimum lr (min_lr_ratio * initial_lr), after a warmup period during which 
    it increases linearly between 0 and the initial lr set in the optimizer.

    Args:
        optimizer ([`~torch.optim.Optimizer`]):
            The optimizer for which to schedule the learning rate.
        num_warmup_steps (`int`):
            The number of steps for the warmup phase.
        num_training_steps (`int`):
            The total number of training steps.
        min_lr_ratio (`float`, *optional*, defaults to 0.1):
            The ratio of minimum learning rate to initial learning rate.
        num_cycles (`float`, *optional*, defaults to 0.5):
            The number of periods of the cosine function in a schedule.
        last_epoch (`int`, *optional*, defaults to -1):
            The index of the last epoch when resuming training.

    Return:
        `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
    """

    def lr_lambda(current_step):
        if current_step < num_warmup_steps:
            return float(current_step) / float(max(1, num_warmup_steps))
        progress = float(current_step - num_warmup_steps) / float(
            max(1, num_training_steps - num_warmup_steps))
        # Cosine decay from 1.0 to min_lr_ratio over num_cycles periods
        # Use the same formula as standard cosine but ensure minimum is min_lr_ratio instead of 0
        cosine_value = 0.5 * (
            1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))
        # Ensure the value doesn't go below min_lr_ratio
        return max(min_lr_ratio, cosine_value)

    return LambdaLR(optimizer, lr_lambda, last_epoch)

fastvideo.training.training_utils.get_cosine_schedule_with_warmup

get_cosine_schedule_with_warmup(optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, num_cycles: float = 0.5, last_epoch: int = -1) -> LambdaLR

Create a schedule with a learning rate that decreases following the values of the cosine function between the initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the initial lr set in the optimizer.

Parameters:

Name Type Description Default
optimizer [`~torch.optim.Optimizer`]

The optimizer for which to schedule the learning rate.

required
num_warmup_steps `int`

The number of steps for the warmup phase.

required
num_training_steps `int`

The total number of training steps.

required
num_periods `float`, *optional*, defaults to 0.5

The number of periods of the cosine function in a schedule (the default is to just decrease from the max value to 0 following a half-cosine).

required
last_epoch `int`, *optional*, defaults to -1

The index of the last epoch when resuming training.

-1
Return

torch.optim.lr_scheduler.LambdaLR with the appropriate schedule.

Source code in fastvideo/training/training_utils.py
def get_cosine_schedule_with_warmup(optimizer: Optimizer,
                                    num_warmup_steps: int,
                                    num_training_steps: int,
                                    num_cycles: float = 0.5,
                                    last_epoch: int = -1) -> LambdaLR:
    """
    Create a schedule with a learning rate that decreases following the values of the cosine function between the
    initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the
    initial lr set in the optimizer.

    Args:
        optimizer ([`~torch.optim.Optimizer`]):
            The optimizer for which to schedule the learning rate.
        num_warmup_steps (`int`):
            The number of steps for the warmup phase.
        num_training_steps (`int`):
            The total number of training steps.
        num_periods (`float`, *optional*, defaults to 0.5):
            The number of periods of the cosine function in a schedule (the default is to just decrease from the max
            value to 0 following a half-cosine).
        last_epoch (`int`, *optional*, defaults to -1):
            The index of the last epoch when resuming training.

    Return:
        `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
    """

    def lr_lambda(current_step):
        if current_step < num_warmup_steps:
            return float(current_step) / float(max(1, num_warmup_steps))
        progress = float(current_step - num_warmup_steps) / float(
            max(1, num_training_steps - num_warmup_steps))
        return max(
            0.0, 0.5 *
            (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))

    return LambdaLR(optimizer, lr_lambda, last_epoch)

fastvideo.training.training_utils.get_cosine_with_hard_restarts_schedule_with_warmup

get_cosine_with_hard_restarts_schedule_with_warmup(optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, num_cycles: int = 1, last_epoch: int = -1) -> LambdaLR

Create a schedule with a learning rate that decreases following the values of the cosine function between the initial lr set in the optimizer to 0, with several hard restarts, after a warmup period during which it increases linearly between 0 and the initial lr set in the optimizer.

Parameters:

Name Type Description Default
optimizer [`~torch.optim.Optimizer`]

The optimizer for which to schedule the learning rate.

required
num_warmup_steps `int`

The number of steps for the warmup phase.

required
num_training_steps `int`

The total number of training steps.

required
num_cycles `int`, *optional*, defaults to 1

The number of hard restarts to use.

1
last_epoch `int`, *optional*, defaults to -1

The index of the last epoch when resuming training.

-1
Return

torch.optim.lr_scheduler.LambdaLR with the appropriate schedule.

Source code in fastvideo/training/training_utils.py
def get_cosine_with_hard_restarts_schedule_with_warmup(
        optimizer: Optimizer,
        num_warmup_steps: int,
        num_training_steps: int,
        num_cycles: int = 1,
        last_epoch: int = -1) -> LambdaLR:
    """
    Create a schedule with a learning rate that decreases following the values of the cosine function between the
    initial lr set in the optimizer to 0, with several hard restarts, after a warmup period during which it increases
    linearly between 0 and the initial lr set in the optimizer.

    Args:
        optimizer ([`~torch.optim.Optimizer`]):
            The optimizer for which to schedule the learning rate.
        num_warmup_steps (`int`):
            The number of steps for the warmup phase.
        num_training_steps (`int`):
            The total number of training steps.
        num_cycles (`int`, *optional*, defaults to 1):
            The number of hard restarts to use.
        last_epoch (`int`, *optional*, defaults to -1):
            The index of the last epoch when resuming training.

    Return:
        `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
    """

    def lr_lambda(current_step):
        if current_step < num_warmup_steps:
            return float(current_step) / float(max(1, num_warmup_steps))
        progress = float(current_step - num_warmup_steps) / float(
            max(1, num_training_steps - num_warmup_steps))
        if progress >= 1.0:
            return 0.0
        return max(
            0.0, 0.5 * (1.0 + math.cos(math.pi *
                                       ((float(num_cycles) * progress) % 1.0))))

    return LambdaLR(optimizer, lr_lambda, last_epoch)

fastvideo.training.training_utils.get_linear_schedule_with_warmup

get_linear_schedule_with_warmup(optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, last_epoch: int = -1) -> LambdaLR

Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer.

Parameters:

Name Type Description Default
optimizer [`~torch.optim.Optimizer`]

The optimizer for which to schedule the learning rate.

required
num_warmup_steps `int`

The number of steps for the warmup phase.

required
num_training_steps `int`

The total number of training steps.

required
last_epoch `int`, *optional*, defaults to -1

The index of the last epoch when resuming training.

-1
Return

torch.optim.lr_scheduler.LambdaLR with the appropriate schedule.

Source code in fastvideo/training/training_utils.py
def get_linear_schedule_with_warmup(optimizer: Optimizer,
                                    num_warmup_steps: int,
                                    num_training_steps: int,
                                    last_epoch: int = -1) -> LambdaLR:
    """
    Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after
    a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer.

    Args:
        optimizer ([`~torch.optim.Optimizer`]):
            The optimizer for which to schedule the learning rate.
        num_warmup_steps (`int`):
            The number of steps for the warmup phase.
        num_training_steps (`int`):
            The total number of training steps.
        last_epoch (`int`, *optional*, defaults to -1):
            The index of the last epoch when resuming training.

    Return:
        `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
    """

    def lr_lambda(current_step: int):
        if current_step < num_warmup_steps:
            return float(current_step) / float(max(1, num_warmup_steps))
        return max(
            0.0,
            float(num_training_steps - current_step) /
            float(max(1, num_training_steps - num_warmup_steps)))

    return LambdaLR(optimizer, lr_lambda, last_epoch)

fastvideo.training.training_utils.get_piecewise_constant_schedule

get_piecewise_constant_schedule(optimizer: Optimizer, step_rules: str, last_epoch: int = -1) -> LambdaLR

Create a schedule with a constant learning rate, using the learning rate set in optimizer.

Parameters:

Name Type Description Default
optimizer [`~torch.optim.Optimizer`]

The optimizer for which to schedule the learning rate.

required
step_rules `string`

The rules for the learning rate. ex: rule_steps="1:10,0.1:20,0.01:30,0.005" it means that the learning rate if multiple 1 for the first 10 steps, multiple 0.1 for the next 20 steps, multiple 0.01 for the next 30 steps and multiple 0.005 for the other steps.

required
last_epoch `int`, *optional*, defaults to -1

The index of the last epoch when resuming training.

-1
Return

torch.optim.lr_scheduler.LambdaLR with the appropriate schedule.

Source code in fastvideo/training/training_utils.py
def get_piecewise_constant_schedule(optimizer: Optimizer,
                                    step_rules: str,
                                    last_epoch: int = -1) -> LambdaLR:
    """
    Create a schedule with a constant learning rate, using the learning rate set in optimizer.

    Args:
        optimizer ([`~torch.optim.Optimizer`]):
            The optimizer for which to schedule the learning rate.
        step_rules (`string`):
            The rules for the learning rate. ex: rule_steps="1:10,0.1:20,0.01:30,0.005" it means that the learning rate
            if multiple 1 for the first 10 steps, multiple 0.1 for the next 20 steps, multiple 0.01 for the next 30
            steps and multiple 0.005 for the other steps.
        last_epoch (`int`, *optional*, defaults to -1):
            The index of the last epoch when resuming training.

    Return:
        `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
    """

    rules_dict = {}
    rule_list = step_rules.split(",")
    for rule_str in rule_list[:-1]:
        value_str, steps_str = rule_str.split(":")
        steps = int(steps_str)
        value = float(value_str)
        rules_dict[steps] = value
    last_lr_multiple = float(rule_list[-1])

    def create_rules_function(
            rules_dict: dict,
            last_lr_multiple: float) -> Callable[[int], float]:

        def rule_func(steps: int) -> float:
            for step_threshold, lr_multiple in sorted(rules_dict.items()):
                if steps < step_threshold:
                    return lr_multiple
            return last_lr_multiple

        return rule_func

    rules_func = create_rules_function(rules_dict, last_lr_multiple)

    return LambdaLR(optimizer, rules_func, last_epoch=last_epoch)

fastvideo.training.training_utils.get_polynomial_decay_schedule_with_warmup

get_polynomial_decay_schedule_with_warmup(optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, lr_end: float = 1e-07, power: float = 1.0, last_epoch: int = -1) -> LambdaLR

Create a schedule with a learning rate that decreases as a polynomial decay from the initial lr set in the optimizer to end lr defined by lr_end, after a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer.

Parameters:

Name Type Description Default
optimizer [`~torch.optim.Optimizer`]

The optimizer for which to schedule the learning rate.

required
num_warmup_steps `int`

The number of steps for the warmup phase.

required
num_training_steps `int`

The total number of training steps.

required
lr_end `float`, *optional*, defaults to 1e-7

The end LR.

1e-07
power `float`, *optional*, defaults to 1.0

Power factor.

1.0
last_epoch `int`, *optional*, defaults to -1

The index of the last epoch when resuming training.

-1

Note: power defaults to 1.0 as in the fairseq implementation, which in turn is based on the original BERT implementation at https://github.com/google-research/bert/blob/f39e881b169b9d53bea03d2d341b31707a6c052b/optimization.py#L37

Return

torch.optim.lr_scheduler.LambdaLR with the appropriate schedule.

Source code in fastvideo/training/training_utils.py
def get_polynomial_decay_schedule_with_warmup(
    optimizer: Optimizer,
    num_warmup_steps: int,
    num_training_steps: int,
    lr_end: float = 1e-7,
    power: float = 1.0,
    last_epoch: int = -1,
) -> LambdaLR:
    """
    Create a schedule with a learning rate that decreases as a polynomial decay from the initial lr set in the
    optimizer to end lr defined by *lr_end*, after a warmup period during which it increases linearly from 0 to the
    initial lr set in the optimizer.

    Args:
        optimizer ([`~torch.optim.Optimizer`]):
            The optimizer for which to schedule the learning rate.
        num_warmup_steps (`int`):
            The number of steps for the warmup phase.
        num_training_steps (`int`):
            The total number of training steps.
        lr_end (`float`, *optional*, defaults to 1e-7):
            The end LR.
        power (`float`, *optional*, defaults to 1.0):
            Power factor.
        last_epoch (`int`, *optional*, defaults to -1):
            The index of the last epoch when resuming training.

    Note: *power* defaults to 1.0 as in the fairseq implementation, which in turn is based on the original BERT
    implementation at
    https://github.com/google-research/bert/blob/f39e881b169b9d53bea03d2d341b31707a6c052b/optimization.py#L37

    Return:
        `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.

    """

    lr_init = optimizer.defaults["lr"]
    if not (lr_init > lr_end):
        raise ValueError(
            f"lr_end ({lr_end}) must be smaller than initial lr ({lr_init})")

    def lr_lambda(current_step: int):
        if current_step < num_warmup_steps:
            return float(current_step) / float(max(1, num_warmup_steps))
        elif current_step > num_training_steps:
            return lr_end / lr_init  # as LambdaLR multiplies by lr_init
        else:
            lr_range = lr_init - lr_end
            decay_steps = num_training_steps - num_warmup_steps
            pct_remaining = 1 - (current_step - num_warmup_steps) / decay_steps
            decay = lr_range * pct_remaining**power + lr_end
            return decay / lr_init  # as LambdaLR multiplies by lr_init

    return LambdaLR(optimizer, lr_lambda, last_epoch)

fastvideo.training.training_utils.get_scheduler

get_scheduler(name: str | SchedulerType, optimizer: Optimizer, step_rules: str | None = None, num_warmup_steps: int | None = None, num_training_steps: int | None = None, num_cycles: int = 1, power: float = 1.0, min_lr_ratio: float = 0.1, last_epoch: int = -1) -> LambdaLR

Unified API to get any scheduler from its name.

Parameters:

Name Type Description Default
name `str` or `SchedulerType`

The name of the scheduler to use.

required
optimizer `torch.optim.Optimizer`

The optimizer that will be used during training.

required
step_rules `str`, *optional*

A string representing the step rules to use. This is only used by the PIECEWISE_CONSTANT scheduler.

None
num_warmup_steps `int`, *optional*

The number of warmup steps to do. This is not required by all schedulers (hence the argument being optional), the function will raise an error if it's unset and the scheduler type requires it.

None
num_training_steps `int``, *optional*

The number of training steps to do. This is not required by all schedulers (hence the argument being optional), the function will raise an error if it's unset and the scheduler type requires it.

None
num_cycles `int`, *optional*

The number of hard restarts used in COSINE_WITH_RESTARTS scheduler.

1
power `float`, *optional*, defaults to 1.0

Power factor. See POLYNOMIAL scheduler

1.0
min_lr_ratio `float`, *optional*, defaults to 0.1

The ratio of minimum learning rate to initial learning rate. Used in COSINE_WITH_MIN_LR scheduler.

0.1
last_epoch `int`, *optional*, defaults to -1

The index of the last epoch when resuming training.

-1
Source code in fastvideo/training/training_utils.py
def get_scheduler(
    name: str | SchedulerType,
    optimizer: Optimizer,
    step_rules: str | None = None,
    num_warmup_steps: int | None = None,
    num_training_steps: int | None = None,
    num_cycles: int = 1,
    power: float = 1.0,
    min_lr_ratio: float = 0.1,
    last_epoch: int = -1,
) -> LambdaLR:
    """
    Unified API to get any scheduler from its name.

    Args:
        name (`str` or `SchedulerType`):
            The name of the scheduler to use.
        optimizer (`torch.optim.Optimizer`):
            The optimizer that will be used during training.
        step_rules (`str`, *optional*):
            A string representing the step rules to use. This is only used by the `PIECEWISE_CONSTANT` scheduler.
        num_warmup_steps (`int`, *optional*):
            The number of warmup steps to do. This is not required by all schedulers (hence the argument being
            optional), the function will raise an error if it's unset and the scheduler type requires it.
        num_training_steps (`int``, *optional*):
            The number of training steps to do. This is not required by all schedulers (hence the argument being
            optional), the function will raise an error if it's unset and the scheduler type requires it.
        num_cycles (`int`, *optional*):
            The number of hard restarts used in `COSINE_WITH_RESTARTS` scheduler.
        power (`float`, *optional*, defaults to 1.0):
            Power factor. See `POLYNOMIAL` scheduler
        min_lr_ratio (`float`, *optional*, defaults to 0.1):
            The ratio of minimum learning rate to initial learning rate. Used in `COSINE_WITH_MIN_LR` scheduler.
        last_epoch (`int`, *optional*, defaults to -1):
            The index of the last epoch when resuming training.
    """
    name = SchedulerType(name)
    schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
    if name == SchedulerType.CONSTANT:
        return schedule_func(optimizer, last_epoch=last_epoch)

    if name == SchedulerType.PIECEWISE_CONSTANT:
        return schedule_func(optimizer,
                             step_rules=step_rules,
                             last_epoch=last_epoch)

    # All other schedulers require `num_warmup_steps`
    if num_warmup_steps is None:
        raise ValueError(
            f"{name} requires `num_warmup_steps`, please provide that argument."
        )

    if name == SchedulerType.CONSTANT_WITH_WARMUP:
        return schedule_func(optimizer,
                             num_warmup_steps=num_warmup_steps,
                             last_epoch=last_epoch)

    # All other schedulers require `num_training_steps`
    if num_training_steps is None:
        raise ValueError(
            f"{name} requires `num_training_steps`, please provide that argument."
        )

    if name == SchedulerType.COSINE_WITH_RESTARTS:
        return schedule_func(
            optimizer,
            num_warmup_steps=num_warmup_steps,
            num_training_steps=num_training_steps,
            num_cycles=num_cycles,
            last_epoch=last_epoch,
        )

    if name == SchedulerType.POLYNOMIAL:
        return schedule_func(
            optimizer,
            num_warmup_steps=num_warmup_steps,
            num_training_steps=num_training_steps,
            power=power,
            last_epoch=last_epoch,
        )

    if name == SchedulerType.COSINE_WITH_MIN_LR:
        return schedule_func(
            optimizer,
            num_warmup_steps=num_warmup_steps,
            num_training_steps=num_training_steps,
            min_lr_ratio=min_lr_ratio,
            last_epoch=last_epoch,
        )

    return schedule_func(optimizer,
                         num_warmup_steps=num_warmup_steps,
                         num_training_steps=num_training_steps,
                         last_epoch=last_epoch)

fastvideo.training.training_utils.load_checkpoint

load_checkpoint(transformer, rank, checkpoint_path, optimizer=None, dataloader=None, scheduler=None, noise_generator=None) -> int

Load checkpoint following finetrainer's distributed checkpoint approach. Returns the step number from which training should resume.

Source code in fastvideo/training/training_utils.py
def load_checkpoint(transformer,
                    rank,
                    checkpoint_path,
                    optimizer=None,
                    dataloader=None,
                    scheduler=None,
                    noise_generator=None) -> int:
    """
    Load checkpoint following finetrainer's distributed checkpoint approach.
    Returns the step number from which training should resume.
    """
    if not os.path.exists(checkpoint_path):
        logger.warning("Checkpoint path %s does not exist", checkpoint_path)
        return 0

    # Extract step number from checkpoint path
    step = int(os.path.basename(checkpoint_path).split('-')[-1])

    if rank == 0:
        logger.info("Loading checkpoint from step %s", step)

    dcp_dir = os.path.join(checkpoint_path, "distributed_checkpoint")

    if not os.path.exists(dcp_dir):
        logger.warning("Distributed checkpoint directory %s does not exist",
                       dcp_dir)
        return 0

    states = {
        "model": ModelWrapper(transformer),
        "random_state": RandomStateWrapper(noise_generator),
    }

    if optimizer is not None:
        states["optimizer"] = OptimizerWrapper(transformer, optimizer)

    if dataloader is not None:
        states["dataloader"] = dataloader

    if scheduler is not None:
        states["scheduler"] = SchedulerWrapper(scheduler)

    logger.info("rank: %s, loading distributed checkpoint from %s",
                rank,
                dcp_dir,
                local_main_process_only=False)

    begin_time = time.perf_counter()
    dcp.load(states, checkpoint_id=dcp_dir)
    end_time = time.perf_counter()

    logger.info("rank: %s, distributed checkpoint loaded in %.2f seconds",
                rank,
                end_time - begin_time,
                local_main_process_only=False)
    logger.info("--> checkpoint loaded from step %s", step)

    return step

fastvideo.training.training_utils.load_distillation_checkpoint

load_distillation_checkpoint(generator_transformer, fake_score_transformer, rank, checkpoint_path, generator_optimizer=None, fake_score_optimizer=None, dataloader=None, generator_scheduler=None, fake_score_scheduler=None, noise_generator=None, generator_ema=None, generator_transformer_2=None, real_score_transformer_2=None, fake_score_transformer_2=None, generator_optimizer_2=None, fake_score_optimizer_2=None, generator_scheduler_2=None, fake_score_scheduler_2=None, generator_ema_2=None) -> int

Load distillation checkpoint with both generator and fake_score models. Supports MoE (Mixture of Experts) models with transformer_2 variants. Returns the step number from which training should resume.

Parameters:

Name Type Description Default
generator_transformer

Main generator transformer model

required
fake_score_transformer

Main fake score transformer model

required
generator_transformer_2

Secondary generator transformer for MoE (optional)

None
real_score_transformer_2

Secondary real score transformer for MoE (optional)

None
fake_score_transformer_2

Secondary fake score transformer for MoE (optional)

None
generator_optimizer_2

Optimizer for generator_transformer_2 (optional)

None
fake_score_optimizer_2

Optimizer for fake_score_transformer_2 (optional)

None
generator_scheduler_2

Scheduler for generator_transformer_2 (optional)

None
fake_score_scheduler_2

Scheduler for fake_score_transformer_2 (optional)

None
generator_ema_2

EMA for generator_transformer_2 (optional)

None
Source code in fastvideo/training/training_utils.py
def load_distillation_checkpoint(
        generator_transformer,
        fake_score_transformer,
        rank,
        checkpoint_path,
        generator_optimizer=None,
        fake_score_optimizer=None,
        dataloader=None,
        generator_scheduler=None,
        fake_score_scheduler=None,
        noise_generator=None,
        generator_ema=None,
        # MoE support
        generator_transformer_2=None,
        real_score_transformer_2=None,
        fake_score_transformer_2=None,
        generator_optimizer_2=None,
        fake_score_optimizer_2=None,
        generator_scheduler_2=None,
        fake_score_scheduler_2=None,
        generator_ema_2=None) -> int:
    """
    Load distillation checkpoint with both generator and fake_score models.
    Supports MoE (Mixture of Experts) models with transformer_2 variants.
    Returns the step number from which training should resume.

    Args:
        generator_transformer: Main generator transformer model
        fake_score_transformer: Main fake score transformer model
        generator_transformer_2: Secondary generator transformer for MoE (optional)
        real_score_transformer_2: Secondary real score transformer for MoE (optional)
        fake_score_transformer_2: Secondary fake score transformer for MoE (optional)
        generator_optimizer_2: Optimizer for generator_transformer_2 (optional)
        fake_score_optimizer_2: Optimizer for fake_score_transformer_2 (optional)
        generator_scheduler_2: Scheduler for generator_transformer_2 (optional)
        fake_score_scheduler_2: Scheduler for fake_score_transformer_2 (optional)
        generator_ema_2: EMA for generator_transformer_2 (optional)
    """
    if not os.path.exists(checkpoint_path):
        logger.warning("Distillation checkpoint path %s does not exist",
                       checkpoint_path)
        return 0

    # Extract step number from checkpoint path
    step = int(os.path.basename(checkpoint_path).split('-')[-1])

    if rank == 0:
        logger.info("Loading distillation checkpoint from step %s", step)

    # Load generator distributed checkpoint
    generator_dcp_dir = os.path.join(checkpoint_path, "distributed_checkpoint",
                                     "generator")
    if not os.path.exists(generator_dcp_dir):
        logger.warning(
            "Generator distributed checkpoint directory %s does not exist",
            generator_dcp_dir)
        return 0

    generator_states = {
        "model": ModelWrapper(generator_transformer),
    }

    if generator_optimizer is not None:
        generator_states["optimizer"] = OptimizerWrapper(
            generator_transformer, generator_optimizer)

    if dataloader is not None:
        generator_states["dataloader"] = dataloader

    if generator_scheduler is not None:
        generator_states["scheduler"] = SchedulerWrapper(generator_scheduler)

    logger.info("rank: %s, loading generator distributed checkpoint from %s",
                rank,
                generator_dcp_dir,
                local_main_process_only=False)

    begin_time = time.perf_counter()
    dcp.load(generator_states, checkpoint_id=generator_dcp_dir)
    end_time = time.perf_counter()

    logger.info(
        "rank: %s, generator distributed checkpoint loaded in %.2f seconds",
        rank,
        end_time - begin_time,
        local_main_process_only=False)

    # Load EMA state if available and generator_ema is provided
    if generator_ema is not None:
        try:
            ema_state = generator_states.get("ema")
            if ema_state is not None:
                generator_ema.load_state_dict(ema_state)
                logger.info("rank: %s, generator EMA state loaded successfully",
                            rank)
            else:
                logger.info("rank: %s, no EMA state found in checkpoint", rank)
        except Exception as e:
            logger.warning("rank: %s, failed to load EMA state: %s", rank,
                           str(e))

    # Load generator_2 distributed checkpoint (MoE support)
    if generator_transformer_2 is not None:
        generator_2_dcp_dir = os.path.join(checkpoint_path,
                                           "distributed_checkpoint",
                                           "generator_2")
        if os.path.exists(generator_2_dcp_dir):
            generator_2_states = {
                "model": ModelWrapper(generator_transformer_2),
            }

            if generator_optimizer_2 is not None:
                generator_2_states["optimizer"] = OptimizerWrapper(
                    generator_transformer_2, generator_optimizer_2)

            if dataloader is not None:
                generator_2_states["dataloader"] = dataloader

            if generator_scheduler_2 is not None:
                generator_2_states["scheduler"] = SchedulerWrapper(
                    generator_scheduler_2)

            logger.info(
                "rank: %s, loading generator_2 distributed checkpoint from %s",
                rank,
                generator_2_dcp_dir,
                local_main_process_only=False)

            begin_time = time.perf_counter()
            dcp.load(generator_2_states, checkpoint_id=generator_2_dcp_dir)
            end_time = time.perf_counter()

            logger.info(
                "rank: %s, generator_2 distributed checkpoint loaded in %.2f seconds",
                rank,
                end_time - begin_time,
                local_main_process_only=False)

            # Load EMA_2 state if available and generator_ema_2 is provided
            if generator_ema_2 is not None:
                try:
                    ema_2_state = generator_2_states.get("ema")
                    if ema_2_state is not None:
                        generator_ema_2.load_state_dict(ema_2_state)
                        logger.info(
                            "rank: %s, generator_2 EMA state loaded successfully",
                            rank)
                    else:
                        logger.info(
                            "rank: %s, no EMA_2 state found in checkpoint",
                            rank)
                except Exception as e:
                    logger.warning("rank: %s, failed to load EMA_2 state: %s",
                                   rank, str(e))
        else:
            logger.info("rank: %s, generator_2 checkpoint not found, skipping",
                        rank)

    # Load critic distributed checkpoint
    critic_dcp_dir = os.path.join(checkpoint_path, "distributed_checkpoint",
                                  "critic")
    if not os.path.exists(critic_dcp_dir):
        logger.warning(
            "Critic distributed checkpoint directory %s does not exist",
            critic_dcp_dir)
        return 0

    critic_states = {
        "model": ModelWrapper(fake_score_transformer),
    }

    if fake_score_optimizer is not None:
        critic_states["optimizer"] = OptimizerWrapper(fake_score_transformer,
                                                      fake_score_optimizer)

    if dataloader is not None:
        critic_states["dataloader"] = dataloader

    if fake_score_scheduler is not None:
        critic_states["scheduler"] = SchedulerWrapper(fake_score_scheduler)

    logger.info("rank: %s, loading critic distributed checkpoint from %s",
                rank,
                critic_dcp_dir,
                local_main_process_only=False)

    begin_time = time.perf_counter()
    dcp.load(critic_states, checkpoint_id=critic_dcp_dir)
    end_time = time.perf_counter()

    logger.info(
        "rank: %s, critic distributed checkpoint loaded in %.2f seconds",
        rank,
        end_time - begin_time,
        local_main_process_only=False)

    # Load critic_2 distributed checkpoint (MoE support)
    if fake_score_transformer_2 is not None:
        critic_2_dcp_dir = os.path.join(checkpoint_path,
                                        "distributed_checkpoint", "critic_2")
        if os.path.exists(critic_2_dcp_dir):
            critic_2_states = {
                "model": ModelWrapper(fake_score_transformer_2),
            }

            if fake_score_optimizer_2 is not None:
                critic_2_states["optimizer"] = OptimizerWrapper(
                    fake_score_transformer_2, fake_score_optimizer_2)

            if dataloader is not None:
                critic_2_states["dataloader"] = dataloader

            if fake_score_scheduler_2 is not None:
                critic_2_states["scheduler"] = SchedulerWrapper(
                    fake_score_scheduler_2)

            logger.info(
                "rank: %s, loading critic_2 distributed checkpoint from %s",
                rank,
                critic_2_dcp_dir,
                local_main_process_only=False)

            begin_time = time.perf_counter()
            dcp.load(critic_2_states, checkpoint_id=critic_2_dcp_dir)
            end_time = time.perf_counter()

            logger.info(
                "rank: %s, critic_2 distributed checkpoint loaded in %.2f seconds",
                rank,
                end_time - begin_time,
                local_main_process_only=False)
        else:
            logger.info("rank: %s, critic_2 checkpoint not found, skipping",
                        rank)

    # Load real_score_2 distributed checkpoint (MoE support)
    if real_score_transformer_2 is not None:
        real_score_2_dcp_dir = os.path.join(checkpoint_path,
                                            "distributed_checkpoint",
                                            "real_score_2")
        if os.path.exists(real_score_2_dcp_dir):
            real_score_2_states = {
                "model": ModelWrapper(real_score_transformer_2),
            }

            if dataloader is not None:
                real_score_2_states["dataloader"] = dataloader

            logger.info(
                "rank: %s, loading real_score_2 distributed checkpoint from %s",
                rank,
                real_score_2_dcp_dir,
                local_main_process_only=False)

            begin_time = time.perf_counter()
            dcp.load(real_score_2_states, checkpoint_id=real_score_2_dcp_dir)
            end_time = time.perf_counter()

            logger.info(
                "rank: %s, real_score_2 distributed checkpoint loaded in %.2f seconds",
                rank,
                end_time - begin_time,
                local_main_process_only=False)
        else:
            logger.info("rank: %s, real_score_2 checkpoint not found, skipping",
                        rank)

    # Load shared random state
    shared_dcp_dir = os.path.join(checkpoint_path, "distributed_checkpoint",
                                  "shared")
    if not os.path.exists(shared_dcp_dir):
        logger.warning("Shared random state directory %s does not exist",
                       shared_dcp_dir)
        return 0

    shared_states = {
        "random_state": RandomStateWrapper(noise_generator),
    }

    begin_time = time.perf_counter()
    dcp.load(shared_states, checkpoint_id=shared_dcp_dir)
    end_time = time.perf_counter()

    logger.info("rank: %s, shared random state loaded in %.2f seconds",
                rank,
                end_time - begin_time,
                local_main_process_only=False)
    logger.info("--> distillation checkpoint loaded from step %s", step)
    return step

fastvideo.training.training_utils.save_checkpoint

save_checkpoint(transformer, rank, output_dir, step, optimizer=None, dataloader=None, scheduler=None, noise_generator=None) -> None

Save checkpoint following finetrainer's distributed checkpoint approach. Saves both distributed checkpoint and consolidated model weights.

Source code in fastvideo/training/training_utils.py
def save_checkpoint(transformer,
                    rank,
                    output_dir,
                    step,
                    optimizer=None,
                    dataloader=None,
                    scheduler=None,
                    noise_generator=None) -> None:
    """
    Save checkpoint following finetrainer's distributed checkpoint approach.
    Saves both distributed checkpoint and consolidated model weights.
    """
    save_dir = os.path.join(output_dir, f"checkpoint-{step}")
    os.makedirs(save_dir, exist_ok=True)

    states = {
        "model": ModelWrapper(transformer),
        "random_state": RandomStateWrapper(noise_generator),
    }

    if optimizer is not None:
        states["optimizer"] = OptimizerWrapper(transformer, optimizer)

    if dataloader is not None:
        states["dataloader"] = dataloader

    if scheduler is not None:
        states["scheduler"] = SchedulerWrapper(scheduler)
    dcp_dir = os.path.join(save_dir, "distributed_checkpoint")
    logger.info("rank: %s, saving distributed checkpoint to %s",
                rank,
                dcp_dir,
                local_main_process_only=False)

    begin_time = time.perf_counter()
    dcp.save(states, checkpoint_id=dcp_dir)
    end_time = time.perf_counter()

    logger.info("rank: %s, distributed checkpoint saved in %.2f seconds",
                rank,
                end_time - begin_time,
                local_main_process_only=False)

    cpu_state = gather_state_dict_on_cpu_rank0(transformer, device=None)
    if rank == 0:
        # Save model weights (consolidated)
        transformer_save_dir = os.path.join(save_dir, "transformer")
        os.makedirs(transformer_save_dir, exist_ok=True)
        weight_path = os.path.join(transformer_save_dir,
                                   "diffusion_pytorch_model.safetensors")
        logger.info("rank: %s, saving consolidated checkpoint to %s",
                    rank,
                    weight_path,
                    local_main_process_only=False)

        # Convert training format to diffusers format and save
        diffusers_state_dict = custom_to_hf_state_dict(
            cpu_state, transformer.reverse_param_names_mapping)
        save_file(diffusers_state_dict, weight_path)

        logger.info("rank: %s, consolidated checkpoint saved to %s",
                    rank,
                    weight_path,
                    local_main_process_only=False)

        # Save model config
        config_dict = transformer.hf_config
        if "dtype" in config_dict:
            del config_dict["dtype"]  # TODO
        config_path = os.path.join(transformer_save_dir, "config.json")
        # save dict as json
        with open(config_path, "w") as f:
            json.dump(config_dict, f, indent=4)
        logger.info("--> checkpoint saved at step %s to %s", step, weight_path)

fastvideo.training.training_utils.save_distillation_checkpoint

save_distillation_checkpoint(generator_transformer, fake_score_transformer, rank, output_dir, step, generator_optimizer=None, fake_score_optimizer=None, dataloader=None, generator_scheduler=None, fake_score_scheduler=None, noise_generator=None, generator_ema=None, only_save_generator_weight=False, generator_transformer_2=None, real_score_transformer_2=None, fake_score_transformer_2=None, generator_optimizer_2=None, fake_score_optimizer_2=None, generator_scheduler_2=None, fake_score_scheduler_2=None, generator_ema_2=None) -> None

Save distillation checkpoint with both generator and fake_score models. Supports MoE (Mixture of Experts) models with transformer_2 variants. Saves both distributed checkpoint and consolidated model weights. Only saves the generator model for inference (consolidated weights).

Parameters:

Name Type Description Default
generator_transformer

Main generator transformer model

required
fake_score_transformer

Main fake score transformer model

required
only_save_generator_weight

If True, only save the generator model weights for inference without saving distributed checkpoint for training resume.

False
generator_transformer_2

Secondary generator transformer for MoE (optional)

None
real_score_transformer_2

Secondary real score transformer for MoE (optional)

None
fake_score_transformer_2

Secondary fake score transformer for MoE (optional)

None
generator_optimizer_2

Optimizer for generator_transformer_2 (optional)

None
fake_score_optimizer_2

Optimizer for fake_score_transformer_2 (optional)

None
generator_scheduler_2

Scheduler for generator_transformer_2 (optional)

None
fake_score_scheduler_2

Scheduler for fake_score_transformer_2 (optional)

None
generator_ema_2

EMA for generator_transformer_2 (optional)

None
Source code in fastvideo/training/training_utils.py
def save_distillation_checkpoint(
        generator_transformer,
        fake_score_transformer,
        rank,
        output_dir,
        step,
        generator_optimizer=None,
        fake_score_optimizer=None,
        dataloader=None,
        generator_scheduler=None,
        fake_score_scheduler=None,
        noise_generator=None,
        generator_ema=None,
        only_save_generator_weight=False,
        # MoE support
        generator_transformer_2=None,
        real_score_transformer_2=None,
        fake_score_transformer_2=None,
        generator_optimizer_2=None,
        fake_score_optimizer_2=None,
        generator_scheduler_2=None,
        fake_score_scheduler_2=None,
        generator_ema_2=None) -> None:
    """
    Save distillation checkpoint with both generator and fake_score models.
    Supports MoE (Mixture of Experts) models with transformer_2 variants.
    Saves both distributed checkpoint and consolidated model weights.
    Only saves the generator model for inference (consolidated weights).

    Args:
        generator_transformer: Main generator transformer model
        fake_score_transformer: Main fake score transformer model
        only_save_generator_weight: If True, only save the generator model weights for inference
                                   without saving distributed checkpoint for training resume.
        generator_transformer_2: Secondary generator transformer for MoE (optional)
        real_score_transformer_2: Secondary real score transformer for MoE (optional) 
        fake_score_transformer_2: Secondary fake score transformer for MoE (optional)
        generator_optimizer_2: Optimizer for generator_transformer_2 (optional)
        fake_score_optimizer_2: Optimizer for fake_score_transformer_2 (optional)
        generator_scheduler_2: Scheduler for generator_transformer_2 (optional)
        fake_score_scheduler_2: Scheduler for fake_score_transformer_2 (optional)
        generator_ema_2: EMA for generator_transformer_2 (optional)
    """
    save_dir = os.path.join(output_dir, f"checkpoint-{step}")
    os.makedirs(save_dir, exist_ok=True)

    # Create directories for models
    inference_save_dir = os.path.join(save_dir,
                                      "generator_inference_transformer")

    # Only save distributed checkpoint if not only saving generator weight
    if not only_save_generator_weight:
        # Save generator distributed checkpoint
        generator_states = {
            "model": ModelWrapper(generator_transformer),
        }
        if generator_optimizer is not None:
            generator_states["optimizer"] = OptimizerWrapper(
                generator_transformer, generator_optimizer)
        if dataloader is not None:
            generator_states["dataloader"] = dataloader
        if generator_scheduler is not None:
            generator_states["scheduler"] = SchedulerWrapper(
                generator_scheduler)
        if generator_ema is not None:
            generator_states["ema"] = generator_ema.state_dict()

        generator_dcp_dir = os.path.join(save_dir, "distributed_checkpoint",
                                         "generator")
        logger.info("rank: %s, saving generator distributed checkpoint to %s",
                    rank,
                    generator_dcp_dir,
                    local_main_process_only=False)

        begin_time = time.perf_counter()
        dcp.save(generator_states, checkpoint_id=generator_dcp_dir)
        end_time = time.perf_counter()

        logger.info(
            "rank: %s, generator distributed checkpoint saved in %.2f seconds",
            rank,
            end_time - begin_time,
            local_main_process_only=False)

        # Save generator_2 distributed checkpoint (MoE support)
        if generator_transformer_2 is not None:
            generator_2_states = {
                "model": ModelWrapper(generator_transformer_2),
            }
            if generator_optimizer_2 is not None:
                generator_2_states["optimizer"] = OptimizerWrapper(
                    generator_transformer_2, generator_optimizer_2)
            if dataloader is not None:
                generator_2_states["dataloader"] = dataloader
            if generator_scheduler_2 is not None:
                generator_2_states["scheduler"] = SchedulerWrapper(
                    generator_scheduler_2)
            if generator_ema_2 is not None:
                generator_2_states["ema"] = generator_ema_2.state_dict()

            generator_2_dcp_dir = os.path.join(save_dir,
                                               "distributed_checkpoint",
                                               "generator_2")
            logger.info(
                "rank: %s, saving generator_2 distributed checkpoint to %s",
                rank,
                generator_2_dcp_dir,
                local_main_process_only=False)

            begin_time = time.perf_counter()
            dcp.save(generator_2_states, checkpoint_id=generator_2_dcp_dir)
            end_time = time.perf_counter()

            logger.info(
                "rank: %s, generator_2 distributed checkpoint saved in %.2f seconds",
                rank,
                end_time - begin_time,
                local_main_process_only=False)

        # Save critic distributed checkpoint
        critic_states = {
            "model": ModelWrapper(fake_score_transformer),
        }
        if fake_score_optimizer is not None:
            critic_states["optimizer"] = OptimizerWrapper(
                fake_score_transformer, fake_score_optimizer)
        if dataloader is not None:
            critic_states["dataloader"] = dataloader
        if fake_score_scheduler is not None:
            critic_states["scheduler"] = SchedulerWrapper(fake_score_scheduler)

        critic_dcp_dir = os.path.join(save_dir, "distributed_checkpoint",
                                      "critic")
        logger.info("rank: %s, saving critic distributed checkpoint to %s",
                    rank,
                    critic_dcp_dir,
                    local_main_process_only=False)

        begin_time = time.perf_counter()
        dcp.save(critic_states, checkpoint_id=critic_dcp_dir)
        end_time = time.perf_counter()

        logger.info(
            "rank: %s, critic distributed checkpoint saved in %.2f seconds",
            rank,
            end_time - begin_time,
            local_main_process_only=False)

        # Save critic_2 distributed checkpoint (MoE support)
        if fake_score_transformer_2 is not None:
            critic_2_states = {
                "model": ModelWrapper(fake_score_transformer_2),
            }
            if fake_score_optimizer_2 is not None:
                critic_2_states["optimizer"] = OptimizerWrapper(
                    fake_score_transformer_2, fake_score_optimizer_2)
            if dataloader is not None:
                critic_2_states["dataloader"] = dataloader
            if fake_score_scheduler_2 is not None:
                critic_2_states["scheduler"] = SchedulerWrapper(
                    fake_score_scheduler_2)

            critic_2_dcp_dir = os.path.join(save_dir, "distributed_checkpoint",
                                            "critic_2")
            logger.info(
                "rank: %s, saving critic_2 distributed checkpoint to %s",
                rank,
                critic_2_dcp_dir,
                local_main_process_only=False)

            begin_time = time.perf_counter()
            dcp.save(critic_2_states, checkpoint_id=critic_2_dcp_dir)
            end_time = time.perf_counter()

            logger.info(
                "rank: %s, critic_2 distributed checkpoint saved in %.2f seconds",
                rank,
                end_time - begin_time,
                local_main_process_only=False)

        # Save real_score_transformer_2 distributed checkpoint (MoE support)
        if real_score_transformer_2 is not None:
            real_score_2_states = {
                "model": ModelWrapper(real_score_transformer_2),
            }
            # Note: real_score_transformer_2 typically doesn't have optimizer/scheduler
            # since it's used for inference only, but we include dataloader for consistency
            if dataloader is not None:
                real_score_2_states["dataloader"] = dataloader

            real_score_2_dcp_dir = os.path.join(save_dir,
                                                "distributed_checkpoint",
                                                "real_score_2")
            logger.info(
                "rank: %s, saving real_score_2 distributed checkpoint to %s",
                rank,
                real_score_2_dcp_dir,
                local_main_process_only=False)

            begin_time = time.perf_counter()
            dcp.save(real_score_2_states, checkpoint_id=real_score_2_dcp_dir)
            end_time = time.perf_counter()

            logger.info(
                "rank: %s, real_score_2 distributed checkpoint saved in %.2f seconds",
                rank,
                end_time - begin_time,
                local_main_process_only=False)

        # Save shared random state separately
        shared_states = {
            "random_state": RandomStateWrapper(noise_generator),
        }
        shared_dcp_dir = os.path.join(save_dir, "distributed_checkpoint",
                                      "shared")

        dcp.save(shared_states, checkpoint_id=shared_dcp_dir)

    else:
        logger.info(
            "rank: %s, skipping distributed checkpoint save (only_save_generator_weight=True)",
            rank,
            local_main_process_only=False)

    # Save generator model weights (consolidated) for inference
    cpu_state = gather_state_dict_on_cpu_rank0(generator_transformer,
                                               device=None)

    if rank == 0:
        # Save generator model weights (consolidated) for inference
        os.makedirs(inference_save_dir, exist_ok=True)
        weight_path = os.path.join(inference_save_dir,
                                   "diffusion_pytorch_model.safetensors")
        logger.info(
            "rank: %s, saving consolidated generator inference checkpoint to %s",
            rank,
            weight_path,
            local_main_process_only=False)

        # Convert training format to diffusers format and save
        diffusers_state_dict = custom_to_hf_state_dict(
            cpu_state, generator_transformer.reverse_param_names_mapping)
        save_file(diffusers_state_dict, weight_path)

        logger.info(
            "rank: %s, consolidated generator inference checkpoint saved to %s",
            rank,
            weight_path,
            local_main_process_only=False)

        # Save model config
        config_dict = generator_transformer.hf_config
        if "dtype" in config_dict:
            del config_dict["dtype"]  # TODO
        config_path = os.path.join(inference_save_dir, "config.json")
        # save dict as json
        with open(config_path, "w") as f:
            json.dump(config_dict, f, indent=4)
        logger.info("--> distillation checkpoint saved at step %s to %s", step,
                    weight_path)

        # Save generator_2 model weights (consolidated) for inference (MoE support)
        if generator_transformer_2 is not None:
            inference_save_dir_2 = os.path.join(
                save_dir, "generator_2_inference_transformer")
            cpu_state_2 = gather_state_dict_on_cpu_rank0(
                generator_transformer_2, device=None)

            if rank == 0:
                os.makedirs(inference_save_dir_2, exist_ok=True)
                weight_path_2 = os.path.join(
                    inference_save_dir_2, "diffusion_pytorch_model.safetensors")
                logger.info(
                    "rank: %s, saving consolidated generator_2 inference checkpoint to %s",
                    rank,
                    weight_path_2,
                    local_main_process_only=False)

                # Convert training format to diffusers format and save
                diffusers_state_dict_2 = custom_to_hf_state_dict(
                    cpu_state_2,
                    generator_transformer_2.reverse_param_names_mapping)
                save_file(diffusers_state_dict_2, weight_path_2)

                logger.info(
                    "rank: %s, consolidated generator_2 inference checkpoint saved to %s",
                    rank,
                    weight_path_2,
                    local_main_process_only=False)

                # Save model config
                config_dict_2 = generator_transformer_2.hf_config
                if "dtype" in config_dict_2:
                    del config_dict_2["dtype"]  # TODO
                config_path_2 = os.path.join(inference_save_dir_2,
                                             "config.json")
                with open(config_path_2, "w") as f:
                    json.dump(config_dict_2, f, indent=4)
                logger.info(
                    "--> generator_2 distillation checkpoint saved at step %s to %s",
                    step, weight_path_2)