communication_op
¶
Functions¶
fastvideo.distributed.communication_op.sequence_model_parallel_all_gather
¶
sequence_model_parallel_all_gather(input_: Tensor, dim: int = -1) -> Tensor
All-gather the input tensor across model parallel group.
fastvideo.distributed.communication_op.sequence_model_parallel_all_gather_with_unpad
¶
sequence_model_parallel_all_gather_with_unpad(input_: Tensor, original_seq_len: int, dim: int = -1) -> Tensor
All-gather the input tensor and remove padding.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
input_
|
Tensor
|
Sharded (and possibly padded) tensor to gather |
required |
original_seq_len
|
int
|
Original sequence length before padding |
required |
dim
|
int
|
Dimension to gather along (default: -1) |
-1
|
Returns:
| Name | Type | Description |
|---|---|---|
Tensor |
Tensor
|
Gathered and unpadded tensor |
Source code in fastvideo/distributed/communication_op.py
fastvideo.distributed.communication_op.sequence_model_parallel_all_to_all_4D
¶
sequence_model_parallel_all_to_all_4D(input_: Tensor, scatter_dim: int = 2, gather_dim: int = 1) -> Tensor
All-to-all communication of 4D tensors (e.g. QKV matrices) across sequence parallel group.
Source code in fastvideo/distributed/communication_op.py
fastvideo.distributed.communication_op.sequence_model_parallel_shard
¶
Shard the input tensor across model parallel group with optional padding.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
input_
|
Tensor
|
Input tensor to shard |
required |
dim
|
int
|
Dimension to shard along (default: 1) |
1
|
Returns:
| Name | Type | Description |
|---|---|---|
tuple |
tuple[Tensor, int]
|
(sharded_tensor, original_seq_len) - sharded_tensor: The sharded (and possibly padded) tensor - original_seq_len: Original sequence length before padding |
Source code in fastvideo/distributed/communication_op.py
fastvideo.distributed.communication_op.tensor_model_parallel_all_gather
¶
tensor_model_parallel_all_gather(input_: Tensor, dim: int = -1) -> Tensor
All-gather the input tensor across model parallel group.
fastvideo.distributed.communication_op.tensor_model_parallel_all_reduce
¶
All-reduce the input tensor across model parallel group.