Skip to content

cpu_communicator

Classes

fastvideo.distributed.device_communicators.cpu_communicator.CpuCommunicator

CpuCommunicator(cpu_group: ProcessGroup, device: device | None = None, device_group: ProcessGroup | None = None, unique_name: str = '')

Bases: DeviceCommunicatorBase

Source code in fastvideo/distributed/device_communicators/cpu_communicator.py
def __init__(self,
             cpu_group: ProcessGroup,
             device: torch.device | None = None,
             device_group: ProcessGroup | None = None,
             unique_name: str = ""):
    super().__init__(cpu_group, device, device_group, unique_name)
    self.dist_module = torch.distributed

    from fastvideo.platforms import current_platform

    if (current_platform.get_cpu_architecture()
            == CpuArchEnum.X86) and hasattr(
                torch.ops._C,
                "init_shm_manager") and unique_name.startswith("tp"):
        self.dist_module = _CPUSHMDistributed(self)

Functions

fastvideo.distributed.device_communicators.cpu_communicator.CpuCommunicator.gather
gather(input_: Tensor, dst: int = 0, dim: int = -1) -> Tensor | None

NOTE: We assume that the input tensor is on the same device across all the ranks. NOTE: dst is the local rank of the destination rank.

Source code in fastvideo/distributed/device_communicators/cpu_communicator.py
def gather(self,
           input_: torch.Tensor,
           dst: int = 0,
           dim: int = -1) -> torch.Tensor | None:
    """
    NOTE: We assume that the input tensor is on the same device across
    all the ranks.
    NOTE: `dst` is the local rank of the destination rank.
    """
    world_size = self.world_size
    assert -input_.dim() <= dim < input_.dim(), (
        f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
    if dim < 0:
        # Convert negative dim to positive.
        dim += input_.dim()

    # Allocate output tensor.
    if self.rank_in_group == dst:
        gather_list = [torch.empty_like(input_) for _ in range(world_size)]
    else:
        gather_list = None

    # Gather.
    self.dist_module.gather(input_,
                            gather_list,
                            dst=self.ranks[dst],
                            group=self.device_group)

    if self.rank_in_group == dst:
        output_tensor = torch.cat(gather_list, dim=dim)
    else:
        output_tensor = None
    return output_tensor