fastvideo.v1.training.checkpointing_utils#

Module Contents#

Classes#

API#

class fastvideo.v1.training.checkpointing_utils.ModelWrapper(model: torch.nn.Module)[source]#

Bases: torch.distributed.checkpoint.stateful.Stateful

load_state_dict(state_dict: Dict[str, Any]) None[source]#
state_dict() Dict[str, Any][source]#
class fastvideo.v1.training.checkpointing_utils.OptimizerWrapper(model: torch.nn.Module, optimizer: torch.optim.Optimizer)[source]#

Bases: torch.distributed.checkpoint.stateful.Stateful

load_state_dict(state_dict: Dict[str, Any]) None[source]#
state_dict() Dict[str, Any][source]#
class fastvideo.v1.training.checkpointing_utils.RandomStateWrapper(noise_generator: Optional[torch.Generator] = None)[source]#

Bases: torch.distributed.checkpoint.stateful.Stateful

load_state_dict(state_dict: Dict[str, Any]) None[source]#
state_dict() Dict[str, Any][source]#
class fastvideo.v1.training.checkpointing_utils.SchedulerWrapper(scheduler)[source]#

Bases: torch.distributed.checkpoint.stateful.Stateful

load_state_dict(state_dict: Dict[str, Any]) None[source]#
state_dict() Dict[str, Any][source]#