fastvideo.v1.distributed.communication_op#

Module Contents#

Functions#

sequence_model_parallel_all_gather

All-gather the input tensor across model parallel group.

sequence_model_parallel_all_to_all_4D

All-to-all communication of 4D tensors (e.g. QKV matrices) across sequence parallel group.

tensor_model_parallel_all_gather

All-gather the input tensor across model parallel group.

tensor_model_parallel_all_reduce

All-reduce the input tensor across model parallel group.

API#

fastvideo.v1.distributed.communication_op.sequence_model_parallel_all_gather(input_: torch.Tensor, dim: int = -1) torch.Tensor[source]#

All-gather the input tensor across model parallel group.

fastvideo.v1.distributed.communication_op.sequence_model_parallel_all_to_all_4D(input_: torch.Tensor, scatter_dim: int = 2, gather_dim: int = 1) torch.Tensor[source]#

All-to-all communication of 4D tensors (e.g. QKV matrices) across sequence parallel group.

fastvideo.v1.distributed.communication_op.tensor_model_parallel_all_gather(input_: torch.Tensor, dim: int = -1) torch.Tensor[source]#

All-gather the input tensor across model parallel group.

fastvideo.v1.distributed.communication_op.tensor_model_parallel_all_reduce(input_: torch.Tensor) torch.Tensor[source]#

All-reduce the input tensor across model parallel group.