fastvideo.v1.distributed.device_communicators.base_device_communicator#

Module Contents#

Classes#

DeviceCommunicatorBase

Base class for device-specific communicator with autograd support. It can use the cpu_group to initialize the communicator. If the device has PyTorch integration (PyTorch can recognize its communication backend), the device_group will also be given.

DistributedAutograd

Collection of autograd functions for distributed operations.

API#

class fastvideo.v1.distributed.device_communicators.base_device_communicator.DeviceCommunicatorBase(cpu_group: torch.distributed.ProcessGroup, device: Optional[torch.device] = None, device_group: Optional[torch.distributed.ProcessGroup] = None, unique_name: str = '')[source]#

Base class for device-specific communicator with autograd support. It can use the cpu_group to initialize the communicator. If the device has PyTorch integration (PyTorch can recognize its communication backend), the device_group will also be given.

Initialization

all_gather(input_: torch.Tensor, dim: int = -1) torch.Tensor[source]#

Performs an all_gather operation with gradient support.

all_reduce(input_: torch.Tensor, op: Optional[torch.distributed.ReduceOp] = ReduceOp.SUM) torch.Tensor[source]#

Performs an all_reduce operation with gradient support.

all_to_all_4D(input_: torch.Tensor, scatter_dim: int = 2, gather_dim: int = 1) torch.Tensor[source]#

Performs a 4D all-to-all operation with gradient support.

destroy() None[source]#
gather(input_: torch.Tensor, dst: int = 0, dim: int = -1) Optional[torch.Tensor][source]#

NOTE: We assume that the input tensor is on the same device across all the ranks. NOTE: dst is the local rank of the destination rank.

recv(size: torch.Size, dtype: torch.dtype, src: Optional[int] = None) torch.Tensor[source]#

Receives a tensor from the source rank.

send(tensor: torch.Tensor, dst: Optional[int] = None) None[source]#

Sends a tensor to the destination rank in a non-blocking way

class fastvideo.v1.distributed.device_communicators.base_device_communicator.DistributedAutograd[source]#

Collection of autograd functions for distributed operations.

This class provides custom autograd functions for distributed operations like all_reduce, all_gather, and all_to_all. Each operation is implemented as a static inner class with proper forward and backward implementations.

class AllGather(*args, **kwargs)[source]#

Bases: torch.autograd.Function

Differentiable all_gather operation.

The operation gathers tensors from all ranks and concatenates them along a specified dimension. The backward pass uses reduce_scatter to efficiently distribute gradients back to source ranks.

Initialization

static backward(ctx: Any, grad_output: torch.Tensor) Tuple[None, torch.Tensor, None, None][source]#
static forward(ctx: Any, group: torch.distributed.ProcessGroup, input_: torch.Tensor, world_size: int, dim: int) torch.Tensor[source]#
class AllReduce(*args, **kwargs)[source]#

Bases: torch.autograd.Function

Differentiable all_reduce operation.

The gradient of all_reduce is another all_reduce operation since the operation combines values from all ranks equally.

Initialization

static backward(ctx: Any, grad_output: torch.Tensor) Tuple[None, torch.Tensor, None][source]#
static forward(ctx: Any, group: torch.distributed.ProcessGroup, input_: torch.Tensor, op: Optional[torch.distributed.ReduceOp] = None) torch.Tensor[source]#
class AllToAll4D(*args, **kwargs)[source]#

Bases: torch.autograd.Function

Differentiable all_to_all operation specialized for 4D tensors.

This operation is particularly useful for attention operations where we need to redistribute data across ranks for efficient parallel processing.

The operation supports two modes:

  1. scatter_dim=2, gather_dim=1: Used for redistributing attention heads

  2. scatter_dim=1, gather_dim=2: Used for redistributing sequence dimensions

Initialization

static backward(ctx: Any, grad_output: torch.Tensor) Tuple[None, torch.Tensor, None, None, None][source]#
static forward(ctx: Any, group: torch.distributed.ProcessGroup, input_: torch.Tensor, world_size: int, scatter_dim: int, gather_dim: int) torch.Tensor[source]#