Skip to content

profiler

Utilities for managing the PyTorch profiler within FastVideo.

The profiler is shared across the process; this module adds a light-weight controller that gates collection based on named regions. Regions may be enabled through dedicated environment variables (e.g. FASTVIDEO_TORCH_PROFILE_MODEL_LOADING=1) or via the consolidated FASTVIDEO_TORCH_PROFILE_REGIONS comma-separated list (e.g. FASTVIDEO_TORCH_PROFILE_REGIONS=model_loading,training_dit).

Typical usage from client code::

controller = TorchProfilerController(profiler, activities)
with controller.region("training_dit"):
    run_training_step()

To introduce a new region, register it via :func:register_profiler_region and wrap the corresponding code in :meth:TorchProfilerController.region.

Classes

fastvideo.profiler.ProfilerRegion dataclass

ProfilerRegion(name: str, description: str, default_enabled: bool = False)

Metadata describing a profiler region.

fastvideo.profiler.TorchProfilerConfig dataclass

TorchProfilerConfig(regions: dict[str, bool])

Configuration for torch profiler region control.

Use :meth:from_env to construct an instance with defaults inherited from registered regions and optional overrides from the FASTVIDEO_TORCH_PROFILE_REGIONS environment variable. The resulting regions map is consumed by :class:TorchProfilerController to decide when collection should be enabled.

Functions

fastvideo.profiler.TorchProfilerConfig.from_env classmethod
from_env() -> TorchProfilerConfig

Build a configuration from process environment variables.

Source code in fastvideo/profiler.py
@classmethod
def from_env(cls) -> TorchProfilerConfig:
    """Build a configuration from process environment variables."""

    requested_regions = {
        token.strip()
        for token in (getattr(envs, "FASTVIDEO_TORCH_PROFILE_REGIONS", "")
                      or "").split(",") if token.strip()
    }

    if not requested_regions:
        available = ", ".join(region.name
                              for region in list_profiler_regions())
        raise ValueError(
            "FASTVIDEO_TORCH_PROFILE_REGIONS must list at least one region; "
            f"available regions: {available}")

    regions: dict[str, bool] = {}
    available_regions = list_profiler_regions()
    available_names = ", ".join(region.name for region in available_regions)

    for token in requested_regions:
        resolved = resolve_profiler_region(token)
        if resolved is None:
            logger.warning(
                "Unknown profiler region '%s'; available regions: %s",
                token, available_names)
            continue
        regions[resolved.name] = True

    if not regions:
        raise ValueError(
            "FASTVIDEO_TORCH_PROFILE_REGIONS did not match any known regions; "
            f"requested={sorted(requested_regions)}, available={available_names}"
        )

    return cls(regions=regions)

fastvideo.profiler.TorchProfilerController

TorchProfilerController(profiler: Any, activities: Iterable[ProfilerActivity], config: TorchProfilerConfig | None = None, disabled: bool = False)

Helper that toggles torch profiler collection for named regions.

Parameters

profiler: The shared :class:torch.profiler.profile instance, or None if profiling is disabled. activities: Iterable of :class:torch.profiler.ProfilerActivity recorded by the profiler. config: Optional :class:TorchProfilerConfig. If omitted, :meth:from_env constructs one during initialization.

Examples

Enabling an existing region from the command line::

FASTVIDEO_TORCH_PROFILE_REGIONS=model_loading,training_dit         python fastvideo/training/wan_training_pipeline.py ...

Wrapping a code block in a custom region::

controller = TorchProfilerController(profiler, activities)
with controller.region("training_validation"):
    run_validation_epoch()
Adding a new region requires three steps
  1. Define an env var in envs.py.
  2. Add a default entry to register_profiler_region in this module.
  3. Wrap the target code in :meth:region using the new name.
Source code in fastvideo/profiler.py
def __init__(
    self,
    profiler: Any,
    activities: Iterable[torch.profiler.ProfilerActivity],
    config: TorchProfilerConfig | None = None,
    disabled: bool = False,
) -> None:
    activities_tuple = tuple(activities)
    existing = get_global_controller()
    if existing is not None and not disabled:
        raise RuntimeError(
            "TorchProfilerController already initialized globally. Use get_global_controller()."
        )
    if disabled:
        self._profiler = None
        return

    self._profiler = profiler
    self._activities = activities_tuple
    self._config = config or TorchProfilerConfig.from_env()
    self._collection_enabled = False
    self._active_region_depth = 0
    logger.info(
        "PROFILER: TorchProfilerController initialized with config: %s",
        self._config)
    set_global_profiler(self._profiler)
    set_global_controller(self)

Attributes

fastvideo.profiler.TorchProfilerController.has_profiler property
has_profiler: bool

Return True when a profiler instance is available.

fastvideo.profiler.TorchProfilerController.is_enabled property
is_enabled: bool

Return True when the underlying profiler is collecting.

Functions

fastvideo.profiler.TorchProfilerController.is_region_enabled
is_region_enabled(region: str) -> bool

Return True if region should be collected.

Source code in fastvideo/profiler.py
def is_region_enabled(self, region: str) -> bool:
    """Return ``True`` if ``region`` should be collected."""

    if self._profiler is None:
        return False
    return self._config.regions.get(region, False)
fastvideo.profiler.TorchProfilerController.region
region(region: str)

Context manager that enables profiling for region if configured.

Source code in fastvideo/profiler.py
@contextlib.contextmanager
def region(self, region: str):
    """Context manager that enables profiling for ``region`` if configured."""

    if self._profiler is None:
        yield
        return

    if not self.is_region_enabled(region):
        yield
        return

    with torch.profiler.record_function(f"fastvideo.region::{region}"):
        self._active_region_depth += 1
        if self._active_region_depth == 1:
            logger.info(
                "PROFILER: Setting collection to True (depth=%s) for region %s",
                self._active_region_depth, region)
            self._set_collection(True)
        try:
            yield
        finally:
            self._active_region_depth -= 1
            logger.info("PROFILER: Decreasing active region depth to %s",
                        self._active_region_depth)
            if self._active_region_depth == 0:
                logger.info(
                    "PROFILER: Setting collection to False upon exiting region %s",
                    region)
                self._set_collection(False)
fastvideo.profiler.TorchProfilerController.start
start() -> None

Start the profiler and pause collection until a region is entered.

Source code in fastvideo/profiler.py
def start(self) -> None:
    """Start the profiler and pause collection until a region is entered."""

    logger.info("PROFILER: Starting profiler...")
    if self._profiler is None:
        return
    self._profiler.start()
    logger.info("PROFILER: Profiler started")
    # Profiler starts with collection disabled by default.
    logger.info("PROFILER: Setting collection to False")
    self._set_collection(False)
    logger.info("PROFILER: Profiler started with collection disabled")
fastvideo.profiler.TorchProfilerController.stop
stop() -> None

Stop the profiler after disabling collection and clearing state.

Source code in fastvideo/profiler.py
def stop(self) -> None:
    """Stop the profiler after disabling collection and clearing state."""

    if self._profiler is None:
        return

    logger.info("PROFILER: Stopping profiler...")
    self._profiler.stop()
    logger.info("PROFILER: Profiler stopped")
    self._active_region_depth = 0
    set_global_profiler(None)
    set_global_controller(None)

Functions

fastvideo.profiler.get_global_profiler

get_global_profiler() -> profile | None

Return the global profiler instance if one was created.

Source code in fastvideo/profiler.py
def get_global_profiler() -> torch.profiler.profile | None:
    """Return the global profiler instance if one was created."""

    return _GLOBAL_PROFILER

fastvideo.profiler.get_or_create_profiler

get_or_create_profiler(trace_dir: str | None) -> TorchProfilerController

Create or reuse the process-wide torch profiler controller.

Source code in fastvideo/profiler.py
def get_or_create_profiler(trace_dir: str | None) -> TorchProfilerController:
    """Create or reuse the process-wide torch profiler controller."""

    existing = get_global_controller()
    if existing is not None:
        if trace_dir:
            logger.info("Reusing existing global torch profiler controller")
        return existing

    if not trace_dir:
        logger.info("Torch profiler disabled; returning no-op controller")
        return TorchProfilerController(None, _DEFAULT_ACTIVITIES, disabled=True)

    logger.info("Profiling enabled. Traces will be saved to: %s", trace_dir)
    logger.info(
        "Profiler config: record_shapes=%s, profile_memory=%s, with_stack=%s, with_flops=%s",
        envs.FASTVIDEO_TORCH_PROFILER_RECORD_SHAPES,
        envs.FASTVIDEO_TORCH_PROFILER_WITH_PROFILE_MEMORY,
        envs.FASTVIDEO_TORCH_PROFILER_WITH_STACK,
        envs.FASTVIDEO_TORCH_PROFILER_WITH_FLOPS,
    )
    logger.info("FASTVIDEO_TORCH_PROFILE_REGIONS=%s",
                envs.FASTVIDEO_TORCH_PROFILE_REGIONS)

    profiler = torch.profiler.profile(
        activities=_DEFAULT_ACTIVITIES,
        record_shapes=envs.FASTVIDEO_TORCH_PROFILER_RECORD_SHAPES,
        profile_memory=envs.FASTVIDEO_TORCH_PROFILER_WITH_PROFILE_MEMORY,
        with_stack=envs.FASTVIDEO_TORCH_PROFILER_WITH_STACK,
        with_flops=envs.FASTVIDEO_TORCH_PROFILER_WITH_FLOPS,
        on_trace_ready=torch.profiler.tensorboard_trace_handler(trace_dir,
                                                                use_gzip=True),
    )
    controller = TorchProfilerController(profiler, _DEFAULT_ACTIVITIES)
    controller.start()
    logger.info("Torch profiler started")
    return controller

fastvideo.profiler.list_profiler_regions

list_profiler_regions() -> list[ProfilerRegion]

Return all registered profiler regions sorted by canonical name.

Source code in fastvideo/profiler.py
def list_profiler_regions() -> list[ProfilerRegion]:
    """Return all registered profiler regions sorted by canonical name."""

    return [_REGISTERED_REGIONS[name] for name in sorted(_REGISTERED_REGIONS)]

fastvideo.profiler.profile_region

profile_region(region: str) -> Callable[[Callable[..., Any]], Callable[..., Any]]

Wrap a bound method so it runs inside a profiler region if available.

Source code in fastvideo/profiler.py
def profile_region(
        region: str) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
    """Wrap a bound method so it runs inside a profiler region if available."""

    def decorator(fn: Callable[..., Any]) -> Callable[..., Any]:

        @functools.wraps(fn)
        def wrapped(self, *args, **kwargs):
            controller = getattr(self, "profiler_controller", None)
            if controller is None or not controller.has_profiler:
                return fn(self, *args, **kwargs)
            with controller.region(region):
                return fn(self, *args, **kwargs)

        return wrapped

    return decorator

fastvideo.profiler.register_profiler_region

register_profiler_region(name: str, description: str, *, default_enabled: bool = False) -> None

Register a profiler region so configuration can validate inputs.

Source code in fastvideo/profiler.py
def register_profiler_region(
    name: str,
    description: str,
    *,
    default_enabled: bool = False,
) -> None:
    """Register a profiler region so configuration can validate inputs."""

    canonical = _normalize_token(name)
    if canonical in _REGISTERED_REGIONS:
        raise ValueError(f"Profiler region {name!r} is already registered")

    region = ProfilerRegion(
        name=canonical,
        description=description,
        default_enabled=bool(default_enabled),
    )
    _REGISTERED_REGIONS[canonical] = region

fastvideo.profiler.resolve_profiler_region

resolve_profiler_region(name: str) -> ProfilerRegion | None

Return the registered region matching name or None if absent.

Source code in fastvideo/profiler.py
def resolve_profiler_region(name: str) -> ProfilerRegion | None:
    """Return the registered region matching ``name`` or ``None`` if absent."""

    canonical = _normalize_token(name)
    return _REGISTERED_REGIONS.get(canonical)