Skip to content

vocab_parallel_embedding

Classes

fastvideo.layers.vocab_parallel_embedding.UnquantizedEmbeddingMethod

Bases: QuantizeMethodBase

Unquantized method for embeddings.

Functions

fastvideo.layers.vocab_parallel_embedding.UnquantizedEmbeddingMethod.create_weights
create_weights(layer: Module, input_size_per_partition: int, output_partition_sizes: list[int], input_size: int, output_size: int, params_dtype: dtype, **extra_weight_attrs)

Create weights for embedding layer.

Source code in fastvideo/layers/vocab_parallel_embedding.py
def create_weights(self, layer: torch.nn.Module,
                   input_size_per_partition: int,
                   output_partition_sizes: list[int], input_size: int,
                   output_size: int, params_dtype: torch.dtype,
                   **extra_weight_attrs):
    """Create weights for embedding layer."""

    weight = Parameter(torch.empty(
        sum(output_partition_sizes),
        input_size_per_partition,
        dtype=params_dtype,
    ),
                       requires_grad=False)
    set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
    layer.register_parameter("weight", weight)
    set_weight_attrs(weight, extra_weight_attrs)

fastvideo.layers.vocab_parallel_embedding.VocabParallelEmbedding

VocabParallelEmbedding(num_embeddings: int, embedding_dim: int, params_dtype: dtype | None = None, org_num_embeddings: int | None = None, padding_size: int = DEFAULT_VOCAB_PADDING_SIZE, quant_config: QuantizationConfig | None = None, prefix: str = '')

Bases: Module

Embedding parallelized in the vocabulary dimension.

Adapted from torch.nn.Embedding, note that we pad the vocabulary size to make sure it is divisible by the number of model parallel GPUs.

In order to support various loading methods, we ensure that LoRA-added embeddings are always at the end of TP-sharded tensors. In other words, we shard base embeddings and LoRA embeddings separately (both padded), and place them in the same tensor. In this example, we will have the original vocab size = 1010, added vocab size = 16 and padding to 64. Therefore, the total vocab size with padding will be 1088 (because we first pad 1010 to 1024, add 16, and then pad to 1088). Therefore, the tensor format looks like the following: TP1, rank 0 (no sharding): |< --------BASE-------- >|< -BASE PADDING-- >|< -----LORA------ >|< -LORA PADDING-- >| corresponding token_id: | 0 | 1 | ... | 1009 | -1 | ... | -1 | 1010 | ... | 1015 | -1 | ... | -1 | index: | 0 | 1 | ... | 1009 | 1010 | ... | 1023 | 1024 | ... | 1039 | 1040 | ... | 1087 |

TP2, rank 0: |< --------------------BASE--------------------- >|< -----LORA------ >|< -LORA PADDING- >| corresponding token_id: | 0 | 1 | 2 | ... | 497 | 498 | ... | 511 | 1000 | ... | 1015 | -1 | ... | -1 | index: | 0 | 1 | 2 | ... | 497 | 498 | ... | 511 | 512 | ... | 527 | 520 | ... | 543 | TP2, rank 1: |< -----------BASE----------- >|< -BASE PADDING- >|< -----------LORA PADDING----------- >| corresponding token_id: | 512 | 513 | 514 | ... | 1009 | -1 | ... | -1 | -1 | ... | -1 | -1 | ... | -1 | index: | 0 | 1 | 2 | ... | 497 | 498 | ... | 511 | 512 | ... | 519 | 520 | ... | 543 |

Parameters:

Name Type Description Default
num_embeddings int

vocabulary size.

required
embedding_dim int

size of hidden state.

required
params_dtype dtype | None

type of the parameters.

None
org_num_embeddings int | None

original vocabulary size (without LoRA).

None
padding_size int

padding size for the vocabulary.

DEFAULT_VOCAB_PADDING_SIZE
quant_config QuantizationConfig | None

quant config for the layer

None
prefix str

full name of the layer in the state dict

''
Source code in fastvideo/layers/vocab_parallel_embedding.py
def __init__(self,
             num_embeddings: int,
             embedding_dim: int,
             params_dtype: torch.dtype | None = None,
             org_num_embeddings: int | None = None,
             padding_size: int = DEFAULT_VOCAB_PADDING_SIZE,
             quant_config: QuantizationConfig | None = None,
             prefix: str = ""):
    super().__init__()

    # Keep the input dimensions.
    tp_rank = get_tp_rank()
    self.tp_size = get_tp_world_size()
    self.num_embeddings = num_embeddings
    self.padding_size = padding_size
    self.org_vocab_size = org_num_embeddings or num_embeddings
    num_added_embeddings = num_embeddings - self.org_vocab_size
    self.org_vocab_size_padded = pad_vocab_size(self.org_vocab_size,
                                                self.padding_size)
    self.num_embeddings_padded = pad_vocab_size(
        self.org_vocab_size_padded + num_added_embeddings,
        self.padding_size)
    assert self.org_vocab_size_padded <= self.num_embeddings_padded

    self.shard_indices = self._get_indices(self.num_embeddings_padded,
                                           self.org_vocab_size_padded,
                                           self.num_embeddings,
                                           self.org_vocab_size, tp_rank,
                                           self.tp_size)
    self.embedding_dim = embedding_dim

    quant_method = None
    if quant_config is not None:
        quant_method = quant_config.get_quant_method(self, prefix=prefix)
    if quant_method is None:
        quant_method = UnquantizedEmbeddingMethod()

    # If we are making an embedding layer, then our quantization linear
    # method must implement the embedding operation. If we are another
    # layer type like ParallelLMHead, this is not important.
    is_embedding_layer = type(self.__class__) is VocabParallelEmbedding
    quant_method_implements_embedding = method_has_implemented_embedding(
        type(quant_method))
    if is_embedding_layer and not quant_method_implements_embedding:
        raise NotImplementedError(
            f"The class {type(quant_method).__name__} must implement "
            "the 'embedding' method, see UnquantizedEmbeddingMethod.")

    self.quant_method: QuantizeMethodBase = quant_method

    if params_dtype is None:
        params_dtype = torch.get_default_dtype()
    # Divide the weight matrix along the vocaburaly dimension.
    self.num_added_embeddings = self.num_embeddings - self.org_vocab_size
    self.num_embeddings_per_partition = divide(self.num_embeddings_padded,
                                               self.tp_size)
    assert (self.shard_indices.num_elements_padded ==
            self.num_embeddings_per_partition)
    self.num_org_embeddings_per_partition = (
        self.shard_indices.org_vocab_end_index -
        self.shard_indices.org_vocab_start_index)
    self.num_added_embeddings_per_partition = (
        self.shard_indices.added_vocab_end_index -
        self.shard_indices.added_vocab_start_index)

    self.quant_method.create_weights(self,
                                     self.embedding_dim,
                                     [self.num_embeddings_per_partition],
                                     self.embedding_dim,
                                     self.num_embeddings_padded,
                                     params_dtype=params_dtype,
                                     weight_loader=self.weight_loader)

Functions

fastvideo.layers.vocab_parallel_embedding.VocabParallelEmbedding.get_sharded_to_full_mapping
get_sharded_to_full_mapping() -> list[int] | None

Get a mapping that can be used to reindex the gathered logits for sampling.

During sampling, we gather logits from all ranks. The relationship of index->token_id will follow the same format as outlined in the class docstring. However, after the gather, we want to reindex the final logits tensor to map index->token_id one-to-one (the index is always equal the token_id it corresponds to). The indices returned by this method allow us to do that.

Source code in fastvideo/layers/vocab_parallel_embedding.py
def get_sharded_to_full_mapping(self) -> list[int] | None:
    """Get a mapping that can be used to reindex the gathered
    logits for sampling.

    During sampling, we gather logits from all ranks. The relationship
    of index->token_id will follow the same format as outlined in the class
    docstring. However, after the gather, we want to reindex the final
    logits tensor to map index->token_id one-to-one (the index is always
    equal the token_id it corresponds to). The indices returned by this
    method allow us to do that.
    """
    if self.tp_size < 2:
        return None

    base_embeddings: list[int] = []
    added_embeddings: list[int] = []
    padding: list[int] = []
    for tp_rank in range(self.tp_size):
        shard_indices = self._get_indices(self.num_embeddings_padded,
                                          self.org_vocab_size_padded,
                                          self.num_embeddings,
                                          self.org_vocab_size, tp_rank,
                                          self.tp_size)
        range_start = self.num_embeddings_per_partition * tp_rank
        range_end = self.num_embeddings_per_partition * (tp_rank + 1)
        base_embeddings.extend(
            range(range_start,
                  range_start + shard_indices.num_org_elements))
        padding.extend(
            range(range_start + shard_indices.num_org_elements,
                  range_start + shard_indices.num_org_elements_padded))
        added_embeddings.extend(
            range(
                range_start + shard_indices.num_org_elements_padded,
                range_start + shard_indices.num_org_elements_padded +
                shard_indices.num_added_elements))
        padding.extend(
            range(
                range_start + shard_indices.num_org_elements_padded +
                shard_indices.num_added_elements,
                range_start + shard_indices.num_org_elements_padded +
                shard_indices.num_added_elements_padded))
        assert (range_start + shard_indices.num_org_elements_padded +
                shard_indices.num_added_elements_padded == range_end)
    ret = base_embeddings + added_embeddings + padding
    assert len(ret) == self.num_embeddings_padded
    return ret

fastvideo.layers.vocab_parallel_embedding.VocabParallelEmbeddingShardIndices dataclass

VocabParallelEmbeddingShardIndices(padded_org_vocab_start_index: int, padded_org_vocab_end_index: int, padded_added_vocab_start_index: int, padded_added_vocab_end_index: int, org_vocab_start_index: int, org_vocab_end_index: int, added_vocab_start_index: int, added_vocab_end_index: int)

Indices for a shard of a vocab parallel embedding.

Functions

fastvideo.layers.vocab_parallel_embedding.pad_vocab_size

pad_vocab_size(vocab_size: int, pad_to: int = DEFAULT_VOCAB_PADDING_SIZE) -> int

Pad the vocab size to the given value.

Source code in fastvideo/layers/vocab_parallel_embedding.py
def pad_vocab_size(vocab_size: int,
                   pad_to: int = DEFAULT_VOCAB_PADDING_SIZE) -> int:
    """Pad the vocab size to the given value."""
    return ((vocab_size + pad_to - 1) // pad_to) * pad_to