Skip to content

device_communicators

Modules

fastvideo.distributed.device_communicators.base_device_communicator

Classes

fastvideo.distributed.device_communicators.base_device_communicator.DeviceCommunicatorBase
DeviceCommunicatorBase(cpu_group: ProcessGroup, device: device | None = None, device_group: ProcessGroup | None = None, unique_name: str = '')

Base class for device-specific communicator with autograd support. It can use the cpu_group to initialize the communicator. If the device has PyTorch integration (PyTorch can recognize its communication backend), the device_group will also be given.

Source code in fastvideo/distributed/device_communicators/base_device_communicator.py
def __init__(self,
             cpu_group: ProcessGroup,
             device: torch.device | None = None,
             device_group: ProcessGroup | None = None,
             unique_name: str = ""):
    self.device = device or torch.device("cpu")
    self.cpu_group = cpu_group
    self.device_group = device_group
    self.unique_name = unique_name
    self.rank = dist.get_rank(cpu_group)
    self.world_size = dist.get_world_size(cpu_group)
    self.ranks = dist.get_process_group_ranks(cpu_group)
    self.global_rank = dist.get_rank()
    self.global_world_size = dist.get_world_size()
    self.rank_in_group = dist.get_group_rank(self.cpu_group,
                                             self.global_rank)
Functions
fastvideo.distributed.device_communicators.base_device_communicator.DeviceCommunicatorBase.all_gather
all_gather(input_: Tensor, dim: int = -1) -> Tensor

Performs an all_gather operation with gradient support.

Source code in fastvideo/distributed/device_communicators/base_device_communicator.py
def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
    """Performs an all_gather operation with gradient support."""
    if dim < 0:
        dim += input_.dim()
    return DistributedAutograd.AllGather.apply(self.device_group, input_,
                                               self.world_size, dim)
fastvideo.distributed.device_communicators.base_device_communicator.DeviceCommunicatorBase.all_reduce
all_reduce(input_: Tensor, op: ReduceOp | None = SUM) -> Tensor

Performs an all_reduce operation with gradient support.

Source code in fastvideo/distributed/device_communicators/base_device_communicator.py
def all_reduce(self,
               input_: torch.Tensor,
               op: dist.ReduceOp | None = ReduceOp.SUM) -> torch.Tensor:
    """Performs an all_reduce operation with gradient support."""
    return DistributedAutograd.AllReduce.apply(self.device_group, input_,
                                               op)
fastvideo.distributed.device_communicators.base_device_communicator.DeviceCommunicatorBase.all_to_all_4D
all_to_all_4D(input_: Tensor, scatter_dim: int = 2, gather_dim: int = 1) -> Tensor

Performs a 4D all-to-all operation with gradient support.

Source code in fastvideo/distributed/device_communicators/base_device_communicator.py
def all_to_all_4D(self,
                  input_: torch.Tensor,
                  scatter_dim: int = 2,
                  gather_dim: int = 1) -> torch.Tensor:
    """Performs a 4D all-to-all operation with gradient support."""
    return DistributedAutograd.AllToAll4D.apply(self.device_group, input_,
                                                self.world_size,
                                                scatter_dim, gather_dim)
fastvideo.distributed.device_communicators.base_device_communicator.DeviceCommunicatorBase.gather
gather(input_: Tensor, dst: int = 0, dim: int = -1) -> Tensor | None

NOTE: We assume that the input tensor is on the same device across all the ranks. NOTE: dst is the local rank of the destination rank.

Source code in fastvideo/distributed/device_communicators/base_device_communicator.py
def gather(self,
           input_: torch.Tensor,
           dst: int = 0,
           dim: int = -1) -> torch.Tensor | None:
    """
    NOTE: We assume that the input tensor is on the same device across
    all the ranks.
    NOTE: `dst` is the local rank of the destination rank.
    """
    world_size = self.world_size
    assert -input_.dim() <= dim < input_.dim(), (
        f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
    if dim < 0:
        # Convert negative dim to positive.
        dim += input_.dim()

    # Allocate output tensor.
    if self.rank_in_group == dst:
        gather_list = [torch.empty_like(input_) for _ in range(world_size)]
    else:
        gather_list = None
    # Gather.
    torch.distributed.gather(input_,
                             gather_list,
                             dst=self.ranks[dst],
                             group=self.device_group)
    if self.rank_in_group == dst:
        output_tensor = torch.cat(gather_list, dim=dim)
    else:
        output_tensor = None
    return output_tensor
fastvideo.distributed.device_communicators.base_device_communicator.DeviceCommunicatorBase.recv
recv(size: Size, dtype: dtype, src: int | None = None) -> Tensor

Receives a tensor from the source rank.

Source code in fastvideo/distributed/device_communicators/base_device_communicator.py
def recv(self,
         size: torch.Size,
         dtype: torch.dtype,
         src: int | None = None) -> torch.Tensor:
    """Receives a tensor from the source rank."""
    """NOTE: `src` is the local rank of the source rank."""
    if src is None:
        src = (self.rank_in_group - 1) % self.world_size

    tensor = torch.empty(size, dtype=dtype, device=self.device)
    torch.distributed.recv(tensor, self.ranks[src], self.device_group)
    return tensor
fastvideo.distributed.device_communicators.base_device_communicator.DeviceCommunicatorBase.send
send(tensor: Tensor, dst: int | None = None) -> None

Sends a tensor to the destination rank in a non-blocking way

Source code in fastvideo/distributed/device_communicators/base_device_communicator.py
def send(self, tensor: torch.Tensor, dst: int | None = None) -> None:
    """Sends a tensor to the destination rank in a non-blocking way"""
    """NOTE: `dst` is the local rank of the destination rank."""
    if dst is None:
        dst = (self.rank_in_group + 1) % self.world_size
    torch.distributed.send(tensor, self.ranks[dst], self.device_group)
fastvideo.distributed.device_communicators.base_device_communicator.DistributedAutograd

Collection of autograd functions for distributed operations.

This class provides custom autograd functions for distributed operations like all_reduce, all_gather, and all_to_all. Each operation is implemented as a static inner class with proper forward and backward implementations.

Classes
fastvideo.distributed.device_communicators.base_device_communicator.DistributedAutograd.AllGather

Bases: Function

Differentiable all_gather operation.

The operation gathers tensors from all ranks and concatenates them along a specified dimension. The backward pass uses reduce_scatter to efficiently distribute gradients back to source ranks.

fastvideo.distributed.device_communicators.base_device_communicator.DistributedAutograd.AllReduce

Bases: Function

Differentiable all_reduce operation.

The gradient of all_reduce is another all_reduce operation since the operation combines values from all ranks equally.

fastvideo.distributed.device_communicators.base_device_communicator.DistributedAutograd.AllToAll4D

Bases: Function

Differentiable all_to_all operation specialized for 4D tensors.

This operation is particularly useful for attention operations where we need to redistribute data across ranks for efficient parallel processing.

The operation supports two modes: 1. scatter_dim=2, gather_dim=1: Used for redistributing attention heads 2. scatter_dim=1, gather_dim=2: Used for redistributing sequence dimensions

fastvideo.distributed.device_communicators.cpu_communicator

Classes

fastvideo.distributed.device_communicators.cpu_communicator.CpuCommunicator
CpuCommunicator(cpu_group: ProcessGroup, device: device | None = None, device_group: ProcessGroup | None = None, unique_name: str = '')

Bases: DeviceCommunicatorBase

Source code in fastvideo/distributed/device_communicators/cpu_communicator.py
def __init__(self,
             cpu_group: ProcessGroup,
             device: torch.device | None = None,
             device_group: ProcessGroup | None = None,
             unique_name: str = ""):
    super().__init__(cpu_group, device, device_group, unique_name)
    self.dist_module = torch.distributed

    from fastvideo.platforms import current_platform

    if (current_platform.get_cpu_architecture()
            == CpuArchEnum.X86) and hasattr(
                torch.ops._C,
                "init_shm_manager") and unique_name.startswith("tp"):
        self.dist_module = _CPUSHMDistributed(self)
Functions
fastvideo.distributed.device_communicators.cpu_communicator.CpuCommunicator.gather
gather(input_: Tensor, dst: int = 0, dim: int = -1) -> Tensor | None

NOTE: We assume that the input tensor is on the same device across all the ranks. NOTE: dst is the local rank of the destination rank.

Source code in fastvideo/distributed/device_communicators/cpu_communicator.py
def gather(self,
           input_: torch.Tensor,
           dst: int = 0,
           dim: int = -1) -> torch.Tensor | None:
    """
    NOTE: We assume that the input tensor is on the same device across
    all the ranks.
    NOTE: `dst` is the local rank of the destination rank.
    """
    world_size = self.world_size
    assert -input_.dim() <= dim < input_.dim(), (
        f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
    if dim < 0:
        # Convert negative dim to positive.
        dim += input_.dim()

    # Allocate output tensor.
    if self.rank_in_group == dst:
        gather_list = [torch.empty_like(input_) for _ in range(world_size)]
    else:
        gather_list = None

    # Gather.
    self.dist_module.gather(input_,
                            gather_list,
                            dst=self.ranks[dst],
                            group=self.device_group)

    if self.rank_in_group == dst:
        output_tensor = torch.cat(gather_list, dim=dim)
    else:
        output_tensor = None
    return output_tensor

fastvideo.distributed.device_communicators.cuda_communicator

Classes

fastvideo.distributed.device_communicators.cuda_communicator.CudaCommunicator
CudaCommunicator(cpu_group: ProcessGroup, device: device | None = None, device_group: ProcessGroup | None = None, unique_name: str = '')

Bases: DeviceCommunicatorBase

Source code in fastvideo/distributed/device_communicators/cuda_communicator.py
def __init__(self,
             cpu_group: ProcessGroup,
             device: torch.device | None = None,
             device_group: ProcessGroup | None = None,
             unique_name: str = ""):
    super().__init__(cpu_group, device, device_group, unique_name)

    from fastvideo.distributed.device_communicators.pynccl import (
        PyNcclCommunicator)

    self.pynccl_comm: PyNcclCommunicator | None = None
    if self.world_size > 1:
        self.pynccl_comm = PyNcclCommunicator(
            group=self.cpu_group,
            device=self.device,
        )
Functions
fastvideo.distributed.device_communicators.cuda_communicator.CudaCommunicator.recv
recv(size: Size, dtype: dtype, src: int | None = None) -> Tensor

Receives a tensor from the source rank.

Source code in fastvideo/distributed/device_communicators/cuda_communicator.py
def recv(self,
         size: torch.Size,
         dtype: torch.dtype,
         src: int | None = None) -> torch.Tensor:
    """Receives a tensor from the source rank."""
    """NOTE: `src` is the local rank of the source rank."""
    if src is None:
        src = (self.rank_in_group - 1) % self.world_size

    tensor = torch.empty(size, dtype=dtype, device=self.device)
    pynccl_comm = self.pynccl_comm
    if pynccl_comm is not None and not pynccl_comm.disabled:
        pynccl_comm.recv(tensor, src)
    else:
        torch.distributed.recv(tensor, self.ranks[src], self.device_group)
    return tensor
fastvideo.distributed.device_communicators.cuda_communicator.CudaCommunicator.send
send(tensor: Tensor, dst: int | None = None) -> None

Sends a tensor to the destination rank in a non-blocking way

Source code in fastvideo/distributed/device_communicators/cuda_communicator.py
def send(self, tensor: torch.Tensor, dst: int | None = None) -> None:
    """Sends a tensor to the destination rank in a non-blocking way"""
    """NOTE: `dst` is the local rank of the destination rank."""
    if dst is None:
        dst = (self.rank_in_group + 1) % self.world_size

    pynccl_comm = self.pynccl_comm
    if pynccl_comm is not None and not pynccl_comm.disabled:
        pynccl_comm.send(tensor, dst)
    else:
        torch.distributed.send(tensor, self.ranks[dst], self.device_group)

fastvideo.distributed.device_communicators.npu_communicator

Classes

fastvideo.distributed.device_communicators.npu_communicator.NpuCommunicator
NpuCommunicator(cpu_group: ProcessGroup, device: device | None = None, device_group: ProcessGroup | None = None, unique_name: str = '')

Bases: DeviceCommunicatorBase

Source code in fastvideo/distributed/device_communicators/npu_communicator.py
def __init__(self,
             cpu_group: ProcessGroup,
             device: torch.device | None = None,
             device_group: ProcessGroup | None = None,
             unique_name: str = ""):
    super().__init__(cpu_group, device, device_group, unique_name)

    from fastvideo.distributed.device_communicators.pyhccl import (
        PyHcclCommunicator)

    self.pyhccl_comm: PyHcclCommunicator | None = None
    if self.world_size > 1:
        self.pyhccl_comm = PyHcclCommunicator(
            group=self.cpu_group,
            device=self.device,
        )
Functions
fastvideo.distributed.device_communicators.npu_communicator.NpuCommunicator.recv
recv(size: Size, dtype: dtype, src: int | None = None) -> Tensor

Receives a tensor from the source rank.

Source code in fastvideo/distributed/device_communicators/npu_communicator.py
def recv(self,
         size: torch.Size,
         dtype: torch.dtype,
         src: int | None = None) -> torch.Tensor:
    """Receives a tensor from the source rank."""
    """NOTE: `src` is the local rank of the source rank."""
    if src is None:
        src = (self.rank_in_group - 1) % self.world_size

    tensor = torch.empty(size, dtype=dtype, device=self.device)
    pyhccl_comm = self.pyhccl_comm
    if pyhccl_comm is not None and not pyhccl_comm.disabled:
        pyhccl_comm.recv(tensor, src)
    else:
        torch.distributed.recv(tensor, self.ranks[src], self.device_group)
    return tensor
fastvideo.distributed.device_communicators.npu_communicator.NpuCommunicator.send
send(tensor: Tensor, dst: int | None = None) -> None

Sends a tensor to the destination rank in a non-blocking way

Source code in fastvideo/distributed/device_communicators/npu_communicator.py
def send(self, tensor: torch.Tensor, dst: int | None = None) -> None:
    """Sends a tensor to the destination rank in a non-blocking way"""
    """NOTE: `dst` is the local rank of the destination rank."""
    if dst is None:
        dst = (self.rank_in_group + 1) % self.world_size

    pyhccl_comm = self.pyhccl_comm
    if pyhccl_comm is not None and not pyhccl_comm.disabled:
        pyhccl_comm.send(tensor, dst)
    else:
        torch.distributed.send(tensor, self.ranks[dst], self.device_group)

fastvideo.distributed.device_communicators.pyhccl

Classes

fastvideo.distributed.device_communicators.pyhccl.PyHcclCommunicator
PyHcclCommunicator(group: ProcessGroup | StatelessProcessGroup, device: int | str | device, library_path: str | None = None)

Parameters:

Name Type Description Default
group ProcessGroup | StatelessProcessGroup

the process group to work on. If None, it will use the default process group.

required
device int | str | device

the device to bind the PyHcclCommunicator to. If None, it will be bind to f"npu:{local_rank}".

required
library_path str | None

the path to the HCCL library. If None, it will use the default library path.

None

It is the caller's responsibility to make sure each communicator is bind to a unique device.

Source code in fastvideo/distributed/device_communicators/pyhccl.py
def __init__(
    self,
    group: ProcessGroup | StatelessProcessGroup,
    device: int | str | torch.device,
    library_path: str | None = None,
):
    """
    Args:
        group: the process group to work on. If None, it will use the
            default process group.
        device: the device to bind the PyHcclCommunicator to. If None,
            it will be bind to f"npu:{local_rank}".
        library_path: the path to the HCCL library. If None, it will
            use the default library path.
    It is the caller's responsibility to make sure each communicator
    is bind to a unique device.
    """

    if not isinstance(group, StatelessProcessGroup):
        assert dist.is_initialized()
        assert dist.get_backend(group) != dist.Backend.HCCL, (
            "PyHcclCommunicator should be attached to a non-HCCL group.")
        # note: this rank is the rank in the group
        self.rank = dist.get_rank(group)
        self.world_size = dist.get_world_size(group)
    else:
        self.rank = group.rank
        self.world_size = group.world_size

    self.group = group

    # if world_size == 1, no need to create communicator
    if self.world_size == 1:
        self.available = False
        self.disabled = True
        return

    try:
        self.hccl = HCCLLibrary(library_path)
    except Exception:
        logger.warning("disable hccl because of missing HCCL library")
        # disable because of missing HCCL library
        # e.g. in a non-NPU environment
        self.available = False
        self.disabled = True
        return

    self.available = True
    self.disabled = False

    logger.info("FastVideo is using pyhccl")

    if isinstance(device, int):
        device = torch.device(f"npu:{device}")
    elif isinstance(device, str):
        device = torch.device(device)
    # now `device` is a `torch.device` object
    assert isinstance(device, torch.device)
    self.device = device

    if self.rank == 0:
        # get the unique id from HCCL
        with torch.npu.device(device):
            self.unique_id = self.hccl.hcclGetUniqueId()
    else:
        # construct an empty unique id
        self.unique_id = hcclUniqueId()

    if not isinstance(group, StatelessProcessGroup):
        tensor = torch.ByteTensor(list(self.unique_id.internal))
        ranks = dist.get_process_group_ranks(group)
        # arg `src` in `broadcast` is the global rank
        dist.broadcast(tensor, src=ranks[0], group=group)
        byte_list = tensor.tolist()
        for i, byte in enumerate(byte_list):
            self.unique_id.internal[i] = byte
    else:
        self.unique_id = group.broadcast_obj(self.unique_id, src=0)

    # hccl communicator and stream will use this device
    # `torch.npu.device` is a context manager that changes the
    # current npu device to the specified one
    with torch.npu.device(device):
        self.comm: hcclComm_t = self.hccl.hcclCommInitRank(
            self.world_size, self.unique_id, self.rank)

        stream = current_stream()
        # A small all_reduce for warmup.
        data = torch.zeros(1, device=device)
        self.all_reduce(data)
        stream.synchronize()
        del data
Functions

Functions

fastvideo.distributed.device_communicators.pynccl

Classes

fastvideo.distributed.device_communicators.pynccl.PyNcclCommunicator
PyNcclCommunicator(group: ProcessGroup | StatelessProcessGroup, device: int | str | device, library_path: str | None = None)

Parameters:

Name Type Description Default
group ProcessGroup | StatelessProcessGroup

the process group to work on. If None, it will use the default process group.

required
device int | str | device

the device to bind the PyNcclCommunicator to. If None, it will be bind to f"cuda:{local_rank}".

required
library_path str | None

the path to the NCCL library. If None, it will use the default library path.

None

It is the caller's responsibility to make sure each communicator is bind to a unique device.

Source code in fastvideo/distributed/device_communicators/pynccl.py
def __init__(
    self,
    group: ProcessGroup | StatelessProcessGroup,
    device: int | str | torch.device,
    library_path: str | None = None,
):
    """
    Args:
        group: the process group to work on. If None, it will use the
            default process group.
        device: the device to bind the PyNcclCommunicator to. If None,
            it will be bind to f"cuda:{local_rank}".
        library_path: the path to the NCCL library. If None, it will
            use the default library path.
    It is the caller's responsibility to make sure each communicator
    is bind to a unique device.
    """
    if not isinstance(group, StatelessProcessGroup):
        assert dist.is_initialized()
        assert dist.get_backend(group) != dist.Backend.NCCL, (
            "PyNcclCommunicator should be attached to a non-NCCL group.")
        # note: this rank is the rank in the group
        self.rank = dist.get_rank(group)
        self.world_size = dist.get_world_size(group)
    else:
        self.rank = group.rank
        self.world_size = group.world_size

    self.group = group

    # if world_size == 1, no need to create communicator
    if self.world_size == 1:
        self.available = False
        self.disabled = True
        return
    try:
        self.nccl = NCCLLibrary(library_path)
    except Exception:
        # disable because of missing NCCL library
        # e.g. in a non-GPU environment
        self.available = False
        self.disabled = True
        return

    self.available = True
    self.disabled = False

    logger.info("FastVideo is using nccl==%s", self.nccl.ncclGetVersion())

    if self.rank == 0:
        # get the unique id from NCCL
        self.unique_id = self.nccl.ncclGetUniqueId()
    else:
        # construct an empty unique id
        self.unique_id = ncclUniqueId()

    if not isinstance(group, StatelessProcessGroup):
        tensor = torch.ByteTensor(list(self.unique_id.internal))
        ranks = dist.get_process_group_ranks(group)
        # arg `src` in `broadcast` is the global rank
        dist.broadcast(tensor, src=ranks[0], group=group)
        byte_list = tensor.tolist()
        for i, byte in enumerate(byte_list):
            self.unique_id.internal[i] = byte
    else:
        self.unique_id = group.broadcast_obj(self.unique_id, src=0)
    if isinstance(device, int):
        device = torch.device(f"cuda:{device}")
    elif isinstance(device, str):
        device = torch.device(device)
    # now `device` is a `torch.device` object
    assert isinstance(device, torch.device)
    self.device = device
    # nccl communicator and stream will use this device
    # `torch.cuda.device` is a context manager that changes the
    # current cuda device to the specified one
    with torch.cuda.device(device):
        self.comm: ncclComm_t = self.nccl.ncclCommInitRank(
            self.world_size, self.unique_id, self.rank)

        stream = current_stream()
        # A small all_reduce for warmup.
        data = torch.zeros(1, device=device)
        self.all_reduce(data)
        if stream is not None:
            stream.synchronize()
        del data
Functions

Functions