Skip to content

forward_context

Classes

Functions

fastvideo.forward_context.get_forward_context

get_forward_context() -> ForwardContext

Get the current forward context.

Source code in fastvideo/forward_context.py
def get_forward_context() -> "ForwardContext":
    """Get the current forward context."""
    assert _forward_context is not None, (
        "Forward context is not set. "
        "Please use `set_forward_context` to set the forward context.")
    return _forward_context

fastvideo.forward_context.set_forward_context

set_forward_context(current_timestep, attn_metadata, forward_batch: Optional[ForwardBatch] = None)

A context manager that stores the current forward context, can be attention metadata, etc. Here we can inject common logic for every model forward pass.

Source code in fastvideo/forward_context.py
@contextmanager
def set_forward_context(current_timestep,
                        attn_metadata,
                        forward_batch: Optional["ForwardBatch"] = None):
    """A context manager that stores the current forward context,
    can be attention metadata, etc.
    Here we can inject common logic for every model forward pass.
    """
    global forward_start_time
    need_to_track_batchsize = track_batchsize and attn_metadata is not None
    if need_to_track_batchsize:
        forward_start_time = time.perf_counter()
    global _forward_context
    prev_context = _forward_context
    _forward_context = ForwardContext(current_timestep=current_timestep,
                                      attn_metadata=attn_metadata,
                                      forward_batch=forward_batch)

    try:
        yield
    finally:
        global last_logging_time, batchsize_logging_interval
        if need_to_track_batchsize:
            if hasattr(attn_metadata, "num_prefill_tokens"):
                # for v0 attention backends
                batchsize = attn_metadata.num_prefill_tokens + \
                    attn_metadata.num_decode_tokens
            else:
                # for v1 attention backends
                batchsize = attn_metadata.num_input_tokens
            now = time.perf_counter()
            # time measurement is in milliseconds
            batchsize_forward_time[batchsize].append(
                (now - forward_start_time) * 1000)
            if now - last_logging_time > batchsize_logging_interval:
                last_logging_time = now
                forward_stats = []
                for bs, times in batchsize_forward_time.items():
                    if len(times) <= 1:
                        # can be cudagraph / profiling run
                        continue
                    medium = torch.quantile(torch.tensor(times), q=0.5).item()
                    medium = round(medium, 2)
                    forward_stats.append((bs, len(times), medium))
                forward_stats.sort(key=lambda x: x[1], reverse=True)
                if forward_stats:
                    logger.info(("Batchsize forward time stats "
                                 "(batchsize, count, median_time(ms)): %s"),
                                forward_stats)
        _forward_context = prev_context