Skip to content

parallel_state

FastVideo distributed state. It takes over the control of the distributed environment from PyTorch. The typical workflow is:

  • call init_distributed_environment to initialize the distributed environment.
  • call initialize_model_parallel or ensure_model_parallel_initialized to initialize the model parallel groups.

  • any code dealing with the distributed stuff

  • call destroy_model_parallel to destroy the model parallel groups.

  • call destroy_distributed_environment to destroy the distributed environment.

If you only need to use the distributed environment without model parallelism, you can skip the model parallel initialization and destruction steps.

Classes

fastvideo.distributed.parallel_state.GroupCoordinator

GroupCoordinator(group_ranks: list[list[int]], local_rank: int, torch_distributed_backend: str | Backend, use_device_communicator: bool, use_message_queue_broadcaster: bool = False, group_name: str | None = None)

PyTorch ProcessGroup wrapper for a group of processes. PyTorch ProcessGroup is bound to one specific communication backend, e.g. NCCL, Gloo, MPI, etc. GroupCoordinator takes charge of all the communication operations among the processes in the group. It manages both CPU and device communication.

Source code in fastvideo/distributed/parallel_state.py
def __init__(
    self,
    group_ranks: list[list[int]],
    local_rank: int,
    torch_distributed_backend: str | Backend,
    use_device_communicator: bool,
    use_message_queue_broadcaster: bool = False,
    group_name: str | None = None,
):
    group_name = group_name or "anonymous"
    self.unique_name = _get_unique_name(group_name)
    _register_group(self)

    self.rank = torch.distributed.get_rank()
    self.local_rank = local_rank
    self.device_group = None
    self.cpu_group = None

    for ranks in group_ranks:
        device_group = torch.distributed.new_group(
            ranks, backend=torch_distributed_backend)
        # a group with `gloo` backend, to allow direct coordination between
        # processes through the CPU.
        cpu_group = torch.distributed.new_group(ranks, backend="gloo")
        if self.rank in ranks:
            self.ranks = ranks
            self.world_size = len(ranks)
            self.rank_in_group = ranks.index(self.rank)
            self.device_group = device_group
            self.cpu_group = cpu_group
    try:
        assert self.cpu_group is not None
        assert self.device_group is not None
    except Exception as e:
        print(f"rank: {self.rank} group not found")
        raise e

    from fastvideo.platforms import current_platform

    # TODO: fix it for other platforms
    self.device = get_local_torch_device()

    self.use_device_communicator = use_device_communicator
    self.device_communicator: DeviceCommunicatorBase = None  # type: ignore
    if use_device_communicator and self.world_size > 1:
        # Platform-aware device communicator selection
        if current_platform.is_cuda_alike():
            from fastvideo.distributed.device_communicators.cuda_communicator import (
                CudaCommunicator)
            self.device_communicator = CudaCommunicator(
                cpu_group=self.cpu_group,
                device=self.device,
                device_group=self.device_group,
                unique_name=self.unique_name,
            )
        elif current_platform.is_npu():
            from fastvideo.distributed.device_communicators.npu_communicator import (
                NpuCommunicator)
            self.device_communicator = NpuCommunicator(
                cpu_group=self.cpu_group,
                device=self.device,
                device_group=self.device_group,
                unique_name=self.unique_name,
            )
        else:
            # For MPS and CPU, use the CPU communicator
            self.device_communicator = CpuCommunicator(
                cpu_group=self.cpu_group,
                device=self.device,
                device_group=self.device_group,
                unique_name=self.unique_name,
            )

    self.mq_broadcaster = None

    from fastvideo.platforms import current_platform

    # TODO(will): check if this is needed
    # self.use_custom_op_call = current_platform.is_cuda_alike()
    self.use_custom_op_call = False

Attributes

fastvideo.distributed.parallel_state.GroupCoordinator.first_rank property
first_rank

Return the global rank of the first process in the group

fastvideo.distributed.parallel_state.GroupCoordinator.is_first_rank property
is_first_rank

Return whether the caller is the first process in the group

fastvideo.distributed.parallel_state.GroupCoordinator.is_last_rank property
is_last_rank

Return whether the caller is the last process in the group

fastvideo.distributed.parallel_state.GroupCoordinator.last_rank property
last_rank

Return the global rank of the last process in the group

fastvideo.distributed.parallel_state.GroupCoordinator.next_rank property
next_rank

Return the global rank of the process that follows the caller

fastvideo.distributed.parallel_state.GroupCoordinator.prev_rank property
prev_rank

Return the global rank of the process that precedes the caller

Functions

fastvideo.distributed.parallel_state.GroupCoordinator.all_reduce
all_reduce(input_: Tensor, op: ReduceOp | None = SUM) -> Tensor

User-facing all-reduce function before we actually call the all-reduce operation.

We need this because Dynamo does not support passing an arbitrary object (self in this case) to a custom op. We need to pass the group name as a string, and then look up the group coordinator from the group name, dispatch the all-reduce operation to the group coordinator.

In addition, PyTorch custom ops do not support mutation or returning a new tensor in the same op. So we always make the all-reduce operation out-of-place.

Source code in fastvideo/distributed/parallel_state.py
def all_reduce(
        self,
        input_: torch.Tensor,
        op: torch.distributed.ReduceOp | None = ReduceOp.SUM
) -> torch.Tensor:
    """
    User-facing all-reduce function before we actually call the
    all-reduce operation.

    We need this because Dynamo does not support passing an arbitrary
    object (`self` in this case) to a custom op. We need to pass the
     group name as a string, and then look up the group coordinator from
     the group name, dispatch the all-reduce operation to the group
     coordinator.

    In addition, PyTorch custom ops do not support mutation or returning
    a new tensor in the same op. So we always make the all-reduce operation
    out-of-place.
    """
    # Bypass the function if we are using only 1 GPU.
    if self.world_size == 1:
        return input_

    if self.use_custom_op_call:
        return torch.ops.vllm.all_reduce(input_,
                                         group_name=self.unique_name)
    else:
        return self._all_reduce_out_place(input_, op=op)
fastvideo.distributed.parallel_state.GroupCoordinator.barrier
barrier() -> None

Barrier synchronization among the group. NOTE: don't use device_group here! barrier in NCCL is terrible because it is internally a broadcast operation with secretly created GPU tensors. It is easy to mess up the current device. Use the CPU group instead.

Source code in fastvideo/distributed/parallel_state.py
def barrier(self) -> None:
    """Barrier synchronization among the group.
    NOTE: don't use `device_group` here! `barrier` in NCCL is
    terrible because it is internally a broadcast operation with
    secretly created GPU tensors. It is easy to mess up the current
    device. Use the CPU group instead.
    """
    torch.distributed.barrier(group=self.cpu_group)
fastvideo.distributed.parallel_state.GroupCoordinator.broadcast
broadcast(input_: Tensor, src: int = 0)

Broadcast the input tensor. NOTE: src is the local rank of the source rank.

Source code in fastvideo/distributed/parallel_state.py
def broadcast(self, input_: torch.Tensor, src: int = 0):
    """Broadcast the input tensor.
    NOTE: `src` is the local rank of the source rank.
    """
    assert src < self.world_size, f"Invalid src rank ({src})"

    # Bypass the function if we are using only 1 GPU.
    if self.world_size == 1:
        return input_
    # Broadcast.
    torch.distributed.broadcast(input_,
                                src=self.ranks[src],
                                group=self.device_group)
    return input_
fastvideo.distributed.parallel_state.GroupCoordinator.broadcast_object
broadcast_object(obj: Any | None = None, src: int = 0)

Broadcast the input object. NOTE: src is the local rank of the source rank.

Source code in fastvideo/distributed/parallel_state.py
def broadcast_object(self, obj: Any | None = None, src: int = 0):
    """Broadcast the input object.
    NOTE: `src` is the local rank of the source rank.
    """
    assert src < self.world_size, f"Invalid src rank ({src})"

    # Bypass the function if we are using only 1 GPU.
    if self.world_size == 1:
        return obj
    if self.mq_broadcaster is not None:
        assert src == 0, "Message queue broadcaster only supports src=0"
        return self.mq_broadcaster.broadcast_object(obj)
    if self.rank_in_group == src:
        torch.distributed.broadcast_object_list([obj],
                                                src=self.ranks[src],
                                                group=self.cpu_group)
        return obj
    else:
        recv = [None]
        torch.distributed.broadcast_object_list(recv,
                                                src=self.ranks[src],
                                                group=self.cpu_group)
        return recv[0]
fastvideo.distributed.parallel_state.GroupCoordinator.broadcast_object_list
broadcast_object_list(obj_list: list[Any], src: int = 0, group: ProcessGroup | None = None)

Broadcast the input object list. NOTE: src is the local rank of the source rank.

Source code in fastvideo/distributed/parallel_state.py
def broadcast_object_list(self,
                          obj_list: list[Any],
                          src: int = 0,
                          group: ProcessGroup | None = None):
    """Broadcast the input object list.
    NOTE: `src` is the local rank of the source rank.
    """
    assert src < self.world_size, f"Invalid src rank ({src})"

    # Bypass the function if we are using only 1 GPU.
    if self.world_size == 1:
        return obj_list
    # Broadcast.
    torch.distributed.broadcast_object_list(obj_list,
                                            src=self.ranks[src],
                                            group=self.device_group)
    return obj_list
fastvideo.distributed.parallel_state.GroupCoordinator.broadcast_tensor_dict
broadcast_tensor_dict(tensor_dict: dict[str, Tensor | Any] | None = None, src: int = 0, group: ProcessGroup | None = None, metadata_group: ProcessGroup | None = None) -> dict[str, Tensor | Any] | None

Broadcast the input tensor dictionary. NOTE: src is the local rank of the source rank.

Source code in fastvideo/distributed/parallel_state.py
def broadcast_tensor_dict(
    self,
    tensor_dict: dict[str, torch.Tensor | Any] | None = None,
    src: int = 0,
    group: ProcessGroup | None = None,
    metadata_group: ProcessGroup | None = None
) -> dict[str, torch.Tensor | Any] | None:
    """Broadcast the input tensor dictionary.
    NOTE: `src` is the local rank of the source rank.
    """
    # Bypass the function if we are using only 1 GPU.
    if (not torch.distributed.is_initialized() or self.world_size == 1):
        return tensor_dict

    group = self.device_group
    metadata_group = self.cpu_group
    assert src < self.world_size, f"Invalid src rank ({src})"

    rank_in_group = self.rank_in_group
    if rank_in_group == src:
        metadata_list: list[tuple[Any, Any]] = []
        assert isinstance(
            tensor_dict,
            dict), (f"Expecting a dictionary, got {type(tensor_dict)}")
        metadata_list, tensor_list = _split_tensor_dict(tensor_dict)
        # `metadata_list` lives in CPU memory.
        # `broadcast_object_list` has serialization & deserialization,
        # all happening on CPU. Therefore, we can use the CPU group.
        self.broadcast_object(metadata_list, src=src)
        async_handles = []
        for tensor in tensor_list:
            if tensor.numel() == 0:
                # Skip broadcasting empty tensors.
                continue
            if tensor.is_cpu:
                # use metadata_group for CPU tensors
                handle = torch.distributed.broadcast(tensor,
                                                     src=self.ranks[src],
                                                     group=metadata_group,
                                                     async_op=True)
            else:
                # use group for GPU tensors
                handle = torch.distributed.broadcast(tensor,
                                                     src=self.ranks[src],
                                                     group=group,
                                                     async_op=True)
            async_handles.append(handle)
        for async_handle in async_handles:
            async_handle.wait()

    else:
        metadata_list = self.broadcast_object(None, src=src)
        tensor_dict = {}
        async_handles = []
        for key, value in metadata_list:
            if isinstance(value, TensorMetadata):
                tensor = torch.empty(value.size,
                                     dtype=value.dtype,
                                     device=value.device)
                if tensor.numel() == 0:
                    # Skip broadcasting empty tensors.
                    tensor_dict[key] = tensor
                    continue
                if tensor.is_cpu:
                    # use metadata_group for CPU tensors
                    handle = torch.distributed.broadcast(
                        tensor,
                        src=self.ranks[src],
                        group=metadata_group,
                        async_op=True)
                else:
                    # use group for GPU tensors
                    handle = torch.distributed.broadcast(
                        tensor,
                        src=self.ranks[src],
                        group=group,
                        async_op=True)
                async_handles.append(handle)
                tensor_dict[key] = tensor
            else:
                tensor_dict[key] = value
        for async_handle in async_handles:
            async_handle.wait()
    return tensor_dict
fastvideo.distributed.parallel_state.GroupCoordinator.gather
gather(input_: Tensor, dst: int = 0, dim: int = -1) -> Tensor | None

NOTE: We assume that the input tensor is on the same device across all the ranks. NOTE: dst is the local rank of the destination rank.

Source code in fastvideo/distributed/parallel_state.py
def gather(self,
           input_: torch.Tensor,
           dst: int = 0,
           dim: int = -1) -> torch.Tensor | None:
    """
    NOTE: We assume that the input tensor is on the same device across
    all the ranks.
    NOTE: `dst` is the local rank of the destination rank.
    """
    world_size = self.world_size
    # Bypass the function if we are using only 1 GPU.
    if world_size == 1:
        return input_
    return self.device_communicator.gather(input_, dst, dim)
fastvideo.distributed.parallel_state.GroupCoordinator.recv
recv(size: Size, dtype: dtype, src: int | None = None) -> Tensor

Receives a tensor from the source rank.

Source code in fastvideo/distributed/parallel_state.py
def recv(self,
         size: torch.Size,
         dtype: torch.dtype,
         src: int | None = None) -> torch.Tensor:
    """Receives a tensor from the source rank."""
    """NOTE: `src` is the local rank of the source rank."""
    return self.device_communicator.recv(size, dtype, src)
fastvideo.distributed.parallel_state.GroupCoordinator.recv_object
recv_object(src: int) -> Any

Receive the input object list from the source rank.

Source code in fastvideo/distributed/parallel_state.py
def recv_object(self, src: int) -> Any:
    """Receive the input object list from the source rank."""
    """NOTE: `src` is the local rank of the source rank."""

    assert src < self.world_size, f"Invalid src rank ({src})"

    assert src != self.rank_in_group, (
        "Invalid source rank. Source rank is the same as the current rank.")

    size_tensor = torch.empty(1, dtype=torch.long, device="cpu")

    # Receive object size
    rank_size = torch.distributed.recv(size_tensor,
                                       src=self.ranks[src],
                                       group=self.cpu_group)

    # Tensor to receive serialized objects into.
    object_tensor = torch.empty(  # type: ignore[call-overload]
        size_tensor.item(),  # type: ignore[arg-type]
        dtype=torch.uint8,
        device="cpu")

    rank_object = torch.distributed.recv(object_tensor,
                                         src=self.ranks[src],
                                         group=self.cpu_group)

    assert rank_object == rank_size, (
        "Received object sender rank does not match the size sender rank.")

    obj = pickle.loads(object_tensor.numpy().tobytes())

    return obj
fastvideo.distributed.parallel_state.GroupCoordinator.recv_tensor_dict
recv_tensor_dict(src: int | None = None, all_gather_group: Optional[GroupCoordinator] = None) -> dict[str, Tensor | Any] | None

Recv the input tensor dictionary. NOTE: src is the local rank of the source rank.

Source code in fastvideo/distributed/parallel_state.py
def recv_tensor_dict(
    self,
    src: int | None = None,
    all_gather_group: Optional["GroupCoordinator"] = None,
) -> dict[str, torch.Tensor | Any] | None:
    """Recv the input tensor dictionary.
    NOTE: `src` is the local rank of the source rank.
    """
    # Bypass the function if we are using only 1 GPU.
    if not torch.distributed.is_initialized() or self.world_size == 1:
        return None

    all_gather_size = (1 if all_gather_group is None else
                       all_gather_group.world_size)
    all_gather_rank = (0 if all_gather_group is None else
                       all_gather_group.rank_in_group)

    group = self.device_group
    metadata_group = self.cpu_group

    if src is None:
        src = (self.rank_in_group - 1) % self.world_size
    assert src < self.world_size, f"Invalid src rank ({src})"

    recv_metadata_list = self.recv_object(src=src)
    tensor_dict: dict[str, Any] = {}
    for key, value in recv_metadata_list:
        if isinstance(value, TensorMetadata):
            tensor = torch.empty(value.size,
                                 dtype=value.dtype,
                                 device=value.device)
            if tensor.numel() == 0:
                # Skip broadcasting empty tensors.
                tensor_dict[key] = tensor
                continue

            # send-allgather: send only a slice, then do allgather.
            use_all_gather = (all_gather_group is not None
                              and tensor.numel() % all_gather_size == 0)

            if use_all_gather:
                orig_shape = tensor.shape
                tensor = tensor.reshape(all_gather_size,
                                        -1)[all_gather_rank]

            if tensor.is_cpu:
                # use metadata_group for CPU tensors
                torch.distributed.recv(tensor,
                                       src=self.ranks[src],
                                       group=metadata_group)
            else:
                # use group for GPU tensors
                torch.distributed.recv(tensor,
                                       src=self.ranks[src],
                                       group=group)
            if use_all_gather:
                # do the allgather
                tensor = all_gather_group.all_gather(  # type: ignore
                    tensor, dim=0)
                tensor = tensor.reshape(orig_shape)

            tensor_dict[key] = tensor
        else:
            tensor_dict[key] = value
    return tensor_dict
fastvideo.distributed.parallel_state.GroupCoordinator.send
send(tensor: Tensor, dst: int | None = None) -> None

Sends a tensor to the destination rank in a non-blocking way

Source code in fastvideo/distributed/parallel_state.py
def send(self, tensor: torch.Tensor, dst: int | None = None) -> None:
    """Sends a tensor to the destination rank in a non-blocking way"""
    """NOTE: `dst` is the local rank of the destination rank."""
    self.device_communicator.send(tensor, dst)
fastvideo.distributed.parallel_state.GroupCoordinator.send_object
send_object(obj: Any, dst: int) -> None

Send the input object list to the destination rank.

Source code in fastvideo/distributed/parallel_state.py
def send_object(self, obj: Any, dst: int) -> None:
    """Send the input object list to the destination rank."""
    """NOTE: `dst` is the local rank of the destination rank."""

    assert dst < self.world_size, f"Invalid dst rank ({dst})"

    assert dst != self.rank_in_group, (
        "Invalid destination rank. Destination rank is the same "
        "as the current rank.")

    # Serialize object to tensor and get the size as well
    object_tensor = torch.frombuffer(pickle.dumps(obj), dtype=torch.uint8)

    size_tensor = torch.tensor([object_tensor.numel()],
                               dtype=torch.long,
                               device="cpu")

    # Send object size

    torch.distributed.send(size_tensor,
                           dst=self.ranks[dst],
                           group=self.cpu_group)

    # Send object
    torch.distributed.send(object_tensor,
                           dst=self.ranks[dst],
                           group=self.cpu_group)

    return None
fastvideo.distributed.parallel_state.GroupCoordinator.send_tensor_dict
send_tensor_dict(tensor_dict: dict[str, Tensor | Any], dst: int | None = None, all_gather_group: Optional[GroupCoordinator] = None) -> dict[str, Tensor | Any] | None

Send the input tensor dictionary. NOTE: dst is the local rank of the source rank.

Source code in fastvideo/distributed/parallel_state.py
def send_tensor_dict(
    self,
    tensor_dict: dict[str, torch.Tensor | Any],
    dst: int | None = None,
    all_gather_group: Optional["GroupCoordinator"] = None,
) -> dict[str, torch.Tensor | Any] | None:
    """Send the input tensor dictionary.
    NOTE: `dst` is the local rank of the source rank.
    """
    # Bypass the function if we are using only 1 GPU.
    if not torch.distributed.is_initialized() or self.world_size == 1:
        return tensor_dict

    all_gather_size = (1 if all_gather_group is None else
                       all_gather_group.world_size)
    all_gather_rank = (0 if all_gather_group is None else
                       all_gather_group.rank_in_group)

    group = self.device_group
    metadata_group = self.cpu_group

    if dst is None:
        dst = (self.rank_in_group + 1) % self.world_size
    assert dst < self.world_size, f"Invalid dst rank ({dst})"

    metadata_list: list[tuple[Any, Any]] = []
    assert isinstance(
        tensor_dict,
        dict), f"Expecting a dictionary, got {type(tensor_dict)}"
    metadata_list, tensor_list = _split_tensor_dict(tensor_dict)
    # `metadata_list` lives in CPU memory.
    # `send_object_list` has serialization & deserialization,
    # all happening on CPU. Therefore, we can use the CPU group.
    self.send_object(metadata_list, dst=dst)
    for tensor in tensor_list:
        if tensor.numel() == 0:
            # Skip sending empty tensors.
            continue

        # send-allgather: send only a slice, then do allgather.
        if (all_gather_group is not None
                and tensor.numel() % all_gather_size == 0):
            tensor = tensor.reshape(all_gather_size, -1)[all_gather_rank]

        if tensor.is_cpu:
            # use metadata_group for CPU tensors
            torch.distributed.send(tensor,
                                   dst=self.ranks[dst],
                                   group=metadata_group)
        else:
            # use group for GPU tensors
            torch.distributed.send(tensor, dst=self.ranks[dst], group=group)
    return None

Functions

fastvideo.distributed.parallel_state.destroy_model_parallel

destroy_model_parallel() -> None

Set the groups to none and destroy them.

Source code in fastvideo/distributed/parallel_state.py
def destroy_model_parallel() -> None:
    """Set the groups to none and destroy them."""
    global _TP
    if _TP:
        _TP.destroy()
    _TP = None

    global _SP
    if _SP:
        _SP.destroy()
    _SP = None

    global _DP
    if _DP:
        _DP.destroy()
    _DP = None

fastvideo.distributed.parallel_state.get_dp_rank

get_dp_rank() -> int

Return my rank for the data parallel group.

Source code in fastvideo/distributed/parallel_state.py
def get_dp_rank() -> int:
    """Return my rank for the data parallel group."""
    return get_dp_group().rank_in_group

fastvideo.distributed.parallel_state.get_dp_world_size

get_dp_world_size() -> int

Return world size for the data parallel group.

Source code in fastvideo/distributed/parallel_state.py
def get_dp_world_size() -> int:
    """Return world size for the data parallel group."""
    return get_dp_group().world_size

fastvideo.distributed.parallel_state.get_local_torch_device

get_local_torch_device() -> device

Return the torch device for the current rank.

Source code in fastvideo/distributed/parallel_state.py
def get_local_torch_device() -> torch.device:
    """Return the torch device for the current rank."""
    from fastvideo.platforms import current_platform
    if current_platform.is_npu():
        device = torch.device(f"npu:{envs.LOCAL_RANK}")
    elif current_platform.is_cuda_alike() or current_platform.is_cuda():
        device = torch.device(f"cuda:{envs.LOCAL_RANK}")
    else:
        device = torch.device("mps")
    return device

fastvideo.distributed.parallel_state.get_sp_parallel_rank

get_sp_parallel_rank() -> int

Return my rank for the sequence model parallel group.

Source code in fastvideo/distributed/parallel_state.py
def get_sp_parallel_rank() -> int:
    """Return my rank for the sequence model parallel group."""
    return get_sp_group().rank_in_group

fastvideo.distributed.parallel_state.get_sp_world_size

get_sp_world_size() -> int

Return world size for the sequence model parallel group.

Source code in fastvideo/distributed/parallel_state.py
def get_sp_world_size() -> int:
    """Return world size for the sequence model parallel group."""
    return get_sp_group().world_size

fastvideo.distributed.parallel_state.get_tp_rank

get_tp_rank() -> int

Return my rank for the tensor model parallel group.

Source code in fastvideo/distributed/parallel_state.py
def get_tp_rank() -> int:
    """Return my rank for the tensor model parallel group."""
    return get_tp_group().rank_in_group

fastvideo.distributed.parallel_state.get_tp_world_size

get_tp_world_size() -> int

Return world size for the tensor model parallel group.

Source code in fastvideo/distributed/parallel_state.py
def get_tp_world_size() -> int:
    """Return world size for the tensor model parallel group."""
    return get_tp_group().world_size

fastvideo.distributed.parallel_state.get_world_rank

get_world_rank() -> int

Return my rank for the world group.

Source code in fastvideo/distributed/parallel_state.py
def get_world_rank() -> int:
    """Return my rank for the world group."""
    return get_world_group().rank

fastvideo.distributed.parallel_state.get_world_size

get_world_size() -> int

Return world size for the world group.

Source code in fastvideo/distributed/parallel_state.py
def get_world_size() -> int:
    """Return world size for the world group."""
    return get_world_group().world_size

fastvideo.distributed.parallel_state.initialize_model_parallel

initialize_model_parallel(tensor_model_parallel_size: int = 1, sequence_model_parallel_size: int = 1, data_parallel_size: int = 1, backend: str | None = None) -> None

Initialize model parallel groups.

Parameters:

Name Type Description Default
tensor_model_parallel_size int

number of GPUs used for tensor model parallelism (used for language encoder).

1
sequence_model_parallel_size int

number of GPUs used for sequence model parallelism (used for DiT).

1
Source code in fastvideo/distributed/parallel_state.py
def initialize_model_parallel(
    tensor_model_parallel_size: int = 1,
    sequence_model_parallel_size: int = 1,
    data_parallel_size: int = 1,
    backend: str | None = None,
) -> None:
    """
    Initialize model parallel groups.

    Arguments:
        tensor_model_parallel_size: number of GPUs used for tensor model
            parallelism (used for language encoder).
        sequence_model_parallel_size: number of GPUs used for sequence model
            parallelism (used for DiT).
    """
    # Get world size and rank. Ensure some consistencies.
    assert _WORLD is not None, "world group is not initialized, please call init_distributed_environment first"
    world_size: int = get_world_size()
    backend = backend or torch.distributed.get_backend(
        get_world_group().device_group)
    assert world_size >= tensor_model_parallel_size, f"world_size({world_size}) must be greater than or equal to tensor_model_parallel_size({tensor_model_parallel_size})"
    num_tensor_model_parallel_groups: int = (world_size //
                                             tensor_model_parallel_size)
    global _TP
    assert _TP is None, ("tensor model parallel group is already initialized")
    group_ranks = []
    for i in range(num_tensor_model_parallel_groups):
        ranks = list(
            range(i * tensor_model_parallel_size,
                  (i + 1) * tensor_model_parallel_size))
        group_ranks.append(ranks)

    # message queue broadcaster is only used in tensor model parallel group
    _TP = init_model_parallel_group(group_ranks,
                                    get_world_group().local_rank,
                                    backend,
                                    use_message_queue_broadcaster=True,
                                    group_name="tp")

    # Build the sequence model-parallel groups.
    num_sequence_model_parallel_groups: int = (world_size //
                                               sequence_model_parallel_size)
    global _SP
    assert _SP is None, ("sequence model parallel group is already initialized")
    group_ranks = []

    # Since SP is incompatible with TP and PP, we can use a simpler group creation logic
    for i in range(num_sequence_model_parallel_groups):
        # Create groups of consecutive ranks
        ranks = list(
            range(i * sequence_model_parallel_size,
                  (i + 1) * sequence_model_parallel_size))
        group_ranks.append(ranks)

    _SP = init_model_parallel_group(group_ranks,
                                    get_world_group().local_rank,
                                    backend,
                                    group_name="sp")

    # Build the data parallel groups.
    num_data_parallel_groups: int = sequence_model_parallel_size
    global _DP
    assert _DP is None, ("data parallel group is already initialized")
    group_ranks = []

    for i in range(num_data_parallel_groups):
        ranks = list(range(i, world_size, num_data_parallel_groups))
        group_ranks.append(ranks)

    _DP = init_model_parallel_group(group_ranks,
                                    get_world_group().local_rank,
                                    backend,
                                    group_name="dp")

fastvideo.distributed.parallel_state.initialize_sequence_parallel_group

initialize_sequence_parallel_group(sequence_model_parallel_size: int = 1, backend: str | None = None, group_name_suffix: str = '') -> GroupCoordinator

Initialize a sequence parallel group for a specific model.

This function creates a sequence parallel group that can be used with the patch_sequence_parallel_group context manager. It allows different models to use different sequence parallelism configurations.

Parameters:

Name Type Description Default
sequence_model_parallel_size int

number of GPUs used for sequence model parallelism.

1
backend str | None

communication backend to use.

None
group_name_suffix str

optional suffix to make the group name unique.

''

Returns:

Type Description
GroupCoordinator

A GroupCoordinator for sequence parallelism that can be used with

GroupCoordinator

the patch_sequence_parallel_group context manager.

Example usage
# Initialize sequence parallel group for model2
sp_group_model2 = initialize_sequence_parallel_group(
    sequence_model_parallel_size=2,
    group_name_suffix="model2"
)

# Use sequence parallelism for model2
with patch_sequence_parallel_group(sp_group_model2):
    # Run model2 with sequence parallelism
    output2 = model2(input2)
Source code in fastvideo/distributed/parallel_state.py
def initialize_sequence_parallel_group(
        sequence_model_parallel_size: int = 1,
        backend: str | None = None,
        group_name_suffix: str = "") -> GroupCoordinator:
    """Initialize a sequence parallel group for a specific model.

    This function creates a sequence parallel group that can be used with the
    patch_sequence_parallel_group context manager. It allows different models
    to use different sequence parallelism configurations.

    Arguments:
        sequence_model_parallel_size: number of GPUs used for sequence model parallelism.
        backend: communication backend to use.
        group_name_suffix: optional suffix to make the group name unique.

    Returns:
        A GroupCoordinator for sequence parallelism that can be used with
        the patch_sequence_parallel_group context manager.

    Example usage:
        ```python
        # Initialize sequence parallel group for model2
        sp_group_model2 = initialize_sequence_parallel_group(
            sequence_model_parallel_size=2,
            group_name_suffix="model2"
        )

        # Use sequence parallelism for model2
        with patch_sequence_parallel_group(sp_group_model2):
            # Run model2 with sequence parallelism
            output2 = model2(input2)
        ```
    """
    # Get world size and rank. Ensure some consistencies.
    assert torch.distributed.is_initialized()
    world_size: int = torch.distributed.get_world_size()
    backend = backend or torch.distributed.get_backend(
        get_world_group().device_group)

    # Ensure the world size is compatible with the parallelism configuration
    assert world_size % sequence_model_parallel_size == 0, \
        f"World size ({world_size}) must be divisible by sequence_model_parallel_size ({sequence_model_parallel_size})"

    # Build the sequence model-parallel groups.
    num_sequence_model_parallel_groups: int = (world_size //
                                               sequence_model_parallel_size)
    sp_group_ranks = []

    for i in range(num_sequence_model_parallel_groups):
        # Create groups of consecutive ranks
        ranks = list(
            range(i * sequence_model_parallel_size,
                  (i + 1) * sequence_model_parallel_size))
        sp_group_ranks.append(ranks)

    # Create SP group coordinator with a unique name
    group_name = f"sp_{group_name_suffix}" if group_name_suffix else "sp"
    sp_group = init_model_parallel_group(sp_group_ranks,
                                         get_world_group().local_rank,
                                         backend,
                                         group_name=group_name)

    return sp_group

fastvideo.distributed.parallel_state.initialize_tensor_parallel_group

initialize_tensor_parallel_group(tensor_model_parallel_size: int = 1, backend: str | None = None, group_name_suffix: str = '') -> GroupCoordinator

Initialize a tensor parallel group for a specific model.

This function creates a tensor parallel group that can be used with the patch_tensor_parallel_group context manager. It allows different models to use different tensor parallelism configurations.

Parameters:

Name Type Description Default
tensor_model_parallel_size int

number of GPUs used for tensor model parallelism.

1
backend str | None

communication backend to use.

None
group_name_suffix str

optional suffix to make the group name unique.

''

Returns:

Type Description
GroupCoordinator

A GroupCoordinator for tensor parallelism that can be used with

GroupCoordinator

the patch_tensor_parallel_group context manager.

Example usage
# Initialize tensor parallel group for model1
tp_group_model1 = initialize_tensor_parallel_group(
    tensor_model_parallel_size=4,
    group_name_suffix="model1"
)

# Use tensor parallelism for model1
with patch_tensor_parallel_group(tp_group_model1):
    # Run model1 with tensor parallelism
    output1 = model1(input1)
Source code in fastvideo/distributed/parallel_state.py
def initialize_tensor_parallel_group(
        tensor_model_parallel_size: int = 1,
        backend: str | None = None,
        group_name_suffix: str = "") -> GroupCoordinator:
    """Initialize a tensor parallel group for a specific model.

    This function creates a tensor parallel group that can be used with the
    patch_tensor_parallel_group context manager. It allows different models
    to use different tensor parallelism configurations.

    Arguments:
        tensor_model_parallel_size: number of GPUs used for tensor model parallelism.
        backend: communication backend to use.
        group_name_suffix: optional suffix to make the group name unique.

    Returns:
        A GroupCoordinator for tensor parallelism that can be used with
        the patch_tensor_parallel_group context manager.

    Example usage:
        ```python
        # Initialize tensor parallel group for model1
        tp_group_model1 = initialize_tensor_parallel_group(
            tensor_model_parallel_size=4,
            group_name_suffix="model1"
        )

        # Use tensor parallelism for model1
        with patch_tensor_parallel_group(tp_group_model1):
            # Run model1 with tensor parallelism
            output1 = model1(input1)
        ```
    """
    # Get world size and rank. Ensure some consistencies.
    assert torch.distributed.is_initialized()
    world_size: int = torch.distributed.get_world_size()
    backend = backend or torch.distributed.get_backend(
        get_world_group().device_group)

    # Ensure the world size is compatible with the parallelism configuration
    assert world_size % tensor_model_parallel_size == 0, \
        f"World size ({world_size}) must be divisible by tensor_model_parallel_size ({tensor_model_parallel_size})"

    # Build the tensor model-parallel groups.
    num_tensor_model_parallel_groups: int = (world_size //
                                             tensor_model_parallel_size)
    tp_group_ranks = []
    for i in range(num_tensor_model_parallel_groups):
        ranks = list(
            range(i * tensor_model_parallel_size,
                  (i + 1) * tensor_model_parallel_size))
        tp_group_ranks.append(ranks)

    # Create TP group coordinator with a unique name
    group_name = f"tp_{group_name_suffix}" if group_name_suffix else "tp"
    tp_group = init_model_parallel_group(tp_group_ranks,
                                         get_world_group().local_rank,
                                         backend,
                                         use_message_queue_broadcaster=True,
                                         group_name=group_name)

    return tp_group

fastvideo.distributed.parallel_state.is_the_same_node_as

is_the_same_node_as(pg: ProcessGroup | StatelessProcessGroup, source_rank: int = 0) -> list[int]

This is a collective operation that returns if each rank is in the same node as the source rank. It tests if processes are attached to the same memory system (shared access to shared memory).

Source code in fastvideo/distributed/parallel_state.py
def is_the_same_node_as(pg: ProcessGroup | StatelessProcessGroup,
                        source_rank: int = 0) -> list[int]:
    """
    This is a collective operation that returns if each rank is in the same node
    as the source rank. It tests if processes are attached to the same
    memory system (shared access to shared memory).
    """
    if isinstance(pg, ProcessGroup):
        assert torch.distributed.get_backend(
            pg) != torch.distributed.Backend.NCCL, (
                "in_the_same_node_as should be tested with a non-NCCL group.")
        # local rank inside the group
        rank = torch.distributed.get_rank(group=pg)
        world_size = torch.distributed.get_world_size(group=pg)

        # global ranks of the processes in the group
        ranks = torch.distributed.get_process_group_ranks(pg)
    else:
        rank = pg.rank
        world_size = pg.world_size
        ranks = list(range(world_size))

    # local tensor in each process to store the result
    is_in_the_same_node = torch.tensor([0] * world_size, dtype=torch.int32)

    magic_message = b"magic_message"
    shm = None

    try:
        with contextlib.suppress(OSError):
            if rank == source_rank:
                # create a shared memory segment
                shm = shared_memory.SharedMemory(create=True, size=128)
                shm.buf[:len(magic_message)] = magic_message
                if isinstance(pg, ProcessGroup):
                    torch.distributed.broadcast_object_list(
                        [shm.name], src=ranks[source_rank], group=pg)
                else:
                    pg.broadcast_obj(shm.name, src=source_rank)
                is_in_the_same_node[rank] = 1
            else:
                # try to open the shared memory segment
                if isinstance(pg, ProcessGroup):
                    recv = [None]
                    torch.distributed.broadcast_object_list(
                        recv, src=ranks[source_rank], group=pg)
                    name = recv[0]
                else:
                    name = pg.broadcast_obj(None, src=source_rank)
                # fix to https://stackoverflow.com/q/62748654/9191338
                # Python incorrectly tracks shared memory even if it is not
                # created by the process. The following patch is a workaround.
                with patch("multiprocessing.resource_tracker.register",
                           lambda *args, **kwargs: None):
                    shm = shared_memory.SharedMemory(name=name)
                if shm.buf[:len(magic_message)] == magic_message:
                    is_in_the_same_node[rank] = 1
    except Exception as e:
        logger.error("Error ignored in is_in_the_same_node: %s", e)
    finally:
        if shm:
            shm.close()

    if isinstance(pg, ProcessGroup):
        torch.distributed.barrier(group=pg)
    else:
        pg.barrier()

    # clean up the shared memory segment
    with contextlib.suppress(OSError):
        if rank == source_rank and shm:
            shm.unlink()

    if isinstance(pg, ProcessGroup):
        torch.distributed.all_reduce(is_in_the_same_node, group=pg)
        aggregated_data = is_in_the_same_node
    else:
        aggregated_data = torch.zeros_like(is_in_the_same_node)
        for i in range(world_size):
            rank_data = pg.broadcast_obj(is_in_the_same_node, src=i)
            aggregated_data += rank_data

    return [x == 1 for x in aggregated_data.tolist()]

fastvideo.distributed.parallel_state.model_parallel_is_initialized

model_parallel_is_initialized() -> bool

Check if tensor, sequence parallel groups are initialized.

Source code in fastvideo/distributed/parallel_state.py
def model_parallel_is_initialized() -> bool:
    """Check if tensor, sequence parallel groups are initialized."""
    return _TP is not None and _SP is not None and _DP is not None

fastvideo.distributed.parallel_state.patch_tensor_parallel_group

patch_tensor_parallel_group(tp_group: GroupCoordinator)

Patch the tp group temporarily until this function ends.

This method is for draft workers of speculative decoding to run draft model with different tp degree from that of target model workers.

Parameters:

Name Type Description Default
tp_group GroupCoordinator

the tp group coordinator

required
Source code in fastvideo/distributed/parallel_state.py
@contextmanager
def patch_tensor_parallel_group(tp_group: GroupCoordinator):
    """Patch the tp group temporarily until this function ends.

    This method is for draft workers of speculative decoding to run draft model
    with different tp degree from that of target model workers.

    Args:
        tp_group (GroupCoordinator): the tp group coordinator
    """
    global _TP_STATE_PATCHED
    assert not _TP_STATE_PATCHED, "Should not call when it's already patched"

    _TP_STATE_PATCHED = True
    old_tp_group = get_tp_group()
    global _TP
    _TP = tp_group
    try:
        yield
    finally:
        # restore the original state
        _TP_STATE_PATCHED = False
        _TP = old_tp_group