Skip to content

distributed

Classes

fastvideo.distributed.StatelessProcessGroup dataclass

StatelessProcessGroup(rank: int, world_size: int, store: Store, data_expiration_seconds: int = 3600, send_dst_counter: dict[int, int] = dict(), recv_src_counter: dict[int, int] = dict(), broadcast_send_counter: int = 0, broadcast_recv_src_counter: dict[int, int] = dict(), entries: deque[tuple[str, float]] = deque())

A dataclass to hold a metadata store, and the rank, world_size of the group. Only use it to communicate metadata between processes. For data-plane communication, create NCCL-related objects.

Functions

fastvideo.distributed.StatelessProcessGroup.all_gather_obj
all_gather_obj(obj: Any) -> list[Any]

All gather an object from all ranks.

Source code in fastvideo/distributed/utils.py
def all_gather_obj(self, obj: Any) -> list[Any]:
    """All gather an object from all ranks."""
    gathered_objs = []
    for i in range(self.world_size):
        if i == self.rank:
            gathered_objs.append(obj)
            self.broadcast_obj(obj, src=self.rank)
        else:
            recv_obj = self.broadcast_obj(None, src=i)
            gathered_objs.append(recv_obj)
    return gathered_objs
fastvideo.distributed.StatelessProcessGroup.barrier
barrier()

A barrier to synchronize all ranks.

Source code in fastvideo/distributed/utils.py
def barrier(self):
    """A barrier to synchronize all ranks."""
    for i in range(self.world_size):
        if i == self.rank:
            self.broadcast_obj(None, src=self.rank)
        else:
            self.broadcast_obj(None, src=i)
fastvideo.distributed.StatelessProcessGroup.broadcast_obj
broadcast_obj(obj: Any | None, src: int) -> Any

Broadcast an object from a source rank to all other ranks. It does not clean up after all ranks have received the object. Use it for limited times, e.g., for initialization.

Source code in fastvideo/distributed/utils.py
def broadcast_obj(self, obj: Any | None, src: int) -> Any:
    """Broadcast an object from a source rank to all other ranks.
    It does not clean up after all ranks have received the object.
    Use it for limited times, e.g., for initialization.
    """
    if self.rank == src:
        self.expire_data()
        key = (f"broadcast_from/{src}/"
               f"{self.broadcast_send_counter}")
        self.store.set(key, pickle.dumps(obj))
        self.broadcast_send_counter += 1
        self.entries.append((key, time.perf_counter()))
        return obj
    else:
        key = (f"broadcast_from/{src}/"
               f"{self.broadcast_recv_src_counter[src]}")
        recv_obj = pickle.loads(self.store.get(key))
        self.broadcast_recv_src_counter[src] += 1
        return recv_obj
fastvideo.distributed.StatelessProcessGroup.create staticmethod
create(host: str, port: int, rank: int, world_size: int, data_expiration_seconds: int = 3600) -> StatelessProcessGroup

A replacement for torch.distributed.init_process_group that does not pollute the global state.

If we have process A and process B called torch.distributed.init_process_group to form a group, and then we want to form another group with process A, B, C, D, it is not possible in PyTorch, because process A and process B have already formed a group, and process C and process D cannot join that group. This function is a workaround for this issue.

torch.distributed.init_process_group is a global call, while this function is a stateless call. It will return a StatelessProcessGroup object that can be used for exchanging metadata. With this function, process A and process B can call StatelessProcessGroup.create to form a group, and then process A, B, C, and D can call StatelessProcessGroup.create to form another group.

Source code in fastvideo/distributed/utils.py
@staticmethod
def create(
    host: str,
    port: int,
    rank: int,
    world_size: int,
    data_expiration_seconds: int = 3600,
) -> "StatelessProcessGroup":
    """A replacement for `torch.distributed.init_process_group` that does not
    pollute the global state.

    If we have process A and process B called `torch.distributed.init_process_group`
    to form a group, and then we want to form another group with process A, B, C,
    D, it is not possible in PyTorch, because process A and process B have already
    formed a group, and process C and process D cannot join that group. This
    function is a workaround for this issue.

    `torch.distributed.init_process_group` is a global call, while this function
    is a stateless call. It will return a `StatelessProcessGroup` object that can be
    used for exchanging metadata. With this function, process A and process B
    can call `StatelessProcessGroup.create` to form a group, and then process A, B,
    C, and D can call `StatelessProcessGroup.create` to form another group.
    """ # noqa
    store = TCPStore(
        host_name=host,
        port=port,
        world_size=world_size,
        is_master=(rank == 0),
    )

    return StatelessProcessGroup(
        rank=rank,
        world_size=world_size,
        store=store,
        data_expiration_seconds=data_expiration_seconds)
fastvideo.distributed.StatelessProcessGroup.expire_data
expire_data() -> None

Expire data that is older than data_expiration_seconds seconds.

Source code in fastvideo/distributed/utils.py
def expire_data(self) -> None:
    """Expire data that is older than `data_expiration_seconds` seconds."""
    while self.entries:
        # check the oldest entry
        key, timestamp = self.entries[0]
        if time.perf_counter() - timestamp > self.data_expiration_seconds:
            self.store.delete_key(key)
            self.entries.popleft()
        else:
            break
fastvideo.distributed.StatelessProcessGroup.recv_obj
recv_obj(src: int) -> Any

Receive an object from a source rank.

Source code in fastvideo/distributed/utils.py
def recv_obj(self, src: int) -> Any:
    """Receive an object from a source rank."""
    obj = pickle.loads(
        self.store.get(f"send_to/{self.rank}/{self.recv_src_counter[src]}"))
    self.recv_src_counter[src] += 1
    return obj
fastvideo.distributed.StatelessProcessGroup.send_obj
send_obj(obj: Any, dst: int)

Send an object to a destination rank.

Source code in fastvideo/distributed/utils.py
def send_obj(self, obj: Any, dst: int):
    """Send an object to a destination rank."""
    self.expire_data()
    key = f"send_to/{dst}/{self.send_dst_counter[dst]}"
    self.store.set(key, pickle.dumps(obj))
    self.send_dst_counter[dst] += 1
    self.entries.append((key, time.perf_counter()))

Functions

fastvideo.distributed.divide

divide(numerator: int, denominator: int) -> int

Ensure that numerator is divisible by the denominator and return the division value.

Source code in fastvideo/distributed/utils.py
def divide(numerator: int, denominator: int) -> int:
    """Ensure that numerator is divisible by the denominator and return
    the division value."""
    ensure_divisibility(numerator, denominator)
    return numerator // denominator

fastvideo.distributed.ensure_divisibility

ensure_divisibility(numerator, denominator) -> None

Ensure that numerator is divisible by the denominator.

Source code in fastvideo/distributed/utils.py
def ensure_divisibility(numerator, denominator) -> None:
    """Ensure that numerator is divisible by the denominator."""
    assert numerator % denominator == 0, "{} is not divisible by {}".format(
        numerator, denominator)

fastvideo.distributed.get_dp_rank

get_dp_rank() -> int

Return my rank for the data parallel group.

Source code in fastvideo/distributed/parallel_state.py
def get_dp_rank() -> int:
    """Return my rank for the data parallel group."""
    return get_dp_group().rank_in_group

fastvideo.distributed.get_dp_world_size

get_dp_world_size() -> int

Return world size for the data parallel group.

Source code in fastvideo/distributed/parallel_state.py
def get_dp_world_size() -> int:
    """Return world size for the data parallel group."""
    return get_dp_group().world_size

fastvideo.distributed.get_local_torch_device

get_local_torch_device() -> device

Return the torch device for the current rank.

Source code in fastvideo/distributed/parallel_state.py
def get_local_torch_device() -> torch.device:
    """Return the torch device for the current rank."""
    from fastvideo.platforms import current_platform
    if current_platform.is_npu():
        device = torch.device(f"npu:{envs.LOCAL_RANK}")
    elif current_platform.is_cuda_alike() or current_platform.is_cuda():
        device = torch.device(f"cuda:{envs.LOCAL_RANK}")
    else:
        device = torch.device("mps")
    return device

fastvideo.distributed.get_sp_parallel_rank

get_sp_parallel_rank() -> int

Return my rank for the sequence model parallel group.

Source code in fastvideo/distributed/parallel_state.py
def get_sp_parallel_rank() -> int:
    """Return my rank for the sequence model parallel group."""
    return get_sp_group().rank_in_group

fastvideo.distributed.get_sp_world_size

get_sp_world_size() -> int

Return world size for the sequence model parallel group.

Source code in fastvideo/distributed/parallel_state.py
def get_sp_world_size() -> int:
    """Return world size for the sequence model parallel group."""
    return get_sp_group().world_size

fastvideo.distributed.get_tp_rank

get_tp_rank() -> int

Return my rank for the tensor model parallel group.

Source code in fastvideo/distributed/parallel_state.py
def get_tp_rank() -> int:
    """Return my rank for the tensor model parallel group."""
    return get_tp_group().rank_in_group

fastvideo.distributed.get_tp_world_size

get_tp_world_size() -> int

Return world size for the tensor model parallel group.

Source code in fastvideo/distributed/parallel_state.py
def get_tp_world_size() -> int:
    """Return world size for the tensor model parallel group."""
    return get_tp_group().world_size

fastvideo.distributed.get_world_rank

get_world_rank() -> int

Return my rank for the world group.

Source code in fastvideo/distributed/parallel_state.py
def get_world_rank() -> int:
    """Return my rank for the world group."""
    return get_world_group().rank

fastvideo.distributed.get_world_size

get_world_size() -> int

Return world size for the world group.

Source code in fastvideo/distributed/parallel_state.py
def get_world_size() -> int:
    """Return world size for the world group."""
    return get_world_group().world_size

fastvideo.distributed.init_logger

init_logger(name: str) -> _FastvideoLogger

The main purpose of this function is to ensure that loggers are retrieved in such a way that we can be sure the root fastvideo logger has already been configured.

Source code in fastvideo/logger.py
def init_logger(name: str) -> _FastvideoLogger:
    """The main purpose of this function is to ensure that loggers are
    retrieved in such a way that we can be sure the root fastvideo logger has
    already been configured."""

    logger = logging.getLogger(name)

    methods_to_patch = {
        "info_once": _print_info_once,
        "warning_once": _print_warning_once,
        "info": _info,
    }

    for method_name, method in methods_to_patch.items():
        setattr(logger, method_name,
                MethodType(method, logger))  # type: ignore[arg-type]

    return cast(_FastvideoLogger, logger)

fastvideo.distributed.initialize_model_parallel

initialize_model_parallel(tensor_model_parallel_size: int = 1, sequence_model_parallel_size: int = 1, data_parallel_size: int = 1, backend: str | None = None) -> None

Initialize model parallel groups.

Parameters:

Name Type Description Default
tensor_model_parallel_size int

number of GPUs used for tensor model parallelism (used for language encoder).

1
sequence_model_parallel_size int

number of GPUs used for sequence model parallelism (used for DiT).

1
Source code in fastvideo/distributed/parallel_state.py
def initialize_model_parallel(
    tensor_model_parallel_size: int = 1,
    sequence_model_parallel_size: int = 1,
    data_parallel_size: int = 1,
    backend: str | None = None,
) -> None:
    """
    Initialize model parallel groups.

    Arguments:
        tensor_model_parallel_size: number of GPUs used for tensor model
            parallelism (used for language encoder).
        sequence_model_parallel_size: number of GPUs used for sequence model
            parallelism (used for DiT).
    """
    # Get world size and rank. Ensure some consistencies.
    assert _WORLD is not None, "world group is not initialized, please call init_distributed_environment first"
    world_size: int = get_world_size()
    backend = backend or torch.distributed.get_backend(
        get_world_group().device_group)
    assert world_size >= tensor_model_parallel_size, f"world_size({world_size}) must be greater than or equal to tensor_model_parallel_size({tensor_model_parallel_size})"
    num_tensor_model_parallel_groups: int = (world_size //
                                             tensor_model_parallel_size)
    global _TP
    assert _TP is None, ("tensor model parallel group is already initialized")
    group_ranks = []
    for i in range(num_tensor_model_parallel_groups):
        ranks = list(
            range(i * tensor_model_parallel_size,
                  (i + 1) * tensor_model_parallel_size))
        group_ranks.append(ranks)

    # message queue broadcaster is only used in tensor model parallel group
    _TP = init_model_parallel_group(group_ranks,
                                    get_world_group().local_rank,
                                    backend,
                                    use_message_queue_broadcaster=True,
                                    group_name="tp")

    # Build the sequence model-parallel groups.
    num_sequence_model_parallel_groups: int = (world_size //
                                               sequence_model_parallel_size)
    global _SP
    assert _SP is None, ("sequence model parallel group is already initialized")
    group_ranks = []

    # Since SP is incompatible with TP and PP, we can use a simpler group creation logic
    for i in range(num_sequence_model_parallel_groups):
        # Create groups of consecutive ranks
        ranks = list(
            range(i * sequence_model_parallel_size,
                  (i + 1) * sequence_model_parallel_size))
        group_ranks.append(ranks)

    _SP = init_model_parallel_group(group_ranks,
                                    get_world_group().local_rank,
                                    backend,
                                    group_name="sp")

    # Build the data parallel groups.
    num_data_parallel_groups: int = sequence_model_parallel_size
    global _DP
    assert _DP is None, ("data parallel group is already initialized")
    group_ranks = []

    for i in range(num_data_parallel_groups):
        ranks = list(range(i, world_size, num_data_parallel_groups))
        group_ranks.append(ranks)

    _DP = init_model_parallel_group(group_ranks,
                                    get_world_group().local_rank,
                                    backend,
                                    group_name="dp")

fastvideo.distributed.model_parallel_is_initialized

model_parallel_is_initialized() -> bool

Check if tensor, sequence parallel groups are initialized.

Source code in fastvideo/distributed/parallel_state.py
def model_parallel_is_initialized() -> bool:
    """Check if tensor, sequence parallel groups are initialized."""
    return _TP is not None and _SP is not None and _DP is not None

fastvideo.distributed.sequence_model_parallel_all_gather

sequence_model_parallel_all_gather(input_: Tensor, dim: int = -1) -> Tensor

All-gather the input tensor across model parallel group.

Source code in fastvideo/distributed/communication_op.py
def sequence_model_parallel_all_gather(input_: torch.Tensor,
                                       dim: int = -1) -> torch.Tensor:
    """All-gather the input tensor across model parallel group."""
    return get_sp_group().all_gather(input_, dim)

fastvideo.distributed.sequence_model_parallel_all_to_all_4D

sequence_model_parallel_all_to_all_4D(input_: Tensor, scatter_dim: int = 2, gather_dim: int = 1) -> Tensor

All-to-all communication of 4D tensors (e.g. QKV matrices) across sequence parallel group.

Source code in fastvideo/distributed/communication_op.py
def sequence_model_parallel_all_to_all_4D(input_: torch.Tensor,
                                          scatter_dim: int = 2,
                                          gather_dim: int = 1) -> torch.Tensor:
    """All-to-all communication of 4D tensors (e.g. QKV matrices) across sequence parallel group."""
    return get_sp_group().all_to_all_4D(input_, scatter_dim, gather_dim)

fastvideo.distributed.split_tensor_along_last_dim

split_tensor_along_last_dim(tensor: Tensor, num_partitions: int, contiguous_split_chunks: bool = False) -> Sequence[Tensor]

Split a tensor along its last dimension.

Parameters:

Name Type Description Default
tensor Tensor

input tensor.

required
num_partitions int

number of partitions to split the tensor

required
contiguous_split_chunks bool

If True, make each chunk contiguous in memory.

False

Returns:

Type Description
Sequence[Tensor]

A list of Tensors

Source code in fastvideo/distributed/utils.py
def split_tensor_along_last_dim(
    tensor: torch.Tensor,
    num_partitions: int,
    contiguous_split_chunks: bool = False,
) -> Sequence[torch.Tensor]:
    """ Split a tensor along its last dimension.

        Arguments:
            tensor: input tensor.
            num_partitions: number of partitions to split the tensor
            contiguous_split_chunks: If True, make each chunk contiguous
                                     in memory.

        Returns:
            A list of Tensors
    """
    # Get the size and dimension.
    last_dim = tensor.dim() - 1
    last_dim_size = divide(tensor.size()[last_dim], num_partitions)
    # Split.
    tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)
    # NOTE: torch.split does not create contiguous tensors by default.
    if contiguous_split_chunks:
        return tuple(chunk.contiguous() for chunk in tensor_list)

    return tuple(tensor_list)

fastvideo.distributed.tensor_model_parallel_all_gather

tensor_model_parallel_all_gather(input_: Tensor, dim: int = -1) -> Tensor

All-gather the input tensor across model parallel group.

Source code in fastvideo/distributed/communication_op.py
def tensor_model_parallel_all_gather(input_: torch.Tensor,
                                     dim: int = -1) -> torch.Tensor:
    """All-gather the input tensor across model parallel group."""
    return get_tp_group().all_gather(input_, dim)

fastvideo.distributed.tensor_model_parallel_all_reduce

tensor_model_parallel_all_reduce(input_: Tensor) -> Tensor

All-reduce the input tensor across model parallel group.

Source code in fastvideo/distributed/communication_op.py
def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
    """All-reduce the input tensor across model parallel group."""
    return get_tp_group().all_reduce(input_)

Modules

fastvideo.distributed.communication_op

Functions

fastvideo.distributed.communication_op.sequence_model_parallel_all_gather
sequence_model_parallel_all_gather(input_: Tensor, dim: int = -1) -> Tensor

All-gather the input tensor across model parallel group.

Source code in fastvideo/distributed/communication_op.py
def sequence_model_parallel_all_gather(input_: torch.Tensor,
                                       dim: int = -1) -> torch.Tensor:
    """All-gather the input tensor across model parallel group."""
    return get_sp_group().all_gather(input_, dim)
fastvideo.distributed.communication_op.sequence_model_parallel_all_to_all_4D
sequence_model_parallel_all_to_all_4D(input_: Tensor, scatter_dim: int = 2, gather_dim: int = 1) -> Tensor

All-to-all communication of 4D tensors (e.g. QKV matrices) across sequence parallel group.

Source code in fastvideo/distributed/communication_op.py
def sequence_model_parallel_all_to_all_4D(input_: torch.Tensor,
                                          scatter_dim: int = 2,
                                          gather_dim: int = 1) -> torch.Tensor:
    """All-to-all communication of 4D tensors (e.g. QKV matrices) across sequence parallel group."""
    return get_sp_group().all_to_all_4D(input_, scatter_dim, gather_dim)
fastvideo.distributed.communication_op.tensor_model_parallel_all_gather
tensor_model_parallel_all_gather(input_: Tensor, dim: int = -1) -> Tensor

All-gather the input tensor across model parallel group.

Source code in fastvideo/distributed/communication_op.py
def tensor_model_parallel_all_gather(input_: torch.Tensor,
                                     dim: int = -1) -> torch.Tensor:
    """All-gather the input tensor across model parallel group."""
    return get_tp_group().all_gather(input_, dim)
fastvideo.distributed.communication_op.tensor_model_parallel_all_reduce
tensor_model_parallel_all_reduce(input_: Tensor) -> Tensor

All-reduce the input tensor across model parallel group.

Source code in fastvideo/distributed/communication_op.py
def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
    """All-reduce the input tensor across model parallel group."""
    return get_tp_group().all_reduce(input_)

fastvideo.distributed.device_communicators

Modules

fastvideo.distributed.device_communicators.base_device_communicator
Classes
fastvideo.distributed.device_communicators.base_device_communicator.DeviceCommunicatorBase
DeviceCommunicatorBase(cpu_group: ProcessGroup, device: device | None = None, device_group: ProcessGroup | None = None, unique_name: str = '')

Base class for device-specific communicator with autograd support. It can use the cpu_group to initialize the communicator. If the device has PyTorch integration (PyTorch can recognize its communication backend), the device_group will also be given.

Source code in fastvideo/distributed/device_communicators/base_device_communicator.py
def __init__(self,
             cpu_group: ProcessGroup,
             device: torch.device | None = None,
             device_group: ProcessGroup | None = None,
             unique_name: str = ""):
    self.device = device or torch.device("cpu")
    self.cpu_group = cpu_group
    self.device_group = device_group
    self.unique_name = unique_name
    self.rank = dist.get_rank(cpu_group)
    self.world_size = dist.get_world_size(cpu_group)
    self.ranks = dist.get_process_group_ranks(cpu_group)
    self.global_rank = dist.get_rank()
    self.global_world_size = dist.get_world_size()
    self.rank_in_group = dist.get_group_rank(self.cpu_group,
                                             self.global_rank)
Functions
fastvideo.distributed.device_communicators.base_device_communicator.DeviceCommunicatorBase.all_gather
all_gather(input_: Tensor, dim: int = -1) -> Tensor

Performs an all_gather operation with gradient support.

Source code in fastvideo/distributed/device_communicators/base_device_communicator.py
def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
    """Performs an all_gather operation with gradient support."""
    if dim < 0:
        dim += input_.dim()
    return DistributedAutograd.AllGather.apply(self.device_group, input_,
                                               self.world_size, dim)
fastvideo.distributed.device_communicators.base_device_communicator.DeviceCommunicatorBase.all_reduce
all_reduce(input_: Tensor, op: ReduceOp | None = SUM) -> Tensor

Performs an all_reduce operation with gradient support.

Source code in fastvideo/distributed/device_communicators/base_device_communicator.py
def all_reduce(self,
               input_: torch.Tensor,
               op: dist.ReduceOp | None = ReduceOp.SUM) -> torch.Tensor:
    """Performs an all_reduce operation with gradient support."""
    return DistributedAutograd.AllReduce.apply(self.device_group, input_,
                                               op)
fastvideo.distributed.device_communicators.base_device_communicator.DeviceCommunicatorBase.all_to_all_4D
all_to_all_4D(input_: Tensor, scatter_dim: int = 2, gather_dim: int = 1) -> Tensor

Performs a 4D all-to-all operation with gradient support.

Source code in fastvideo/distributed/device_communicators/base_device_communicator.py
def all_to_all_4D(self,
                  input_: torch.Tensor,
                  scatter_dim: int = 2,
                  gather_dim: int = 1) -> torch.Tensor:
    """Performs a 4D all-to-all operation with gradient support."""
    return DistributedAutograd.AllToAll4D.apply(self.device_group, input_,
                                                self.world_size,
                                                scatter_dim, gather_dim)
fastvideo.distributed.device_communicators.base_device_communicator.DeviceCommunicatorBase.gather
gather(input_: Tensor, dst: int = 0, dim: int = -1) -> Tensor | None

NOTE: We assume that the input tensor is on the same device across all the ranks. NOTE: dst is the local rank of the destination rank.

Source code in fastvideo/distributed/device_communicators/base_device_communicator.py
def gather(self,
           input_: torch.Tensor,
           dst: int = 0,
           dim: int = -1) -> torch.Tensor | None:
    """
    NOTE: We assume that the input tensor is on the same device across
    all the ranks.
    NOTE: `dst` is the local rank of the destination rank.
    """
    world_size = self.world_size
    assert -input_.dim() <= dim < input_.dim(), (
        f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
    if dim < 0:
        # Convert negative dim to positive.
        dim += input_.dim()

    # Allocate output tensor.
    if self.rank_in_group == dst:
        gather_list = [torch.empty_like(input_) for _ in range(world_size)]
    else:
        gather_list = None
    # Gather.
    torch.distributed.gather(input_,
                             gather_list,
                             dst=self.ranks[dst],
                             group=self.device_group)
    if self.rank_in_group == dst:
        output_tensor = torch.cat(gather_list, dim=dim)
    else:
        output_tensor = None
    return output_tensor
fastvideo.distributed.device_communicators.base_device_communicator.DeviceCommunicatorBase.recv
recv(size: Size, dtype: dtype, src: int | None = None) -> Tensor

Receives a tensor from the source rank.

Source code in fastvideo/distributed/device_communicators/base_device_communicator.py
def recv(self,
         size: torch.Size,
         dtype: torch.dtype,
         src: int | None = None) -> torch.Tensor:
    """Receives a tensor from the source rank."""
    """NOTE: `src` is the local rank of the source rank."""
    if src is None:
        src = (self.rank_in_group - 1) % self.world_size

    tensor = torch.empty(size, dtype=dtype, device=self.device)
    torch.distributed.recv(tensor, self.ranks[src], self.device_group)
    return tensor
fastvideo.distributed.device_communicators.base_device_communicator.DeviceCommunicatorBase.send
send(tensor: Tensor, dst: int | None = None) -> None

Sends a tensor to the destination rank in a non-blocking way

Source code in fastvideo/distributed/device_communicators/base_device_communicator.py
def send(self, tensor: torch.Tensor, dst: int | None = None) -> None:
    """Sends a tensor to the destination rank in a non-blocking way"""
    """NOTE: `dst` is the local rank of the destination rank."""
    if dst is None:
        dst = (self.rank_in_group + 1) % self.world_size
    torch.distributed.send(tensor, self.ranks[dst], self.device_group)
fastvideo.distributed.device_communicators.base_device_communicator.DistributedAutograd

Collection of autograd functions for distributed operations.

This class provides custom autograd functions for distributed operations like all_reduce, all_gather, and all_to_all. Each operation is implemented as a static inner class with proper forward and backward implementations.

Classes
fastvideo.distributed.device_communicators.base_device_communicator.DistributedAutograd.AllGather

Bases: Function

Differentiable all_gather operation.

The operation gathers tensors from all ranks and concatenates them along a specified dimension. The backward pass uses reduce_scatter to efficiently distribute gradients back to source ranks.

fastvideo.distributed.device_communicators.base_device_communicator.DistributedAutograd.AllReduce

Bases: Function

Differentiable all_reduce operation.

The gradient of all_reduce is another all_reduce operation since the operation combines values from all ranks equally.

fastvideo.distributed.device_communicators.base_device_communicator.DistributedAutograd.AllToAll4D

Bases: Function

Differentiable all_to_all operation specialized for 4D tensors.

This operation is particularly useful for attention operations where we need to redistribute data across ranks for efficient parallel processing.

The operation supports two modes: 1. scatter_dim=2, gather_dim=1: Used for redistributing attention heads 2. scatter_dim=1, gather_dim=2: Used for redistributing sequence dimensions

fastvideo.distributed.device_communicators.cpu_communicator
Classes
fastvideo.distributed.device_communicators.cpu_communicator.CpuCommunicator
CpuCommunicator(cpu_group: ProcessGroup, device: device | None = None, device_group: ProcessGroup | None = None, unique_name: str = '')

Bases: DeviceCommunicatorBase

Source code in fastvideo/distributed/device_communicators/cpu_communicator.py
def __init__(self,
             cpu_group: ProcessGroup,
             device: torch.device | None = None,
             device_group: ProcessGroup | None = None,
             unique_name: str = ""):
    super().__init__(cpu_group, device, device_group, unique_name)
    self.dist_module = torch.distributed

    from fastvideo.platforms import current_platform

    if (current_platform.get_cpu_architecture()
            == CpuArchEnum.X86) and hasattr(
                torch.ops._C,
                "init_shm_manager") and unique_name.startswith("tp"):
        self.dist_module = _CPUSHMDistributed(self)
Functions
fastvideo.distributed.device_communicators.cpu_communicator.CpuCommunicator.gather
gather(input_: Tensor, dst: int = 0, dim: int = -1) -> Tensor | None

NOTE: We assume that the input tensor is on the same device across all the ranks. NOTE: dst is the local rank of the destination rank.

Source code in fastvideo/distributed/device_communicators/cpu_communicator.py
def gather(self,
           input_: torch.Tensor,
           dst: int = 0,
           dim: int = -1) -> torch.Tensor | None:
    """
    NOTE: We assume that the input tensor is on the same device across
    all the ranks.
    NOTE: `dst` is the local rank of the destination rank.
    """
    world_size = self.world_size
    assert -input_.dim() <= dim < input_.dim(), (
        f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
    if dim < 0:
        # Convert negative dim to positive.
        dim += input_.dim()

    # Allocate output tensor.
    if self.rank_in_group == dst:
        gather_list = [torch.empty_like(input_) for _ in range(world_size)]
    else:
        gather_list = None

    # Gather.
    self.dist_module.gather(input_,
                            gather_list,
                            dst=self.ranks[dst],
                            group=self.device_group)

    if self.rank_in_group == dst:
        output_tensor = torch.cat(gather_list, dim=dim)
    else:
        output_tensor = None
    return output_tensor
fastvideo.distributed.device_communicators.cuda_communicator
Classes
fastvideo.distributed.device_communicators.cuda_communicator.CudaCommunicator
CudaCommunicator(cpu_group: ProcessGroup, device: device | None = None, device_group: ProcessGroup | None = None, unique_name: str = '')

Bases: DeviceCommunicatorBase

Source code in fastvideo/distributed/device_communicators/cuda_communicator.py
def __init__(self,
             cpu_group: ProcessGroup,
             device: torch.device | None = None,
             device_group: ProcessGroup | None = None,
             unique_name: str = ""):
    super().__init__(cpu_group, device, device_group, unique_name)

    from fastvideo.distributed.device_communicators.pynccl import (
        PyNcclCommunicator)

    self.pynccl_comm: PyNcclCommunicator | None = None
    if self.world_size > 1:
        self.pynccl_comm = PyNcclCommunicator(
            group=self.cpu_group,
            device=self.device,
        )
Functions
fastvideo.distributed.device_communicators.cuda_communicator.CudaCommunicator.recv
recv(size: Size, dtype: dtype, src: int | None = None) -> Tensor

Receives a tensor from the source rank.

Source code in fastvideo/distributed/device_communicators/cuda_communicator.py
def recv(self,
         size: torch.Size,
         dtype: torch.dtype,
         src: int | None = None) -> torch.Tensor:
    """Receives a tensor from the source rank."""
    """NOTE: `src` is the local rank of the source rank."""
    if src is None:
        src = (self.rank_in_group - 1) % self.world_size

    tensor = torch.empty(size, dtype=dtype, device=self.device)
    pynccl_comm = self.pynccl_comm
    if pynccl_comm is not None and not pynccl_comm.disabled:
        pynccl_comm.recv(tensor, src)
    else:
        torch.distributed.recv(tensor, self.ranks[src], self.device_group)
    return tensor
fastvideo.distributed.device_communicators.cuda_communicator.CudaCommunicator.send
send(tensor: Tensor, dst: int | None = None) -> None

Sends a tensor to the destination rank in a non-blocking way

Source code in fastvideo/distributed/device_communicators/cuda_communicator.py
def send(self, tensor: torch.Tensor, dst: int | None = None) -> None:
    """Sends a tensor to the destination rank in a non-blocking way"""
    """NOTE: `dst` is the local rank of the destination rank."""
    if dst is None:
        dst = (self.rank_in_group + 1) % self.world_size

    pynccl_comm = self.pynccl_comm
    if pynccl_comm is not None and not pynccl_comm.disabled:
        pynccl_comm.send(tensor, dst)
    else:
        torch.distributed.send(tensor, self.ranks[dst], self.device_group)
fastvideo.distributed.device_communicators.npu_communicator
Classes
fastvideo.distributed.device_communicators.npu_communicator.NpuCommunicator
NpuCommunicator(cpu_group: ProcessGroup, device: device | None = None, device_group: ProcessGroup | None = None, unique_name: str = '')

Bases: DeviceCommunicatorBase

Source code in fastvideo/distributed/device_communicators/npu_communicator.py
def __init__(self,
             cpu_group: ProcessGroup,
             device: torch.device | None = None,
             device_group: ProcessGroup | None = None,
             unique_name: str = ""):
    super().__init__(cpu_group, device, device_group, unique_name)

    from fastvideo.distributed.device_communicators.pyhccl import (
        PyHcclCommunicator)

    self.pyhccl_comm: PyHcclCommunicator | None = None
    if self.world_size > 1:
        self.pyhccl_comm = PyHcclCommunicator(
            group=self.cpu_group,
            device=self.device,
        )
Functions
fastvideo.distributed.device_communicators.npu_communicator.NpuCommunicator.recv
recv(size: Size, dtype: dtype, src: int | None = None) -> Tensor

Receives a tensor from the source rank.

Source code in fastvideo/distributed/device_communicators/npu_communicator.py
def recv(self,
         size: torch.Size,
         dtype: torch.dtype,
         src: int | None = None) -> torch.Tensor:
    """Receives a tensor from the source rank."""
    """NOTE: `src` is the local rank of the source rank."""
    if src is None:
        src = (self.rank_in_group - 1) % self.world_size

    tensor = torch.empty(size, dtype=dtype, device=self.device)
    pyhccl_comm = self.pyhccl_comm
    if pyhccl_comm is not None and not pyhccl_comm.disabled:
        pyhccl_comm.recv(tensor, src)
    else:
        torch.distributed.recv(tensor, self.ranks[src], self.device_group)
    return tensor
fastvideo.distributed.device_communicators.npu_communicator.NpuCommunicator.send
send(tensor: Tensor, dst: int | None = None) -> None

Sends a tensor to the destination rank in a non-blocking way

Source code in fastvideo/distributed/device_communicators/npu_communicator.py
def send(self, tensor: torch.Tensor, dst: int | None = None) -> None:
    """Sends a tensor to the destination rank in a non-blocking way"""
    """NOTE: `dst` is the local rank of the destination rank."""
    if dst is None:
        dst = (self.rank_in_group + 1) % self.world_size

    pyhccl_comm = self.pyhccl_comm
    if pyhccl_comm is not None and not pyhccl_comm.disabled:
        pyhccl_comm.send(tensor, dst)
    else:
        torch.distributed.send(tensor, self.ranks[dst], self.device_group)
fastvideo.distributed.device_communicators.pyhccl
Classes
fastvideo.distributed.device_communicators.pyhccl.PyHcclCommunicator
PyHcclCommunicator(group: ProcessGroup | StatelessProcessGroup, device: int | str | device, library_path: str | None = None)

Parameters:

Name Type Description Default
group ProcessGroup | StatelessProcessGroup

the process group to work on. If None, it will use the default process group.

required
device int | str | device

the device to bind the PyHcclCommunicator to. If None, it will be bind to f"npu:{local_rank}".

required
library_path str | None

the path to the HCCL library. If None, it will use the default library path.

None

It is the caller's responsibility to make sure each communicator is bind to a unique device.

Source code in fastvideo/distributed/device_communicators/pyhccl.py
def __init__(
    self,
    group: ProcessGroup | StatelessProcessGroup,
    device: int | str | torch.device,
    library_path: str | None = None,
):
    """
    Args:
        group: the process group to work on. If None, it will use the
            default process group.
        device: the device to bind the PyHcclCommunicator to. If None,
            it will be bind to f"npu:{local_rank}".
        library_path: the path to the HCCL library. If None, it will
            use the default library path.
    It is the caller's responsibility to make sure each communicator
    is bind to a unique device.
    """

    if not isinstance(group, StatelessProcessGroup):
        assert dist.is_initialized()
        assert dist.get_backend(group) != dist.Backend.HCCL, (
            "PyHcclCommunicator should be attached to a non-HCCL group.")
        # note: this rank is the rank in the group
        self.rank = dist.get_rank(group)
        self.world_size = dist.get_world_size(group)
    else:
        self.rank = group.rank
        self.world_size = group.world_size

    self.group = group

    # if world_size == 1, no need to create communicator
    if self.world_size == 1:
        self.available = False
        self.disabled = True
        return

    try:
        self.hccl = HCCLLibrary(library_path)
    except Exception:
        logger.warning("disable hccl because of missing HCCL library")
        # disable because of missing HCCL library
        # e.g. in a non-NPU environment
        self.available = False
        self.disabled = True
        return

    self.available = True
    self.disabled = False

    logger.info("FastVideo is using pyhccl")

    if isinstance(device, int):
        device = torch.device(f"npu:{device}")
    elif isinstance(device, str):
        device = torch.device(device)
    # now `device` is a `torch.device` object
    assert isinstance(device, torch.device)
    self.device = device

    if self.rank == 0:
        # get the unique id from HCCL
        with torch.npu.device(device):
            self.unique_id = self.hccl.hcclGetUniqueId()
    else:
        # construct an empty unique id
        self.unique_id = hcclUniqueId()

    if not isinstance(group, StatelessProcessGroup):
        tensor = torch.ByteTensor(list(self.unique_id.internal))
        ranks = dist.get_process_group_ranks(group)
        # arg `src` in `broadcast` is the global rank
        dist.broadcast(tensor, src=ranks[0], group=group)
        byte_list = tensor.tolist()
        for i, byte in enumerate(byte_list):
            self.unique_id.internal[i] = byte
    else:
        self.unique_id = group.broadcast_obj(self.unique_id, src=0)

    # hccl communicator and stream will use this device
    # `torch.npu.device` is a context manager that changes the
    # current npu device to the specified one
    with torch.npu.device(device):
        self.comm: hcclComm_t = self.hccl.hcclCommInitRank(
            self.world_size, self.unique_id, self.rank)

        stream = current_stream()
        # A small all_reduce for warmup.
        data = torch.zeros(1, device=device)
        self.all_reduce(data)
        stream.synchronize()
        del data
Functions
Functions
fastvideo.distributed.device_communicators.pynccl
Classes
fastvideo.distributed.device_communicators.pynccl.PyNcclCommunicator
PyNcclCommunicator(group: ProcessGroup | StatelessProcessGroup, device: int | str | device, library_path: str | None = None)

Parameters:

Name Type Description Default
group ProcessGroup | StatelessProcessGroup

the process group to work on. If None, it will use the default process group.

required
device int | str | device

the device to bind the PyNcclCommunicator to. If None, it will be bind to f"cuda:{local_rank}".

required
library_path str | None

the path to the NCCL library. If None, it will use the default library path.

None

It is the caller's responsibility to make sure each communicator is bind to a unique device.

Source code in fastvideo/distributed/device_communicators/pynccl.py
def __init__(
    self,
    group: ProcessGroup | StatelessProcessGroup,
    device: int | str | torch.device,
    library_path: str | None = None,
):
    """
    Args:
        group: the process group to work on. If None, it will use the
            default process group.
        device: the device to bind the PyNcclCommunicator to. If None,
            it will be bind to f"cuda:{local_rank}".
        library_path: the path to the NCCL library. If None, it will
            use the default library path.
    It is the caller's responsibility to make sure each communicator
    is bind to a unique device.
    """
    if not isinstance(group, StatelessProcessGroup):
        assert dist.is_initialized()
        assert dist.get_backend(group) != dist.Backend.NCCL, (
            "PyNcclCommunicator should be attached to a non-NCCL group.")
        # note: this rank is the rank in the group
        self.rank = dist.get_rank(group)
        self.world_size = dist.get_world_size(group)
    else:
        self.rank = group.rank
        self.world_size = group.world_size

    self.group = group

    # if world_size == 1, no need to create communicator
    if self.world_size == 1:
        self.available = False
        self.disabled = True
        return
    try:
        self.nccl = NCCLLibrary(library_path)
    except Exception:
        # disable because of missing NCCL library
        # e.g. in a non-GPU environment
        self.available = False
        self.disabled = True
        return

    self.available = True
    self.disabled = False

    logger.info("FastVideo is using nccl==%s", self.nccl.ncclGetVersion())

    if self.rank == 0:
        # get the unique id from NCCL
        self.unique_id = self.nccl.ncclGetUniqueId()
    else:
        # construct an empty unique id
        self.unique_id = ncclUniqueId()

    if not isinstance(group, StatelessProcessGroup):
        tensor = torch.ByteTensor(list(self.unique_id.internal))
        ranks = dist.get_process_group_ranks(group)
        # arg `src` in `broadcast` is the global rank
        dist.broadcast(tensor, src=ranks[0], group=group)
        byte_list = tensor.tolist()
        for i, byte in enumerate(byte_list):
            self.unique_id.internal[i] = byte
    else:
        self.unique_id = group.broadcast_obj(self.unique_id, src=0)
    if isinstance(device, int):
        device = torch.device(f"cuda:{device}")
    elif isinstance(device, str):
        device = torch.device(device)
    # now `device` is a `torch.device` object
    assert isinstance(device, torch.device)
    self.device = device
    # nccl communicator and stream will use this device
    # `torch.cuda.device` is a context manager that changes the
    # current cuda device to the specified one
    with torch.cuda.device(device):
        self.comm: ncclComm_t = self.nccl.ncclCommInitRank(
            self.world_size, self.unique_id, self.rank)

        stream = current_stream()
        # A small all_reduce for warmup.
        data = torch.zeros(1, device=device)
        self.all_reduce(data)
        if stream is not None:
            stream.synchronize()
        del data
Functions
Functions

fastvideo.distributed.parallel_state

FastVideo distributed state. It takes over the control of the distributed environment from PyTorch. The typical workflow is:

  • call init_distributed_environment to initialize the distributed environment.
  • call initialize_model_parallel or ensure_model_parallel_initialized to initialize the model parallel groups.

  • any code dealing with the distributed stuff

  • call destroy_model_parallel to destroy the model parallel groups.

  • call destroy_distributed_environment to destroy the distributed environment.

If you only need to use the distributed environment without model parallelism, you can skip the model parallel initialization and destruction steps.

Classes

fastvideo.distributed.parallel_state.GroupCoordinator
GroupCoordinator(group_ranks: list[list[int]], local_rank: int, torch_distributed_backend: str | Backend, use_device_communicator: bool, use_message_queue_broadcaster: bool = False, group_name: str | None = None)

PyTorch ProcessGroup wrapper for a group of processes. PyTorch ProcessGroup is bound to one specific communication backend, e.g. NCCL, Gloo, MPI, etc. GroupCoordinator takes charge of all the communication operations among the processes in the group. It manages both CPU and device communication.

Source code in fastvideo/distributed/parallel_state.py
def __init__(
    self,
    group_ranks: list[list[int]],
    local_rank: int,
    torch_distributed_backend: str | Backend,
    use_device_communicator: bool,
    use_message_queue_broadcaster: bool = False,
    group_name: str | None = None,
):
    group_name = group_name or "anonymous"
    self.unique_name = _get_unique_name(group_name)
    _register_group(self)

    self.rank = torch.distributed.get_rank()
    self.local_rank = local_rank
    self.device_group = None
    self.cpu_group = None

    for ranks in group_ranks:
        device_group = torch.distributed.new_group(
            ranks, backend=torch_distributed_backend)
        # a group with `gloo` backend, to allow direct coordination between
        # processes through the CPU.
        cpu_group = torch.distributed.new_group(ranks, backend="gloo")
        if self.rank in ranks:
            self.ranks = ranks
            self.world_size = len(ranks)
            self.rank_in_group = ranks.index(self.rank)
            self.device_group = device_group
            self.cpu_group = cpu_group
    try:
        assert self.cpu_group is not None
        assert self.device_group is not None
    except Exception as e:
        print(f"rank: {self.rank} group not found")
        raise e

    from fastvideo.platforms import current_platform

    # TODO: fix it for other platforms
    self.device = get_local_torch_device()

    self.use_device_communicator = use_device_communicator
    self.device_communicator: DeviceCommunicatorBase = None  # type: ignore
    if use_device_communicator and self.world_size > 1:
        # Platform-aware device communicator selection
        if current_platform.is_cuda_alike():
            from fastvideo.distributed.device_communicators.cuda_communicator import (
                CudaCommunicator)
            self.device_communicator = CudaCommunicator(
                cpu_group=self.cpu_group,
                device=self.device,
                device_group=self.device_group,
                unique_name=self.unique_name,
            )
        elif current_platform.is_npu():
            from fastvideo.distributed.device_communicators.npu_communicator import (
                NpuCommunicator)
            self.device_communicator = NpuCommunicator(
                cpu_group=self.cpu_group,
                device=self.device,
                device_group=self.device_group,
                unique_name=self.unique_name,
            )
        else:
            # For MPS and CPU, use the CPU communicator
            self.device_communicator = CpuCommunicator(
                cpu_group=self.cpu_group,
                device=self.device,
                device_group=self.device_group,
                unique_name=self.unique_name,
            )

    self.mq_broadcaster = None

    from fastvideo.platforms import current_platform

    # TODO(will): check if this is needed
    # self.use_custom_op_call = current_platform.is_cuda_alike()
    self.use_custom_op_call = False
Attributes
fastvideo.distributed.parallel_state.GroupCoordinator.first_rank property
first_rank

Return the global rank of the first process in the group

fastvideo.distributed.parallel_state.GroupCoordinator.is_first_rank property
is_first_rank

Return whether the caller is the first process in the group

fastvideo.distributed.parallel_state.GroupCoordinator.is_last_rank property
is_last_rank

Return whether the caller is the last process in the group

fastvideo.distributed.parallel_state.GroupCoordinator.last_rank property
last_rank

Return the global rank of the last process in the group

fastvideo.distributed.parallel_state.GroupCoordinator.next_rank property
next_rank

Return the global rank of the process that follows the caller

fastvideo.distributed.parallel_state.GroupCoordinator.prev_rank property
prev_rank

Return the global rank of the process that precedes the caller

Functions
fastvideo.distributed.parallel_state.GroupCoordinator.all_reduce
all_reduce(input_: Tensor, op: ReduceOp | None = SUM) -> Tensor

User-facing all-reduce function before we actually call the all-reduce operation.

We need this because Dynamo does not support passing an arbitrary object (self in this case) to a custom op. We need to pass the group name as a string, and then look up the group coordinator from the group name, dispatch the all-reduce operation to the group coordinator.

In addition, PyTorch custom ops do not support mutation or returning a new tensor in the same op. So we always make the all-reduce operation out-of-place.

Source code in fastvideo/distributed/parallel_state.py
def all_reduce(
        self,
        input_: torch.Tensor,
        op: torch.distributed.ReduceOp | None = ReduceOp.SUM
) -> torch.Tensor:
    """
    User-facing all-reduce function before we actually call the
    all-reduce operation.

    We need this because Dynamo does not support passing an arbitrary
    object (`self` in this case) to a custom op. We need to pass the
     group name as a string, and then look up the group coordinator from
     the group name, dispatch the all-reduce operation to the group
     coordinator.

    In addition, PyTorch custom ops do not support mutation or returning
    a new tensor in the same op. So we always make the all-reduce operation
    out-of-place.
    """
    # Bypass the function if we are using only 1 GPU.
    if self.world_size == 1:
        return input_

    if self.use_custom_op_call:
        return torch.ops.vllm.all_reduce(input_,
                                         group_name=self.unique_name)
    else:
        return self._all_reduce_out_place(input_, op=op)
fastvideo.distributed.parallel_state.GroupCoordinator.barrier
barrier() -> None

Barrier synchronization among the group. NOTE: don't use device_group here! barrier in NCCL is terrible because it is internally a broadcast operation with secretly created GPU tensors. It is easy to mess up the current device. Use the CPU group instead.

Source code in fastvideo/distributed/parallel_state.py
def barrier(self) -> None:
    """Barrier synchronization among the group.
    NOTE: don't use `device_group` here! `barrier` in NCCL is
    terrible because it is internally a broadcast operation with
    secretly created GPU tensors. It is easy to mess up the current
    device. Use the CPU group instead.
    """
    torch.distributed.barrier(group=self.cpu_group)
fastvideo.distributed.parallel_state.GroupCoordinator.broadcast
broadcast(input_: Tensor, src: int = 0)

Broadcast the input tensor. NOTE: src is the local rank of the source rank.

Source code in fastvideo/distributed/parallel_state.py
def broadcast(self, input_: torch.Tensor, src: int = 0):
    """Broadcast the input tensor.
    NOTE: `src` is the local rank of the source rank.
    """
    assert src < self.world_size, f"Invalid src rank ({src})"

    # Bypass the function if we are using only 1 GPU.
    if self.world_size == 1:
        return input_
    # Broadcast.
    torch.distributed.broadcast(input_,
                                src=self.ranks[src],
                                group=self.device_group)
    return input_
fastvideo.distributed.parallel_state.GroupCoordinator.broadcast_object
broadcast_object(obj: Any | None = None, src: int = 0)

Broadcast the input object. NOTE: src is the local rank of the source rank.

Source code in fastvideo/distributed/parallel_state.py
def broadcast_object(self, obj: Any | None = None, src: int = 0):
    """Broadcast the input object.
    NOTE: `src` is the local rank of the source rank.
    """
    assert src < self.world_size, f"Invalid src rank ({src})"

    # Bypass the function if we are using only 1 GPU.
    if self.world_size == 1:
        return obj
    if self.mq_broadcaster is not None:
        assert src == 0, "Message queue broadcaster only supports src=0"
        return self.mq_broadcaster.broadcast_object(obj)
    if self.rank_in_group == src:
        torch.distributed.broadcast_object_list([obj],
                                                src=self.ranks[src],
                                                group=self.cpu_group)
        return obj
    else:
        recv = [None]
        torch.distributed.broadcast_object_list(recv,
                                                src=self.ranks[src],
                                                group=self.cpu_group)
        return recv[0]
fastvideo.distributed.parallel_state.GroupCoordinator.broadcast_object_list
broadcast_object_list(obj_list: list[Any], src: int = 0, group: ProcessGroup | None = None)

Broadcast the input object list. NOTE: src is the local rank of the source rank.

Source code in fastvideo/distributed/parallel_state.py
def broadcast_object_list(self,
                          obj_list: list[Any],
                          src: int = 0,
                          group: ProcessGroup | None = None):
    """Broadcast the input object list.
    NOTE: `src` is the local rank of the source rank.
    """
    assert src < self.world_size, f"Invalid src rank ({src})"

    # Bypass the function if we are using only 1 GPU.
    if self.world_size == 1:
        return obj_list
    # Broadcast.
    torch.distributed.broadcast_object_list(obj_list,
                                            src=self.ranks[src],
                                            group=self.device_group)
    return obj_list
fastvideo.distributed.parallel_state.GroupCoordinator.broadcast_tensor_dict
broadcast_tensor_dict(tensor_dict: dict[str, Tensor | Any] | None = None, src: int = 0, group: ProcessGroup | None = None, metadata_group: ProcessGroup | None = None) -> dict[str, Tensor | Any] | None

Broadcast the input tensor dictionary. NOTE: src is the local rank of the source rank.

Source code in fastvideo/distributed/parallel_state.py
def broadcast_tensor_dict(
    self,
    tensor_dict: dict[str, torch.Tensor | Any] | None = None,
    src: int = 0,
    group: ProcessGroup | None = None,
    metadata_group: ProcessGroup | None = None
) -> dict[str, torch.Tensor | Any] | None:
    """Broadcast the input tensor dictionary.
    NOTE: `src` is the local rank of the source rank.
    """
    # Bypass the function if we are using only 1 GPU.
    if (not torch.distributed.is_initialized() or self.world_size == 1):
        return tensor_dict

    group = self.device_group
    metadata_group = self.cpu_group
    assert src < self.world_size, f"Invalid src rank ({src})"

    rank_in_group = self.rank_in_group
    if rank_in_group == src:
        metadata_list: list[tuple[Any, Any]] = []
        assert isinstance(
            tensor_dict,
            dict), (f"Expecting a dictionary, got {type(tensor_dict)}")
        metadata_list, tensor_list = _split_tensor_dict(tensor_dict)
        # `metadata_list` lives in CPU memory.
        # `broadcast_object_list` has serialization & deserialization,
        # all happening on CPU. Therefore, we can use the CPU group.
        self.broadcast_object(metadata_list, src=src)
        async_handles = []
        for tensor in tensor_list:
            if tensor.numel() == 0:
                # Skip broadcasting empty tensors.
                continue
            if tensor.is_cpu:
                # use metadata_group for CPU tensors
                handle = torch.distributed.broadcast(tensor,
                                                     src=self.ranks[src],
                                                     group=metadata_group,
                                                     async_op=True)
            else:
                # use group for GPU tensors
                handle = torch.distributed.broadcast(tensor,
                                                     src=self.ranks[src],
                                                     group=group,
                                                     async_op=True)
            async_handles.append(handle)
        for async_handle in async_handles:
            async_handle.wait()

    else:
        metadata_list = self.broadcast_object(None, src=src)
        tensor_dict = {}
        async_handles = []
        for key, value in metadata_list:
            if isinstance(value, TensorMetadata):
                tensor = torch.empty(value.size,
                                     dtype=value.dtype,
                                     device=value.device)
                if tensor.numel() == 0:
                    # Skip broadcasting empty tensors.
                    tensor_dict[key] = tensor
                    continue
                if tensor.is_cpu:
                    # use metadata_group for CPU tensors
                    handle = torch.distributed.broadcast(
                        tensor,
                        src=self.ranks[src],
                        group=metadata_group,
                        async_op=True)
                else:
                    # use group for GPU tensors
                    handle = torch.distributed.broadcast(
                        tensor,
                        src=self.ranks[src],
                        group=group,
                        async_op=True)
                async_handles.append(handle)
                tensor_dict[key] = tensor
            else:
                tensor_dict[key] = value
        for async_handle in async_handles:
            async_handle.wait()
    return tensor_dict
fastvideo.distributed.parallel_state.GroupCoordinator.gather
gather(input_: Tensor, dst: int = 0, dim: int = -1) -> Tensor | None

NOTE: We assume that the input tensor is on the same device across all the ranks. NOTE: dst is the local rank of the destination rank.

Source code in fastvideo/distributed/parallel_state.py
def gather(self,
           input_: torch.Tensor,
           dst: int = 0,
           dim: int = -1) -> torch.Tensor | None:
    """
    NOTE: We assume that the input tensor is on the same device across
    all the ranks.
    NOTE: `dst` is the local rank of the destination rank.
    """
    world_size = self.world_size
    # Bypass the function if we are using only 1 GPU.
    if world_size == 1:
        return input_
    return self.device_communicator.gather(input_, dst, dim)
fastvideo.distributed.parallel_state.GroupCoordinator.recv
recv(size: Size, dtype: dtype, src: int | None = None) -> Tensor

Receives a tensor from the source rank.

Source code in fastvideo/distributed/parallel_state.py
def recv(self,
         size: torch.Size,
         dtype: torch.dtype,
         src: int | None = None) -> torch.Tensor:
    """Receives a tensor from the source rank."""
    """NOTE: `src` is the local rank of the source rank."""
    return self.device_communicator.recv(size, dtype, src)
fastvideo.distributed.parallel_state.GroupCoordinator.recv_object
recv_object(src: int) -> Any

Receive the input object list from the source rank.

Source code in fastvideo/distributed/parallel_state.py
def recv_object(self, src: int) -> Any:
    """Receive the input object list from the source rank."""
    """NOTE: `src` is the local rank of the source rank."""

    assert src < self.world_size, f"Invalid src rank ({src})"

    assert src != self.rank_in_group, (
        "Invalid source rank. Source rank is the same as the current rank.")

    size_tensor = torch.empty(1, dtype=torch.long, device="cpu")

    # Receive object size
    rank_size = torch.distributed.recv(size_tensor,
                                       src=self.ranks[src],
                                       group=self.cpu_group)

    # Tensor to receive serialized objects into.
    object_tensor = torch.empty(  # type: ignore[call-overload]
        size_tensor.item(),  # type: ignore[arg-type]
        dtype=torch.uint8,
        device="cpu")

    rank_object = torch.distributed.recv(object_tensor,
                                         src=self.ranks[src],
                                         group=self.cpu_group)

    assert rank_object == rank_size, (
        "Received object sender rank does not match the size sender rank.")

    obj = pickle.loads(object_tensor.numpy().tobytes())

    return obj
fastvideo.distributed.parallel_state.GroupCoordinator.recv_tensor_dict
recv_tensor_dict(src: int | None = None, all_gather_group: Optional[GroupCoordinator] = None) -> dict[str, Tensor | Any] | None

Recv the input tensor dictionary. NOTE: src is the local rank of the source rank.

Source code in fastvideo/distributed/parallel_state.py
def recv_tensor_dict(
    self,
    src: int | None = None,
    all_gather_group: Optional["GroupCoordinator"] = None,
) -> dict[str, torch.Tensor | Any] | None:
    """Recv the input tensor dictionary.
    NOTE: `src` is the local rank of the source rank.
    """
    # Bypass the function if we are using only 1 GPU.
    if not torch.distributed.is_initialized() or self.world_size == 1:
        return None

    all_gather_size = (1 if all_gather_group is None else
                       all_gather_group.world_size)
    all_gather_rank = (0 if all_gather_group is None else
                       all_gather_group.rank_in_group)

    group = self.device_group
    metadata_group = self.cpu_group

    if src is None:
        src = (self.rank_in_group - 1) % self.world_size
    assert src < self.world_size, f"Invalid src rank ({src})"

    recv_metadata_list = self.recv_object(src=src)
    tensor_dict: dict[str, Any] = {}
    for key, value in recv_metadata_list:
        if isinstance(value, TensorMetadata):
            tensor = torch.empty(value.size,
                                 dtype=value.dtype,
                                 device=value.device)
            if tensor.numel() == 0:
                # Skip broadcasting empty tensors.
                tensor_dict[key] = tensor
                continue

            # send-allgather: send only a slice, then do allgather.
            use_all_gather = (all_gather_group is not None
                              and tensor.numel() % all_gather_size == 0)

            if use_all_gather:
                orig_shape = tensor.shape
                tensor = tensor.reshape(all_gather_size,
                                        -1)[all_gather_rank]

            if tensor.is_cpu:
                # use metadata_group for CPU tensors
                torch.distributed.recv(tensor,
                                       src=self.ranks[src],
                                       group=metadata_group)
            else:
                # use group for GPU tensors
                torch.distributed.recv(tensor,
                                       src=self.ranks[src],
                                       group=group)
            if use_all_gather:
                # do the allgather
                tensor = all_gather_group.all_gather(  # type: ignore
                    tensor, dim=0)
                tensor = tensor.reshape(orig_shape)

            tensor_dict[key] = tensor
        else:
            tensor_dict[key] = value
    return tensor_dict
fastvideo.distributed.parallel_state.GroupCoordinator.send
send(tensor: Tensor, dst: int | None = None) -> None

Sends a tensor to the destination rank in a non-blocking way

Source code in fastvideo/distributed/parallel_state.py
def send(self, tensor: torch.Tensor, dst: int | None = None) -> None:
    """Sends a tensor to the destination rank in a non-blocking way"""
    """NOTE: `dst` is the local rank of the destination rank."""
    self.device_communicator.send(tensor, dst)
fastvideo.distributed.parallel_state.GroupCoordinator.send_object
send_object(obj: Any, dst: int) -> None

Send the input object list to the destination rank.

Source code in fastvideo/distributed/parallel_state.py
def send_object(self, obj: Any, dst: int) -> None:
    """Send the input object list to the destination rank."""
    """NOTE: `dst` is the local rank of the destination rank."""

    assert dst < self.world_size, f"Invalid dst rank ({dst})"

    assert dst != self.rank_in_group, (
        "Invalid destination rank. Destination rank is the same "
        "as the current rank.")

    # Serialize object to tensor and get the size as well
    object_tensor = torch.frombuffer(pickle.dumps(obj), dtype=torch.uint8)

    size_tensor = torch.tensor([object_tensor.numel()],
                               dtype=torch.long,
                               device="cpu")

    # Send object size

    torch.distributed.send(size_tensor,
                           dst=self.ranks[dst],
                           group=self.cpu_group)

    # Send object
    torch.distributed.send(object_tensor,
                           dst=self.ranks[dst],
                           group=self.cpu_group)

    return None
fastvideo.distributed.parallel_state.GroupCoordinator.send_tensor_dict
send_tensor_dict(tensor_dict: dict[str, Tensor | Any], dst: int | None = None, all_gather_group: Optional[GroupCoordinator] = None) -> dict[str, Tensor | Any] | None

Send the input tensor dictionary. NOTE: dst is the local rank of the source rank.

Source code in fastvideo/distributed/parallel_state.py
def send_tensor_dict(
    self,
    tensor_dict: dict[str, torch.Tensor | Any],
    dst: int | None = None,
    all_gather_group: Optional["GroupCoordinator"] = None,
) -> dict[str, torch.Tensor | Any] | None:
    """Send the input tensor dictionary.
    NOTE: `dst` is the local rank of the source rank.
    """
    # Bypass the function if we are using only 1 GPU.
    if not torch.distributed.is_initialized() or self.world_size == 1:
        return tensor_dict

    all_gather_size = (1 if all_gather_group is None else
                       all_gather_group.world_size)
    all_gather_rank = (0 if all_gather_group is None else
                       all_gather_group.rank_in_group)

    group = self.device_group
    metadata_group = self.cpu_group

    if dst is None:
        dst = (self.rank_in_group + 1) % self.world_size
    assert dst < self.world_size, f"Invalid dst rank ({dst})"

    metadata_list: list[tuple[Any, Any]] = []
    assert isinstance(
        tensor_dict,
        dict), f"Expecting a dictionary, got {type(tensor_dict)}"
    metadata_list, tensor_list = _split_tensor_dict(tensor_dict)
    # `metadata_list` lives in CPU memory.
    # `send_object_list` has serialization & deserialization,
    # all happening on CPU. Therefore, we can use the CPU group.
    self.send_object(metadata_list, dst=dst)
    for tensor in tensor_list:
        if tensor.numel() == 0:
            # Skip sending empty tensors.
            continue

        # send-allgather: send only a slice, then do allgather.
        if (all_gather_group is not None
                and tensor.numel() % all_gather_size == 0):
            tensor = tensor.reshape(all_gather_size, -1)[all_gather_rank]

        if tensor.is_cpu:
            # use metadata_group for CPU tensors
            torch.distributed.send(tensor,
                                   dst=self.ranks[dst],
                                   group=metadata_group)
        else:
            # use group for GPU tensors
            torch.distributed.send(tensor, dst=self.ranks[dst], group=group)
    return None

Functions

fastvideo.distributed.parallel_state.destroy_model_parallel
destroy_model_parallel() -> None

Set the groups to none and destroy them.

Source code in fastvideo/distributed/parallel_state.py
def destroy_model_parallel() -> None:
    """Set the groups to none and destroy them."""
    global _TP
    if _TP:
        _TP.destroy()
    _TP = None

    global _SP
    if _SP:
        _SP.destroy()
    _SP = None

    global _DP
    if _DP:
        _DP.destroy()
    _DP = None
fastvideo.distributed.parallel_state.get_dp_rank
get_dp_rank() -> int

Return my rank for the data parallel group.

Source code in fastvideo/distributed/parallel_state.py
def get_dp_rank() -> int:
    """Return my rank for the data parallel group."""
    return get_dp_group().rank_in_group
fastvideo.distributed.parallel_state.get_dp_world_size
get_dp_world_size() -> int

Return world size for the data parallel group.

Source code in fastvideo/distributed/parallel_state.py
def get_dp_world_size() -> int:
    """Return world size for the data parallel group."""
    return get_dp_group().world_size
fastvideo.distributed.parallel_state.get_local_torch_device
get_local_torch_device() -> device

Return the torch device for the current rank.

Source code in fastvideo/distributed/parallel_state.py
def get_local_torch_device() -> torch.device:
    """Return the torch device for the current rank."""
    from fastvideo.platforms import current_platform
    if current_platform.is_npu():
        device = torch.device(f"npu:{envs.LOCAL_RANK}")
    elif current_platform.is_cuda_alike() or current_platform.is_cuda():
        device = torch.device(f"cuda:{envs.LOCAL_RANK}")
    else:
        device = torch.device("mps")
    return device
fastvideo.distributed.parallel_state.get_sp_parallel_rank
get_sp_parallel_rank() -> int

Return my rank for the sequence model parallel group.

Source code in fastvideo/distributed/parallel_state.py
def get_sp_parallel_rank() -> int:
    """Return my rank for the sequence model parallel group."""
    return get_sp_group().rank_in_group
fastvideo.distributed.parallel_state.get_sp_world_size
get_sp_world_size() -> int

Return world size for the sequence model parallel group.

Source code in fastvideo/distributed/parallel_state.py
def get_sp_world_size() -> int:
    """Return world size for the sequence model parallel group."""
    return get_sp_group().world_size
fastvideo.distributed.parallel_state.get_tp_rank
get_tp_rank() -> int

Return my rank for the tensor model parallel group.

Source code in fastvideo/distributed/parallel_state.py
def get_tp_rank() -> int:
    """Return my rank for the tensor model parallel group."""
    return get_tp_group().rank_in_group
fastvideo.distributed.parallel_state.get_tp_world_size
get_tp_world_size() -> int

Return world size for the tensor model parallel group.

Source code in fastvideo/distributed/parallel_state.py
def get_tp_world_size() -> int:
    """Return world size for the tensor model parallel group."""
    return get_tp_group().world_size
fastvideo.distributed.parallel_state.get_world_rank
get_world_rank() -> int

Return my rank for the world group.

Source code in fastvideo/distributed/parallel_state.py
def get_world_rank() -> int:
    """Return my rank for the world group."""
    return get_world_group().rank
fastvideo.distributed.parallel_state.get_world_size
get_world_size() -> int

Return world size for the world group.

Source code in fastvideo/distributed/parallel_state.py
def get_world_size() -> int:
    """Return world size for the world group."""
    return get_world_group().world_size
fastvideo.distributed.parallel_state.initialize_model_parallel
initialize_model_parallel(tensor_model_parallel_size: int = 1, sequence_model_parallel_size: int = 1, data_parallel_size: int = 1, backend: str | None = None) -> None

Initialize model parallel groups.

Parameters:

Name Type Description Default
tensor_model_parallel_size int

number of GPUs used for tensor model parallelism (used for language encoder).

1
sequence_model_parallel_size int

number of GPUs used for sequence model parallelism (used for DiT).

1
Source code in fastvideo/distributed/parallel_state.py
def initialize_model_parallel(
    tensor_model_parallel_size: int = 1,
    sequence_model_parallel_size: int = 1,
    data_parallel_size: int = 1,
    backend: str | None = None,
) -> None:
    """
    Initialize model parallel groups.

    Arguments:
        tensor_model_parallel_size: number of GPUs used for tensor model
            parallelism (used for language encoder).
        sequence_model_parallel_size: number of GPUs used for sequence model
            parallelism (used for DiT).
    """
    # Get world size and rank. Ensure some consistencies.
    assert _WORLD is not None, "world group is not initialized, please call init_distributed_environment first"
    world_size: int = get_world_size()
    backend = backend or torch.distributed.get_backend(
        get_world_group().device_group)
    assert world_size >= tensor_model_parallel_size, f"world_size({world_size}) must be greater than or equal to tensor_model_parallel_size({tensor_model_parallel_size})"
    num_tensor_model_parallel_groups: int = (world_size //
                                             tensor_model_parallel_size)
    global _TP
    assert _TP is None, ("tensor model parallel group is already initialized")
    group_ranks = []
    for i in range(num_tensor_model_parallel_groups):
        ranks = list(
            range(i * tensor_model_parallel_size,
                  (i + 1) * tensor_model_parallel_size))
        group_ranks.append(ranks)

    # message queue broadcaster is only used in tensor model parallel group
    _TP = init_model_parallel_group(group_ranks,
                                    get_world_group().local_rank,
                                    backend,
                                    use_message_queue_broadcaster=True,
                                    group_name="tp")

    # Build the sequence model-parallel groups.
    num_sequence_model_parallel_groups: int = (world_size //
                                               sequence_model_parallel_size)
    global _SP
    assert _SP is None, ("sequence model parallel group is already initialized")
    group_ranks = []

    # Since SP is incompatible with TP and PP, we can use a simpler group creation logic
    for i in range(num_sequence_model_parallel_groups):
        # Create groups of consecutive ranks
        ranks = list(
            range(i * sequence_model_parallel_size,
                  (i + 1) * sequence_model_parallel_size))
        group_ranks.append(ranks)

    _SP = init_model_parallel_group(group_ranks,
                                    get_world_group().local_rank,
                                    backend,
                                    group_name="sp")

    # Build the data parallel groups.
    num_data_parallel_groups: int = sequence_model_parallel_size
    global _DP
    assert _DP is None, ("data parallel group is already initialized")
    group_ranks = []

    for i in range(num_data_parallel_groups):
        ranks = list(range(i, world_size, num_data_parallel_groups))
        group_ranks.append(ranks)

    _DP = init_model_parallel_group(group_ranks,
                                    get_world_group().local_rank,
                                    backend,
                                    group_name="dp")
fastvideo.distributed.parallel_state.initialize_sequence_parallel_group
initialize_sequence_parallel_group(sequence_model_parallel_size: int = 1, backend: str | None = None, group_name_suffix: str = '') -> GroupCoordinator

Initialize a sequence parallel group for a specific model.

This function creates a sequence parallel group that can be used with the patch_sequence_parallel_group context manager. It allows different models to use different sequence parallelism configurations.

Parameters:

Name Type Description Default
sequence_model_parallel_size int

number of GPUs used for sequence model parallelism.

1
backend str | None

communication backend to use.

None
group_name_suffix str

optional suffix to make the group name unique.

''

Returns:

Type Description
GroupCoordinator

A GroupCoordinator for sequence parallelism that can be used with

GroupCoordinator

the patch_sequence_parallel_group context manager.

Example usage
# Initialize sequence parallel group for model2
sp_group_model2 = initialize_sequence_parallel_group(
    sequence_model_parallel_size=2,
    group_name_suffix="model2"
)

# Use sequence parallelism for model2
with patch_sequence_parallel_group(sp_group_model2):
    # Run model2 with sequence parallelism
    output2 = model2(input2)
Source code in fastvideo/distributed/parallel_state.py
def initialize_sequence_parallel_group(
        sequence_model_parallel_size: int = 1,
        backend: str | None = None,
        group_name_suffix: str = "") -> GroupCoordinator:
    """Initialize a sequence parallel group for a specific model.

    This function creates a sequence parallel group that can be used with the
    patch_sequence_parallel_group context manager. It allows different models
    to use different sequence parallelism configurations.

    Arguments:
        sequence_model_parallel_size: number of GPUs used for sequence model parallelism.
        backend: communication backend to use.
        group_name_suffix: optional suffix to make the group name unique.

    Returns:
        A GroupCoordinator for sequence parallelism that can be used with
        the patch_sequence_parallel_group context manager.

    Example usage:
        ```python
        # Initialize sequence parallel group for model2
        sp_group_model2 = initialize_sequence_parallel_group(
            sequence_model_parallel_size=2,
            group_name_suffix="model2"
        )

        # Use sequence parallelism for model2
        with patch_sequence_parallel_group(sp_group_model2):
            # Run model2 with sequence parallelism
            output2 = model2(input2)
        ```
    """
    # Get world size and rank. Ensure some consistencies.
    assert torch.distributed.is_initialized()
    world_size: int = torch.distributed.get_world_size()
    backend = backend or torch.distributed.get_backend(
        get_world_group().device_group)

    # Ensure the world size is compatible with the parallelism configuration
    assert world_size % sequence_model_parallel_size == 0, \
        f"World size ({world_size}) must be divisible by sequence_model_parallel_size ({sequence_model_parallel_size})"

    # Build the sequence model-parallel groups.
    num_sequence_model_parallel_groups: int = (world_size //
                                               sequence_model_parallel_size)
    sp_group_ranks = []

    for i in range(num_sequence_model_parallel_groups):
        # Create groups of consecutive ranks
        ranks = list(
            range(i * sequence_model_parallel_size,
                  (i + 1) * sequence_model_parallel_size))
        sp_group_ranks.append(ranks)

    # Create SP group coordinator with a unique name
    group_name = f"sp_{group_name_suffix}" if group_name_suffix else "sp"
    sp_group = init_model_parallel_group(sp_group_ranks,
                                         get_world_group().local_rank,
                                         backend,
                                         group_name=group_name)

    return sp_group
fastvideo.distributed.parallel_state.initialize_tensor_parallel_group
initialize_tensor_parallel_group(tensor_model_parallel_size: int = 1, backend: str | None = None, group_name_suffix: str = '') -> GroupCoordinator

Initialize a tensor parallel group for a specific model.

This function creates a tensor parallel group that can be used with the patch_tensor_parallel_group context manager. It allows different models to use different tensor parallelism configurations.

Parameters:

Name Type Description Default
tensor_model_parallel_size int

number of GPUs used for tensor model parallelism.

1
backend str | None

communication backend to use.

None
group_name_suffix str

optional suffix to make the group name unique.

''

Returns:

Type Description
GroupCoordinator

A GroupCoordinator for tensor parallelism that can be used with

GroupCoordinator

the patch_tensor_parallel_group context manager.

Example usage
# Initialize tensor parallel group for model1
tp_group_model1 = initialize_tensor_parallel_group(
    tensor_model_parallel_size=4,
    group_name_suffix="model1"
)

# Use tensor parallelism for model1
with patch_tensor_parallel_group(tp_group_model1):
    # Run model1 with tensor parallelism
    output1 = model1(input1)
Source code in fastvideo/distributed/parallel_state.py
def initialize_tensor_parallel_group(
        tensor_model_parallel_size: int = 1,
        backend: str | None = None,
        group_name_suffix: str = "") -> GroupCoordinator:
    """Initialize a tensor parallel group for a specific model.

    This function creates a tensor parallel group that can be used with the
    patch_tensor_parallel_group context manager. It allows different models
    to use different tensor parallelism configurations.

    Arguments:
        tensor_model_parallel_size: number of GPUs used for tensor model parallelism.
        backend: communication backend to use.
        group_name_suffix: optional suffix to make the group name unique.

    Returns:
        A GroupCoordinator for tensor parallelism that can be used with
        the patch_tensor_parallel_group context manager.

    Example usage:
        ```python
        # Initialize tensor parallel group for model1
        tp_group_model1 = initialize_tensor_parallel_group(
            tensor_model_parallel_size=4,
            group_name_suffix="model1"
        )

        # Use tensor parallelism for model1
        with patch_tensor_parallel_group(tp_group_model1):
            # Run model1 with tensor parallelism
            output1 = model1(input1)
        ```
    """
    # Get world size and rank. Ensure some consistencies.
    assert torch.distributed.is_initialized()
    world_size: int = torch.distributed.get_world_size()
    backend = backend or torch.distributed.get_backend(
        get_world_group().device_group)

    # Ensure the world size is compatible with the parallelism configuration
    assert world_size % tensor_model_parallel_size == 0, \
        f"World size ({world_size}) must be divisible by tensor_model_parallel_size ({tensor_model_parallel_size})"

    # Build the tensor model-parallel groups.
    num_tensor_model_parallel_groups: int = (world_size //
                                             tensor_model_parallel_size)
    tp_group_ranks = []
    for i in range(num_tensor_model_parallel_groups):
        ranks = list(
            range(i * tensor_model_parallel_size,
                  (i + 1) * tensor_model_parallel_size))
        tp_group_ranks.append(ranks)

    # Create TP group coordinator with a unique name
    group_name = f"tp_{group_name_suffix}" if group_name_suffix else "tp"
    tp_group = init_model_parallel_group(tp_group_ranks,
                                         get_world_group().local_rank,
                                         backend,
                                         use_message_queue_broadcaster=True,
                                         group_name=group_name)

    return tp_group
fastvideo.distributed.parallel_state.is_the_same_node_as
is_the_same_node_as(pg: ProcessGroup | StatelessProcessGroup, source_rank: int = 0) -> list[int]

This is a collective operation that returns if each rank is in the same node as the source rank. It tests if processes are attached to the same memory system (shared access to shared memory).

Source code in fastvideo/distributed/parallel_state.py
def is_the_same_node_as(pg: ProcessGroup | StatelessProcessGroup,
                        source_rank: int = 0) -> list[int]:
    """
    This is a collective operation that returns if each rank is in the same node
    as the source rank. It tests if processes are attached to the same
    memory system (shared access to shared memory).
    """
    if isinstance(pg, ProcessGroup):
        assert torch.distributed.get_backend(
            pg) != torch.distributed.Backend.NCCL, (
                "in_the_same_node_as should be tested with a non-NCCL group.")
        # local rank inside the group
        rank = torch.distributed.get_rank(group=pg)
        world_size = torch.distributed.get_world_size(group=pg)

        # global ranks of the processes in the group
        ranks = torch.distributed.get_process_group_ranks(pg)
    else:
        rank = pg.rank
        world_size = pg.world_size
        ranks = list(range(world_size))

    # local tensor in each process to store the result
    is_in_the_same_node = torch.tensor([0] * world_size, dtype=torch.int32)

    magic_message = b"magic_message"
    shm = None

    try:
        with contextlib.suppress(OSError):
            if rank == source_rank:
                # create a shared memory segment
                shm = shared_memory.SharedMemory(create=True, size=128)
                shm.buf[:len(magic_message)] = magic_message
                if isinstance(pg, ProcessGroup):
                    torch.distributed.broadcast_object_list(
                        [shm.name], src=ranks[source_rank], group=pg)
                else:
                    pg.broadcast_obj(shm.name, src=source_rank)
                is_in_the_same_node[rank] = 1
            else:
                # try to open the shared memory segment
                if isinstance(pg, ProcessGroup):
                    recv = [None]
                    torch.distributed.broadcast_object_list(
                        recv, src=ranks[source_rank], group=pg)
                    name = recv[0]
                else:
                    name = pg.broadcast_obj(None, src=source_rank)
                # fix to https://stackoverflow.com/q/62748654/9191338
                # Python incorrectly tracks shared memory even if it is not
                # created by the process. The following patch is a workaround.
                with patch("multiprocessing.resource_tracker.register",
                           lambda *args, **kwargs: None):
                    shm = shared_memory.SharedMemory(name=name)
                if shm.buf[:len(magic_message)] == magic_message:
                    is_in_the_same_node[rank] = 1
    except Exception as e:
        logger.error("Error ignored in is_in_the_same_node: %s", e)
    finally:
        if shm:
            shm.close()

    if isinstance(pg, ProcessGroup):
        torch.distributed.barrier(group=pg)
    else:
        pg.barrier()

    # clean up the shared memory segment
    with contextlib.suppress(OSError):
        if rank == source_rank and shm:
            shm.unlink()

    if isinstance(pg, ProcessGroup):
        torch.distributed.all_reduce(is_in_the_same_node, group=pg)
        aggregated_data = is_in_the_same_node
    else:
        aggregated_data = torch.zeros_like(is_in_the_same_node)
        for i in range(world_size):
            rank_data = pg.broadcast_obj(is_in_the_same_node, src=i)
            aggregated_data += rank_data

    return [x == 1 for x in aggregated_data.tolist()]
fastvideo.distributed.parallel_state.model_parallel_is_initialized
model_parallel_is_initialized() -> bool

Check if tensor, sequence parallel groups are initialized.

Source code in fastvideo/distributed/parallel_state.py
def model_parallel_is_initialized() -> bool:
    """Check if tensor, sequence parallel groups are initialized."""
    return _TP is not None and _SP is not None and _DP is not None
fastvideo.distributed.parallel_state.patch_tensor_parallel_group
patch_tensor_parallel_group(tp_group: GroupCoordinator)

Patch the tp group temporarily until this function ends.

This method is for draft workers of speculative decoding to run draft model with different tp degree from that of target model workers.

Parameters:

Name Type Description Default
tp_group GroupCoordinator

the tp group coordinator

required
Source code in fastvideo/distributed/parallel_state.py
@contextmanager
def patch_tensor_parallel_group(tp_group: GroupCoordinator):
    """Patch the tp group temporarily until this function ends.

    This method is for draft workers of speculative decoding to run draft model
    with different tp degree from that of target model workers.

    Args:
        tp_group (GroupCoordinator): the tp group coordinator
    """
    global _TP_STATE_PATCHED
    assert not _TP_STATE_PATCHED, "Should not call when it's already patched"

    _TP_STATE_PATCHED = True
    old_tp_group = get_tp_group()
    global _TP
    _TP = tp_group
    try:
        yield
    finally:
        # restore the original state
        _TP_STATE_PATCHED = False
        _TP = old_tp_group

fastvideo.distributed.utils

Classes

fastvideo.distributed.utils.StatelessProcessGroup dataclass
StatelessProcessGroup(rank: int, world_size: int, store: Store, data_expiration_seconds: int = 3600, send_dst_counter: dict[int, int] = dict(), recv_src_counter: dict[int, int] = dict(), broadcast_send_counter: int = 0, broadcast_recv_src_counter: dict[int, int] = dict(), entries: deque[tuple[str, float]] = deque())

A dataclass to hold a metadata store, and the rank, world_size of the group. Only use it to communicate metadata between processes. For data-plane communication, create NCCL-related objects.

Functions
fastvideo.distributed.utils.StatelessProcessGroup.all_gather_obj
all_gather_obj(obj: Any) -> list[Any]

All gather an object from all ranks.

Source code in fastvideo/distributed/utils.py
def all_gather_obj(self, obj: Any) -> list[Any]:
    """All gather an object from all ranks."""
    gathered_objs = []
    for i in range(self.world_size):
        if i == self.rank:
            gathered_objs.append(obj)
            self.broadcast_obj(obj, src=self.rank)
        else:
            recv_obj = self.broadcast_obj(None, src=i)
            gathered_objs.append(recv_obj)
    return gathered_objs
fastvideo.distributed.utils.StatelessProcessGroup.barrier
barrier()

A barrier to synchronize all ranks.

Source code in fastvideo/distributed/utils.py
def barrier(self):
    """A barrier to synchronize all ranks."""
    for i in range(self.world_size):
        if i == self.rank:
            self.broadcast_obj(None, src=self.rank)
        else:
            self.broadcast_obj(None, src=i)
fastvideo.distributed.utils.StatelessProcessGroup.broadcast_obj
broadcast_obj(obj: Any | None, src: int) -> Any

Broadcast an object from a source rank to all other ranks. It does not clean up after all ranks have received the object. Use it for limited times, e.g., for initialization.

Source code in fastvideo/distributed/utils.py
def broadcast_obj(self, obj: Any | None, src: int) -> Any:
    """Broadcast an object from a source rank to all other ranks.
    It does not clean up after all ranks have received the object.
    Use it for limited times, e.g., for initialization.
    """
    if self.rank == src:
        self.expire_data()
        key = (f"broadcast_from/{src}/"
               f"{self.broadcast_send_counter}")
        self.store.set(key, pickle.dumps(obj))
        self.broadcast_send_counter += 1
        self.entries.append((key, time.perf_counter()))
        return obj
    else:
        key = (f"broadcast_from/{src}/"
               f"{self.broadcast_recv_src_counter[src]}")
        recv_obj = pickle.loads(self.store.get(key))
        self.broadcast_recv_src_counter[src] += 1
        return recv_obj
fastvideo.distributed.utils.StatelessProcessGroup.create staticmethod
create(host: str, port: int, rank: int, world_size: int, data_expiration_seconds: int = 3600) -> StatelessProcessGroup

A replacement for torch.distributed.init_process_group that does not pollute the global state.

If we have process A and process B called torch.distributed.init_process_group to form a group, and then we want to form another group with process A, B, C, D, it is not possible in PyTorch, because process A and process B have already formed a group, and process C and process D cannot join that group. This function is a workaround for this issue.

torch.distributed.init_process_group is a global call, while this function is a stateless call. It will return a StatelessProcessGroup object that can be used for exchanging metadata. With this function, process A and process B can call StatelessProcessGroup.create to form a group, and then process A, B, C, and D can call StatelessProcessGroup.create to form another group.

Source code in fastvideo/distributed/utils.py
@staticmethod
def create(
    host: str,
    port: int,
    rank: int,
    world_size: int,
    data_expiration_seconds: int = 3600,
) -> "StatelessProcessGroup":
    """A replacement for `torch.distributed.init_process_group` that does not
    pollute the global state.

    If we have process A and process B called `torch.distributed.init_process_group`
    to form a group, and then we want to form another group with process A, B, C,
    D, it is not possible in PyTorch, because process A and process B have already
    formed a group, and process C and process D cannot join that group. This
    function is a workaround for this issue.

    `torch.distributed.init_process_group` is a global call, while this function
    is a stateless call. It will return a `StatelessProcessGroup` object that can be
    used for exchanging metadata. With this function, process A and process B
    can call `StatelessProcessGroup.create` to form a group, and then process A, B,
    C, and D can call `StatelessProcessGroup.create` to form another group.
    """ # noqa
    store = TCPStore(
        host_name=host,
        port=port,
        world_size=world_size,
        is_master=(rank == 0),
    )

    return StatelessProcessGroup(
        rank=rank,
        world_size=world_size,
        store=store,
        data_expiration_seconds=data_expiration_seconds)
fastvideo.distributed.utils.StatelessProcessGroup.expire_data
expire_data() -> None

Expire data that is older than data_expiration_seconds seconds.

Source code in fastvideo/distributed/utils.py
def expire_data(self) -> None:
    """Expire data that is older than `data_expiration_seconds` seconds."""
    while self.entries:
        # check the oldest entry
        key, timestamp = self.entries[0]
        if time.perf_counter() - timestamp > self.data_expiration_seconds:
            self.store.delete_key(key)
            self.entries.popleft()
        else:
            break
fastvideo.distributed.utils.StatelessProcessGroup.recv_obj
recv_obj(src: int) -> Any

Receive an object from a source rank.

Source code in fastvideo/distributed/utils.py
def recv_obj(self, src: int) -> Any:
    """Receive an object from a source rank."""
    obj = pickle.loads(
        self.store.get(f"send_to/{self.rank}/{self.recv_src_counter[src]}"))
    self.recv_src_counter[src] += 1
    return obj
fastvideo.distributed.utils.StatelessProcessGroup.send_obj
send_obj(obj: Any, dst: int)

Send an object to a destination rank.

Source code in fastvideo/distributed/utils.py
def send_obj(self, obj: Any, dst: int):
    """Send an object to a destination rank."""
    self.expire_data()
    key = f"send_to/{dst}/{self.send_dst_counter[dst]}"
    self.store.set(key, pickle.dumps(obj))
    self.send_dst_counter[dst] += 1
    self.entries.append((key, time.perf_counter()))

Functions

fastvideo.distributed.utils.divide
divide(numerator: int, denominator: int) -> int

Ensure that numerator is divisible by the denominator and return the division value.

Source code in fastvideo/distributed/utils.py
def divide(numerator: int, denominator: int) -> int:
    """Ensure that numerator is divisible by the denominator and return
    the division value."""
    ensure_divisibility(numerator, denominator)
    return numerator // denominator
fastvideo.distributed.utils.ensure_divisibility
ensure_divisibility(numerator, denominator) -> None

Ensure that numerator is divisible by the denominator.

Source code in fastvideo/distributed/utils.py
def ensure_divisibility(numerator, denominator) -> None:
    """Ensure that numerator is divisible by the denominator."""
    assert numerator % denominator == 0, "{} is not divisible by {}".format(
        numerator, denominator)
fastvideo.distributed.utils.split_tensor_along_last_dim
split_tensor_along_last_dim(tensor: Tensor, num_partitions: int, contiguous_split_chunks: bool = False) -> Sequence[Tensor]

Split a tensor along its last dimension.

Parameters:

Name Type Description Default
tensor Tensor

input tensor.

required
num_partitions int

number of partitions to split the tensor

required
contiguous_split_chunks bool

If True, make each chunk contiguous in memory.

False

Returns:

Type Description
Sequence[Tensor]

A list of Tensors

Source code in fastvideo/distributed/utils.py
def split_tensor_along_last_dim(
    tensor: torch.Tensor,
    num_partitions: int,
    contiguous_split_chunks: bool = False,
) -> Sequence[torch.Tensor]:
    """ Split a tensor along its last dimension.

        Arguments:
            tensor: input tensor.
            num_partitions: number of partitions to split the tensor
            contiguous_split_chunks: If True, make each chunk contiguous
                                     in memory.

        Returns:
            A list of Tensors
    """
    # Get the size and dimension.
    last_dim = tensor.dim() - 1
    last_dim_size = divide(tensor.size()[last_dim], num_partitions)
    # Split.
    tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)
    # NOTE: torch.split does not create contiguous tensors by default.
    if contiguous_split_chunks:
        return tuple(chunk.contiguous() for chunk in tensor_list)

    return tuple(tensor_list)