fastvideo.v1.distributed.device_communicators.cpu_communicator
#
Module Contents#
Classes#
API#
- class fastvideo.v1.distributed.device_communicators.cpu_communicator.CpuCommunicator(cpu_group: torch.distributed.ProcessGroup, device: torch.device | None = None, device_group: torch.distributed.ProcessGroup | None = None, unique_name: str = '')[source]#
Bases:
fastvideo.v1.distributed.device_communicators.base_device_communicator.DeviceCommunicatorBase
- all_gather(input_: torch.Tensor, dim: int = -1) torch.Tensor [source]#
- all_reduce(input_: torch.Tensor, op: torch.distributed.ReduceOp | None = torch.distributed.ReduceOp.SUM) torch.Tensor [source]#
- gather(input_: torch.Tensor, dst: int = 0, dim: int = -1) torch.Tensor | None [source]#
NOTE: We assume that the input tensor is on the same device across all the ranks. NOTE:
dst
is the local rank of the destination rank.