Skip to content

communication_op

Functions

fastvideo.distributed.communication_op.sequence_model_parallel_all_gather

sequence_model_parallel_all_gather(input_: Tensor, dim: int = -1) -> Tensor

All-gather the input tensor across model parallel group.

Source code in fastvideo/distributed/communication_op.py
def sequence_model_parallel_all_gather(input_: torch.Tensor,
                                       dim: int = -1) -> torch.Tensor:
    """All-gather the input tensor across model parallel group."""
    return get_sp_group().all_gather(input_, dim)

fastvideo.distributed.communication_op.sequence_model_parallel_all_gather_with_unpad

sequence_model_parallel_all_gather_with_unpad(input_: Tensor, original_seq_len: int, dim: int = -1) -> Tensor

All-gather the input tensor and remove padding.

Parameters:

Name Type Description Default
input_ Tensor

Sharded (and possibly padded) tensor to gather

required
original_seq_len int

Original sequence length before padding

required
dim int

Dimension to gather along (default: -1)

-1

Returns:

Name Type Description
Tensor Tensor

Gathered and unpadded tensor

Source code in fastvideo/distributed/communication_op.py
def sequence_model_parallel_all_gather_with_unpad(
        input_: torch.Tensor,
        original_seq_len: int,
        dim: int = -1) -> torch.Tensor:
    """All-gather the input tensor and remove padding.

    Args:
        input_: Sharded (and possibly padded) tensor to gather
        original_seq_len: Original sequence length before padding
        dim: Dimension to gather along (default: -1)

    Returns:
        Tensor: Gathered and unpadded tensor
    """

    # First gather across all ranks
    gathered = get_sp_group().all_gather(input_, dim)

    current_seq_len = gathered.shape[dim]
    if current_seq_len > original_seq_len:
        gathered = unpad_sequence_tensor(gathered,
                                         original_seq_len,
                                         seq_dim=dim)

    return gathered

fastvideo.distributed.communication_op.sequence_model_parallel_all_to_all_4D

sequence_model_parallel_all_to_all_4D(input_: Tensor, scatter_dim: int = 2, gather_dim: int = 1) -> Tensor

All-to-all communication of 4D tensors (e.g. QKV matrices) across sequence parallel group.

Source code in fastvideo/distributed/communication_op.py
def sequence_model_parallel_all_to_all_4D(input_: torch.Tensor,
                                          scatter_dim: int = 2,
                                          gather_dim: int = 1) -> torch.Tensor:
    """All-to-all communication of 4D tensors (e.g. QKV matrices) across sequence parallel group."""
    return get_sp_group().all_to_all_4D(input_, scatter_dim, gather_dim)

fastvideo.distributed.communication_op.sequence_model_parallel_shard

sequence_model_parallel_shard(input_: Tensor, dim: int = 1) -> tuple[Tensor, int]

Shard the input tensor across model parallel group with optional padding.

Parameters:

Name Type Description Default
input_ Tensor

Input tensor to shard

required
dim int

Dimension to shard along (default: 1)

1

Returns:

Name Type Description
tuple tuple[Tensor, int]

(sharded_tensor, original_seq_len) - sharded_tensor: The sharded (and possibly padded) tensor - original_seq_len: Original sequence length before padding

Source code in fastvideo/distributed/communication_op.py
def sequence_model_parallel_shard(input_: torch.Tensor,
                                  dim: int = 1) -> tuple[torch.Tensor, int]:
    """Shard the input tensor across model parallel group with optional padding.

    Args:
        input_: Input tensor to shard
        dim: Dimension to shard along (default: 1)

    Returns:
        tuple: (sharded_tensor, original_seq_len)
            - sharded_tensor: The sharded (and possibly padded) tensor
            - original_seq_len: Original sequence length before padding
    """

    sp_rank = get_sp_parallel_rank()
    sp_world_size = get_sp_world_size()

    original_seq_len = input_.shape[dim]

    # Compute padding if needed
    padded_seq_len, padding_amount = compute_padding_for_sp(
        original_seq_len, sp_world_size)

    # Pad if necessary
    if padding_amount > 0:
        input_ = pad_sequence_tensor(input_, padded_seq_len, seq_dim=dim)

    elements_per_rank = padded_seq_len // sp_world_size

    # Sharding along dim
    input_ = input_.movedim(dim, 0)
    input_ = input_[sp_rank * elements_per_rank:(sp_rank + 1) *
                    elements_per_rank]
    input_ = input_.movedim(0, dim)

    return input_, original_seq_len

fastvideo.distributed.communication_op.tensor_model_parallel_all_gather

tensor_model_parallel_all_gather(input_: Tensor, dim: int = -1) -> Tensor

All-gather the input tensor across model parallel group.

Source code in fastvideo/distributed/communication_op.py
def tensor_model_parallel_all_gather(input_: torch.Tensor,
                                     dim: int = -1) -> torch.Tensor:
    """All-gather the input tensor across model parallel group."""
    return get_tp_group().all_gather(input_, dim)

fastvideo.distributed.communication_op.tensor_model_parallel_all_reduce

tensor_model_parallel_all_reduce(input_: Tensor) -> Tensor

All-reduce the input tensor across model parallel group.

Source code in fastvideo/distributed/communication_op.py
def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
    """All-reduce the input tensor across model parallel group."""
    return get_tp_group().all_reduce(input_)