Skip to content

base_device_communicator

Classes

fastvideo.distributed.device_communicators.base_device_communicator.DeviceCommunicatorBase

DeviceCommunicatorBase(cpu_group: ProcessGroup, device: device | None = None, device_group: ProcessGroup | None = None, unique_name: str = '')

Base class for device-specific communicator with autograd support. It can use the cpu_group to initialize the communicator. If the device has PyTorch integration (PyTorch can recognize its communication backend), the device_group will also be given.

Source code in fastvideo/distributed/device_communicators/base_device_communicator.py
def __init__(self,
             cpu_group: ProcessGroup,
             device: torch.device | None = None,
             device_group: ProcessGroup | None = None,
             unique_name: str = ""):
    self.device = device or torch.device("cpu")
    self.cpu_group = cpu_group
    self.device_group = device_group
    self.unique_name = unique_name
    self.rank = dist.get_rank(cpu_group)
    self.world_size = dist.get_world_size(cpu_group)
    self.ranks = dist.get_process_group_ranks(cpu_group)
    self.global_rank = dist.get_rank()
    self.global_world_size = dist.get_world_size()
    self.rank_in_group = dist.get_group_rank(self.cpu_group,
                                             self.global_rank)

Functions

fastvideo.distributed.device_communicators.base_device_communicator.DeviceCommunicatorBase.all_gather
all_gather(input_: Tensor, dim: int = -1) -> Tensor

Performs an all_gather operation with gradient support.

Source code in fastvideo/distributed/device_communicators/base_device_communicator.py
def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
    """Performs an all_gather operation with gradient support."""
    if dim < 0:
        dim += input_.dim()
    return DistributedAutograd.AllGather.apply(self.device_group, input_,
                                               self.world_size, dim)
fastvideo.distributed.device_communicators.base_device_communicator.DeviceCommunicatorBase.all_reduce
all_reduce(input_: Tensor, op: ReduceOp | None = SUM) -> Tensor

Performs an all_reduce operation with gradient support.

Source code in fastvideo/distributed/device_communicators/base_device_communicator.py
def all_reduce(self,
               input_: torch.Tensor,
               op: dist.ReduceOp | None = ReduceOp.SUM) -> torch.Tensor:
    """Performs an all_reduce operation with gradient support."""
    return DistributedAutograd.AllReduce.apply(self.device_group, input_,
                                               op)
fastvideo.distributed.device_communicators.base_device_communicator.DeviceCommunicatorBase.all_to_all_4D
all_to_all_4D(input_: Tensor, scatter_dim: int = 2, gather_dim: int = 1) -> Tensor

Performs a 4D all-to-all operation with gradient support.

Source code in fastvideo/distributed/device_communicators/base_device_communicator.py
def all_to_all_4D(self,
                  input_: torch.Tensor,
                  scatter_dim: int = 2,
                  gather_dim: int = 1) -> torch.Tensor:
    """Performs a 4D all-to-all operation with gradient support."""
    return DistributedAutograd.AllToAll4D.apply(self.device_group, input_,
                                                self.world_size,
                                                scatter_dim, gather_dim)
fastvideo.distributed.device_communicators.base_device_communicator.DeviceCommunicatorBase.gather
gather(input_: Tensor, dst: int = 0, dim: int = -1) -> Tensor | None

NOTE: We assume that the input tensor is on the same device across all the ranks. NOTE: dst is the local rank of the destination rank.

Source code in fastvideo/distributed/device_communicators/base_device_communicator.py
def gather(self,
           input_: torch.Tensor,
           dst: int = 0,
           dim: int = -1) -> torch.Tensor | None:
    """
    NOTE: We assume that the input tensor is on the same device across
    all the ranks.
    NOTE: `dst` is the local rank of the destination rank.
    """
    world_size = self.world_size
    assert -input_.dim() <= dim < input_.dim(), (
        f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
    if dim < 0:
        # Convert negative dim to positive.
        dim += input_.dim()

    # Allocate output tensor.
    if self.rank_in_group == dst:
        gather_list = [torch.empty_like(input_) for _ in range(world_size)]
    else:
        gather_list = None
    # Gather.
    torch.distributed.gather(input_,
                             gather_list,
                             dst=self.ranks[dst],
                             group=self.device_group)
    if self.rank_in_group == dst:
        output_tensor = torch.cat(gather_list, dim=dim)
    else:
        output_tensor = None
    return output_tensor
fastvideo.distributed.device_communicators.base_device_communicator.DeviceCommunicatorBase.recv
recv(size: Size, dtype: dtype, src: int | None = None) -> Tensor

Receives a tensor from the source rank.

Source code in fastvideo/distributed/device_communicators/base_device_communicator.py
def recv(self,
         size: torch.Size,
         dtype: torch.dtype,
         src: int | None = None) -> torch.Tensor:
    """Receives a tensor from the source rank."""
    """NOTE: `src` is the local rank of the source rank."""
    if src is None:
        src = (self.rank_in_group - 1) % self.world_size

    tensor = torch.empty(size, dtype=dtype, device=self.device)
    torch.distributed.recv(tensor, self.ranks[src], self.device_group)
    return tensor
fastvideo.distributed.device_communicators.base_device_communicator.DeviceCommunicatorBase.send
send(tensor: Tensor, dst: int | None = None) -> None

Sends a tensor to the destination rank in a non-blocking way

Source code in fastvideo/distributed/device_communicators/base_device_communicator.py
def send(self, tensor: torch.Tensor, dst: int | None = None) -> None:
    """Sends a tensor to the destination rank in a non-blocking way"""
    """NOTE: `dst` is the local rank of the destination rank."""
    if dst is None:
        dst = (self.rank_in_group + 1) % self.world_size
    torch.distributed.send(tensor, self.ranks[dst], self.device_group)

fastvideo.distributed.device_communicators.base_device_communicator.DistributedAutograd

Collection of autograd functions for distributed operations.

This class provides custom autograd functions for distributed operations like all_reduce, all_gather, and all_to_all. Each operation is implemented as a static inner class with proper forward and backward implementations.

Classes

fastvideo.distributed.device_communicators.base_device_communicator.DistributedAutograd.AllGather

Bases: Function

Differentiable all_gather operation.

The operation gathers tensors from all ranks and concatenates them along a specified dimension. The backward pass uses reduce_scatter to efficiently distribute gradients back to source ranks.

fastvideo.distributed.device_communicators.base_device_communicator.DistributedAutograd.AllReduce

Bases: Function

Differentiable all_reduce operation.

The gradient of all_reduce is another all_reduce operation since the operation combines values from all ranks equally.

fastvideo.distributed.device_communicators.base_device_communicator.DistributedAutograd.AllToAll4D

Bases: Function

Differentiable all_to_all operation specialized for 4D tensors.

This operation is particularly useful for attention operations where we need to redistribute data across ranks for efficient parallel processing.

The operation supports two modes: 1. scatter_dim=2, gather_dim=1: Used for redistributing attention heads 2. scatter_dim=1, gather_dim=2: Used for redistributing sequence dimensions