Skip to content

communication_op

Functions

fastvideo.distributed.communication_op.sequence_model_parallel_all_gather

sequence_model_parallel_all_gather(input_: Tensor, dim: int = -1) -> Tensor

All-gather the input tensor across model parallel group.

Source code in fastvideo/distributed/communication_op.py
def sequence_model_parallel_all_gather(input_: torch.Tensor,
                                       dim: int = -1) -> torch.Tensor:
    """All-gather the input tensor across model parallel group."""
    return get_sp_group().all_gather(input_, dim)

fastvideo.distributed.communication_op.sequence_model_parallel_all_to_all_4D

sequence_model_parallel_all_to_all_4D(input_: Tensor, scatter_dim: int = 2, gather_dim: int = 1) -> Tensor

All-to-all communication of 4D tensors (e.g. QKV matrices) across sequence parallel group.

Source code in fastvideo/distributed/communication_op.py
def sequence_model_parallel_all_to_all_4D(input_: torch.Tensor,
                                          scatter_dim: int = 2,
                                          gather_dim: int = 1) -> torch.Tensor:
    """All-to-all communication of 4D tensors (e.g. QKV matrices) across sequence parallel group."""
    return get_sp_group().all_to_all_4D(input_, scatter_dim, gather_dim)

fastvideo.distributed.communication_op.tensor_model_parallel_all_gather

tensor_model_parallel_all_gather(input_: Tensor, dim: int = -1) -> Tensor

All-gather the input tensor across model parallel group.

Source code in fastvideo/distributed/communication_op.py
def tensor_model_parallel_all_gather(input_: torch.Tensor,
                                     dim: int = -1) -> torch.Tensor:
    """All-gather the input tensor across model parallel group."""
    return get_tp_group().all_gather(input_, dim)

fastvideo.distributed.communication_op.tensor_model_parallel_all_reduce

tensor_model_parallel_all_reduce(input_: Tensor) -> Tensor

All-reduce the input tensor across model parallel group.

Source code in fastvideo/distributed/communication_op.py
def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
    """All-reduce the input tensor across model parallel group."""
    return get_tp_group().all_reduce(input_)