Skip to content

layers

Modules

fastvideo.layers.activation

Custom activation functions.

Classes

fastvideo.layers.activation.GeluAndMul
GeluAndMul(approximate: str = 'none')

Bases: CustomOp

An activation function for GeGLU.

The function computes x -> GELU(x[:d]) * x[d:] where d = x.shape[-1] // 2.

Shapes

x: (batch_size, seq_len, 2 * d) or (num_tokens, 2 * d) return: (batch_size, seq_len, d) or (num_tokens, d)

Source code in fastvideo/layers/activation.py
def __init__(self, approximate: str = "none"):
    super().__init__()
    self.approximate = approximate
    if approximate not in ("none", "tanh"):
        raise ValueError(f"Unknown approximate mode: {approximate}")
Functions
fastvideo.layers.activation.GeluAndMul.forward_native
forward_native(x: Tensor) -> Tensor

PyTorch-native implementation equivalent to forward().

Source code in fastvideo/layers/activation.py
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
    """PyTorch-native implementation equivalent to forward()."""
    d = x.shape[-1] // 2
    return F.gelu(x[..., :d], approximate=self.approximate) * x[..., d:]
fastvideo.layers.activation.NewGELU
NewGELU()

Bases: CustomOp

Source code in fastvideo/layers/activation.py
def __init__(self):
    super().__init__()
Functions
fastvideo.layers.activation.NewGELU.forward_native
forward_native(x: Tensor) -> Tensor

PyTorch-native implementation equivalent to forward().

Source code in fastvideo/layers/activation.py
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
    """PyTorch-native implementation equivalent to forward()."""
    c = math.sqrt(2.0 / math.pi)
    return 0.5 * x * (1.0 + torch.tanh(c *
                                       (x + 0.044715 * torch.pow(x, 3.0))))
fastvideo.layers.activation.QuickGELU
QuickGELU()

Bases: CustomOp

Source code in fastvideo/layers/activation.py
def __init__(self):
    super().__init__()
Functions
fastvideo.layers.activation.QuickGELU.forward_native
forward_native(x: Tensor) -> Tensor

PyTorch-native implementation equivalent to forward().

Source code in fastvideo/layers/activation.py
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
    """PyTorch-native implementation equivalent to forward()."""
    return x * torch.sigmoid(1.702 * x)
fastvideo.layers.activation.SiluAndMul
SiluAndMul()

Bases: CustomOp

An activation function for SwiGLU.

The function computes x -> silu(x[:d]) * x[d:] where d = x.shape[-1] // 2.

Shapes

x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d) return: (num_tokens, d) or (batch_size, seq_len, d)

Source code in fastvideo/layers/activation.py
def __init__(self) -> None:
    super().__init__()
Functions
fastvideo.layers.activation.SiluAndMul.forward_native
forward_native(x: Tensor) -> Tensor

PyTorch-native implementation equivalent to forward().

Source code in fastvideo/layers/activation.py
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
    """PyTorch-native implementation equivalent to forward()."""
    d = x.shape[-1] // 2
    return F.silu(x[..., :d]) * x[..., d:]

Functions

fastvideo.layers.activation.get_act_and_mul_fn
get_act_and_mul_fn(act_fn_name: str) -> Module

Get an activation-and-mul (i.e. SiluAndMul) function by name.

Source code in fastvideo/layers/activation.py
def get_act_and_mul_fn(act_fn_name: str) -> nn.Module:
    """Get an activation-and-mul (i.e. SiluAndMul) function by name."""
    act_fn_name = act_fn_name.lower()
    if act_fn_name not in _ACTIVATION_AND_MUL_REGISTRY:
        raise ValueError(
            f"Activation function {act_fn_name!r} is not supported.")

    return _ACTIVATION_AND_MUL_REGISTRY[act_fn_name]()
fastvideo.layers.activation.get_act_fn
get_act_fn(act_fn_name: str) -> Module

Get an activation function by name.

Source code in fastvideo/layers/activation.py
def get_act_fn(act_fn_name: str) -> nn.Module:
    """Get an activation function by name."""
    act_fn_name = act_fn_name.lower()
    if act_fn_name not in _ACTIVATION_REGISTRY:
        raise ValueError(
            f"Activation function {act_fn_name!r} is not supported.")

    return _ACTIVATION_REGISTRY[act_fn_name]()

fastvideo.layers.custom_op

Classes

fastvideo.layers.custom_op.CustomOp
CustomOp()

Bases: Module

Base class for custom ops. Dispatches the forward method to the appropriate backend.

Source code in fastvideo/layers/custom_op.py
def __init__(self) -> None:
    super().__init__()
    self._forward_method = self.dispatch_forward()
Functions
fastvideo.layers.custom_op.CustomOp.default_on staticmethod
default_on() -> bool

On by default if level < CompilationLevel.PIECEWISE Specifying 'all' or 'none' in custom_op takes precedence.

Source code in fastvideo/layers/custom_op.py
@staticmethod
def default_on() -> bool:
    """
    On by default if level < CompilationLevel.PIECEWISE
    Specifying 'all' or 'none' in custom_op takes precedence.
    """
    raise NotImplementedError
fastvideo.layers.custom_op.CustomOp.forward_native
forward_native(*args, **kwargs) -> Any

PyTorch-native implementation of the forward method. This method is optional. If implemented, it can be used with compilers such as torch.compile or PyTorch XLA. Also, it can be used for testing purposes.

Source code in fastvideo/layers/custom_op.py
def forward_native(self, *args, **kwargs) -> Any:
    """PyTorch-native implementation of the forward method.
    This method is optional. If implemented, it can be used with compilers
    such as torch.compile or PyTorch XLA. Also, it can be used for testing
    purposes.
    """
    raise NotImplementedError

Functions

fastvideo.layers.layernorm

Custom normalization layers.

Classes

fastvideo.layers.layernorm.LayerNormScaleShift
LayerNormScaleShift(hidden_size: int, norm_type: str = 'rms', eps: float = 1e-06, elementwise_affine: bool = False, dtype: dtype = float32, compute_dtype: dtype | None = None, prefix: str = '')

Bases: Module

Fused operation that combines LayerNorm with scale and shift operations. This reduces memory bandwidth by combining memory-bound operations.

Source code in fastvideo/layers/layernorm.py
def __init__(
    self,
    hidden_size: int,
    norm_type: str = "rms",
    eps: float = 1e-6,
    elementwise_affine: bool = False,
    dtype: torch.dtype = torch.float32,
    compute_dtype: torch.dtype | None = None,
    prefix: str = "",
):
    super().__init__()
    self.compute_dtype = compute_dtype
    if norm_type == "rms":
        self.norm = RMSNorm(hidden_size,
                            has_weight=elementwise_affine,
                            eps=eps)
    elif norm_type == "layer":
        if self.compute_dtype == torch.float32:
            self.norm = FP32LayerNorm(hidden_size,
                                      elementwise_affine=elementwise_affine,
                                      eps=eps)
        else:
            self.norm = nn.LayerNorm(hidden_size,
                                     elementwise_affine=elementwise_affine,
                                     eps=eps,
                                     dtype=dtype)
    else:
        raise NotImplementedError(f"Norm type {norm_type} not implemented")
Functions
fastvideo.layers.layernorm.LayerNormScaleShift.forward
forward(x: Tensor, shift: Tensor, scale: Tensor) -> Tensor

Apply ln followed by scale and shift in a single fused operation.

Source code in fastvideo/layers/layernorm.py
def forward(self, x: torch.Tensor, shift: torch.Tensor,
            scale: torch.Tensor) -> torch.Tensor:
    """Apply ln followed by scale and shift in a single fused operation."""
    # x.shape: [batch_size, seq_len, inner_dim]
    normalized = self.norm(x)
    if self.compute_dtype == torch.float32:
        normalized = normalized.float()

    if scale.dim() == 4:
        # scale.shape: [batch_size, num_frames, 1, inner_dim]
        num_frames = scale.shape[1]
        frame_seqlen = normalized.shape[1] // num_frames
        output = (
            normalized.unflatten(dim=1, sizes=(num_frames, frame_seqlen)) *
            (1.0 + scale) + shift).flatten(1, 2)
    else:
        # scale.shape: [batch_size, 1, inner_dim]
        # shift.shape: [batch_size, 1, inner_dim]
        output = normalized * (1.0 + scale) + shift

    if self.compute_dtype == torch.float32:
        output = output.to(x.dtype)

    return output
fastvideo.layers.layernorm.RMSNorm
RMSNorm(hidden_size: int, eps: float = 1e-06, dtype: dtype = float32, var_hidden_size: int | None = None, has_weight: bool = True)

Bases: CustomOp

Root mean square normalization.

Computes x -> w * x / sqrt(E[x^2] + eps) where w is the learned weight. Refer to https://arxiv.org/abs/1910.07467

Source code in fastvideo/layers/layernorm.py
def __init__(
    self,
    hidden_size: int,
    eps: float = 1e-6,
    dtype: torch.dtype = torch.float32,
    var_hidden_size: int | None = None,
    has_weight: bool = True,
) -> None:
    super().__init__()

    self.hidden_size = hidden_size
    self.variance_epsilon = eps
    self.variance_size_override = (None if var_hidden_size == hidden_size
                                   else var_hidden_size)
    self.has_weight = has_weight

    from fastvideo.platforms import current_platform

    self.weight = torch.ones(hidden_size) if current_platform.is_cuda_alike(
    ) else torch.ones(hidden_size, dtype=dtype)
    if self.has_weight:
        self.weight = nn.Parameter(self.weight)
Functions
fastvideo.layers.layernorm.RMSNorm.forward_native
forward_native(x: Tensor, residual: Tensor | None = None) -> Tensor | tuple[Tensor, Tensor]

PyTorch-native implementation equivalent to forward().

Source code in fastvideo/layers/layernorm.py
def forward_native(
    self,
    x: torch.Tensor,
    residual: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
    """PyTorch-native implementation equivalent to forward()."""
    orig_dtype = x.dtype
    x = x.to(torch.float32)
    if residual is not None:
        x = x + residual.to(torch.float32)
        residual = x.to(orig_dtype)

    hidden_size = x.shape[-1]
    if hidden_size != self.hidden_size:
        raise ValueError("Expected hidden_size to be "
                         f"{self.hidden_size}, but found: {hidden_size}")

    if self.variance_size_override is None:
        x_var = x
    else:
        if hidden_size < self.variance_size_override:
            raise ValueError(
                "Expected hidden_size to be at least "
                f"{self.variance_size_override}, but found: {hidden_size}")

        x_var = x[:, :, :self.variance_size_override]

    variance = x_var.pow(2).mean(dim=-1, keepdim=True)

    x = x * torch.rsqrt(variance + self.variance_epsilon)
    x = x.to(orig_dtype)
    if self.has_weight:
        x = x * self.weight
    if residual is None:
        return x
    else:
        return x, residual
fastvideo.layers.layernorm.ScaleResidual
ScaleResidual(prefix: str = '')

Bases: Module

Applies gated residual connection.

Source code in fastvideo/layers/layernorm.py
def __init__(self, prefix: str = ""):
    super().__init__()
Functions
fastvideo.layers.layernorm.ScaleResidual.forward
forward(residual: Tensor, x: Tensor, gate: Tensor) -> Tensor

Apply gated residual connection.

Source code in fastvideo/layers/layernorm.py
def forward(self, residual: torch.Tensor, x: torch.Tensor,
            gate: torch.Tensor) -> torch.Tensor:
    """Apply gated residual connection."""
    # x.shape: [batch_size, seq_len, inner_dim]
    if gate.dim() == 4:
        # gate.shape: [batch_size, num_frames, 1, inner_dim]
        num_frames = gate.shape[1]
        frame_seqlen = x.shape[1] // num_frames
        return residual + (x.unflatten(
            dim=1, sizes=(num_frames, frame_seqlen)) * gate).flatten(1, 2)
    else:
        # gate.shape: [batch_size, 1, inner_dim]
        return residual + x * gate
fastvideo.layers.layernorm.ScaleResidualLayerNormScaleShift
ScaleResidualLayerNormScaleShift(hidden_size: int, norm_type: str = 'rms', eps: float = 1e-06, elementwise_affine: bool = False, dtype: dtype = float32, compute_dtype: dtype | None = None, prefix: str = '')

Bases: Module

Fused operation that combines: 1. Gated residual connection 2. LayerNorm 3. Scale and shift operations

This reduces memory bandwidth by combining memory-bound operations.

Source code in fastvideo/layers/layernorm.py
def __init__(
    self,
    hidden_size: int,
    norm_type: str = "rms",
    eps: float = 1e-6,
    elementwise_affine: bool = False,
    dtype: torch.dtype = torch.float32,
    compute_dtype: torch.dtype | None = None,
    prefix: str = "",
):
    super().__init__()
    if norm_type == "rms":
        self.norm = RMSNorm(hidden_size,
                            has_weight=elementwise_affine,
                            eps=eps,
                            dtype=dtype)
    elif norm_type == "layer":
        if compute_dtype == torch.float32:
            self.norm = FP32LayerNorm(hidden_size,
                                      elementwise_affine=elementwise_affine,
                                      eps=eps)
        else:
            self.norm = nn.LayerNorm(hidden_size,
                                     elementwise_affine=elementwise_affine,
                                     eps=eps,
                                     dtype=dtype)
    else:
        raise NotImplementedError(f"Norm type {norm_type} not implemented")
Functions
fastvideo.layers.layernorm.ScaleResidualLayerNormScaleShift.forward
forward(residual: Tensor, x: Tensor, gate: Tensor | int, shift: Tensor, scale: Tensor) -> tuple[Tensor, Tensor]

Apply gated residual connection, followed by layernorm and scale/shift in a single fused operation.

Returns:

Type Description
Tensor

Tuple containing:

Tensor
  • normalized and modulated output
tuple[Tensor, Tensor]
  • residual value (value after residual connection but before normalization)
Source code in fastvideo/layers/layernorm.py
def forward(self, residual: torch.Tensor, x: torch.Tensor,
            gate: torch.Tensor | int, shift: torch.Tensor,
            scale: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
    """
    Apply gated residual connection, followed by layernorm and 
    scale/shift in a single fused operation.

    Returns:
        Tuple containing:
        - normalized and modulated output
        - residual value (value after residual connection 
          but before normalization)
    """
    # x.shape: [batch_size, seq_len, inner_dim]
    # Apply residual connection with gating
    if isinstance(gate, int):
        # used by cross-attention, should be 1
        assert gate == 1
        residual_output = residual + x
    elif isinstance(gate, torch.Tensor):
        if gate.dim() == 4:
            # gate.shape: [batch_size, num_frames, 1, inner_dim]
            num_frames = gate.shape[1]
            frame_seqlen = x.shape[1] // num_frames
            residual_output = residual + (
                x.unflatten(dim=1, sizes=(num_frames, frame_seqlen)) *
                gate).flatten(1, 2)
        else:
            # used by bidirectional self attention
            # gate.shape: [batch_size, 1, inner_dim]
            residual_output = residual + x * gate
    else:
        raise ValueError(f"Gate type {type(gate)} not supported")
    # residual_output.shape: [batch_size, seq_len, inner_dim]

    # Apply normalization
    normalized = self.norm(residual_output)
    # Apply scale and shift
    if isinstance(scale, torch.Tensor) and scale.dim() == 4:
        # scale.shape: [batch_size, num_frames, 1, inner_dim]
        # shift.shape: [batch_size, num_frames, 1, inner_dim]
        num_frames = scale.shape[1]
        frame_seqlen = normalized.shape[1] // num_frames
        modulated = (
            normalized.unflatten(dim=1, sizes=(num_frames, frame_seqlen)) *
            (1.0 + scale) + shift).flatten(1, 2)
    else:
        modulated = normalized * (1.0 + scale) + shift
    return modulated, residual_output

fastvideo.layers.linear

Classes

fastvideo.layers.linear.ColumnParallelLinear
ColumnParallelLinear(input_size: int, output_size: int, bias: bool = True, gather_output: bool = False, skip_bias_add: bool = False, params_dtype: dtype | None = None, quant_config: QuantizationConfig | None = None, output_sizes: list[int] | None = None, prefix: str = '')

Bases: LinearBase

Linear layer with column parallelism.

The linear layer is defined as Y = XA + b. A is parallelized along its second dimension as A = [A_1, ..., A_p].

Parameters:

Name Type Description Default
input_size int

first dimension of matrix A.

required
output_size int

second dimension of matrix A.

required
bias bool

If true, add bias.

True
gather_output bool

If true, call all-gather on output and make Y available to all GPUs, otherwise, every GPU will have its output which is Y_i = XA_i

False
skip_bias_add bool

This was added to enable performance optimizations where bias can be fused with other element-wise operations. we skip adding bias but instead return it.

False
params_dtype dtype | None

Data type for the parameters.

None
quant_config QuantizationConfig | None

Quantization configure.

None
output_sizes list[int] | None

list of output sizes packed into one output, like for QKV the list would be size 3.

None
prefix str

The name of the layer in the state dict, including all parents (e.g. model.layers.0.qkv_proj)

''
Source code in fastvideo/layers/linear.py
def __init__(self,
             input_size: int,
             output_size: int,
             bias: bool = True,
             gather_output: bool = False,
             skip_bias_add: bool = False,
             params_dtype: torch.dtype | None = None,
             quant_config: QuantizationConfig | None = None,
             output_sizes: list[int] | None = None,
             prefix: str = ""):
    # Divide the weight matrix along the last dimension.
    self.tp_size = get_tp_world_size()
    self.input_size_per_partition = input_size
    self.output_size_per_partition = divide(output_size, self.tp_size)
    self.output_partition_sizes = [self.output_size_per_partition]
    # If QKV or MergedColumn, use output size of each partition.
    if hasattr(self, "output_sizes"):
        self.output_partition_sizes = [
            divide(output_size, self.tp_size)
            for output_size in self.output_sizes
        ]

    super().__init__(input_size, output_size, skip_bias_add, params_dtype,
                     quant_config, prefix)

    self.gather_output = gather_output

    if output_sizes is None:
        output_sizes = [output_size]

    assert self.quant_method is not None
    self.quant_method.create_weights(
        layer=self,
        input_size_per_partition=self.input_size_per_partition,
        output_partition_sizes=self.output_partition_sizes,
        input_size=self.input_size,
        output_size=self.output_size,
        params_dtype=self.params_dtype,
        weight_loader=(
            self.weight_loader_v2 if self.quant_method.__class__.__name__
            in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader))
    if bias:
        self.bias = Parameter(
            torch.empty(
                self.output_size_per_partition,
                dtype=params_dtype,
            ))
        set_weight_attrs(self.bias, {
            "output_dim": 0,
            "weight_loader": self.weight_loader,
        })
    else:
        self.register_parameter("bias", None)
fastvideo.layers.linear.LinearBase
LinearBase(input_size: int, output_size: int, skip_bias_add: bool = False, params_dtype: dtype | None = None, quant_config: QuantizationConfig | None = None, prefix: str = '')

Bases: Module

Base linear layer.

Parameters:

Name Type Description Default
input_size int

input dimension of the linear layer.

required
output_size int

output dimension of the linear layer.

required
bias

If true, add bias.

required
skip_bias_add bool

If true, skip adding bias but instead return it.

False
params_dtype dtype | None

Data type for the parameters.

None
quant_config QuantizationConfig | None

Quantization configure.

None
Source code in fastvideo/layers/linear.py
def __init__(
    self,
    input_size: int,
    output_size: int,
    skip_bias_add: bool = False,
    params_dtype: torch.dtype | None = None,
    quant_config: QuantizationConfig | None = None,
    prefix: str = "",
):
    super().__init__()

    # Keep input parameters
    self.input_size = input_size
    self.output_size = output_size
    self.skip_bias_add = skip_bias_add
    if params_dtype is None:
        params_dtype = torch.get_default_dtype()
    self.params_dtype = params_dtype
    self.quant_config = quant_config
    self.prefix = prefix
    if quant_config is None:
        self.quant_method: QuantizeMethodBase | None = UnquantizedLinearMethod(
        )
    else:
        self.quant_method = quant_config.get_quant_method(self,
                                                          prefix=prefix)
fastvideo.layers.linear.LinearMethodBase

Bases: QuantizeMethodBase

Base class for different (maybe quantized) linear methods.

Functions
fastvideo.layers.linear.LinearMethodBase.apply abstractmethod
apply(layer: Module, x: Tensor, bias: Tensor | None = None) -> Tensor

Apply the weights in layer to the input tensor. Expects create_weights to have been called before on the layer.

Source code in fastvideo/layers/linear.py
@abstractmethod
def apply(self,
          layer: torch.nn.Module,
          x: torch.Tensor,
          bias: torch.Tensor | None = None) -> torch.Tensor:
    """Apply the weights in layer to the input tensor.
    Expects create_weights to have been called before on the layer."""
    raise NotImplementedError
fastvideo.layers.linear.LinearMethodBase.create_weights abstractmethod
create_weights(layer: Module, input_size_per_partition: int, output_partition_sizes: list[int], input_size: int, output_size: int, params_dtype: dtype, **extra_weight_attrs) -> None

Create weights for a linear layer. The weights will be set as attributes of the layer.

Parameters:

Name Type Description Default
layer Module

The layer that is using the LinearMethodBase factory.

required
input_size_per_partition int

Size of the weight input dim on rank X.

required
output_partition_sizes list[int]

Sizes of the output dim of each logical weight on rank X. E.g., output_partition_sizes for QKVLinear is a list contains the width of Wq, Wk, Wv on rank X.

required
input_size int

Size of the input dim of the weight across all ranks.

required
output_size int

Size of the output dim of the weight across all ranks.

required
params_dtype dtype

Datatype of the parameters.

required
Source code in fastvideo/layers/linear.py
@abstractmethod
def create_weights(self, layer: torch.nn.Module,
                   input_size_per_partition: int,
                   output_partition_sizes: list[int], input_size: int,
                   output_size: int, params_dtype: torch.dtype,
                   **extra_weight_attrs) -> None:
    """Create weights for a linear layer. 
       The weights will be set as attributes of the layer.

    Args:
        layer: The layer that is using the LinearMethodBase factory.
        input_size_per_partition: Size of the weight input dim on rank X.
        output_partition_sizes: Sizes of the output dim of each logical 
            weight on rank X. E.g., output_partition_sizes for QKVLinear
            is a list contains the width of Wq, Wk, Wv on rank X.
        input_size: Size of the input dim of the weight across all ranks.
        output_size: Size of the output dim of the weight across all ranks.
        params_dtype: Datatype of the parameters.
    """
    raise NotImplementedError
fastvideo.layers.linear.MergedColumnParallelLinear
MergedColumnParallelLinear(input_size: int, output_sizes: list[int], bias: bool = True, gather_output: bool = False, skip_bias_add: bool = False, params_dtype: dtype | None = None, quant_config: QuantizationConfig | None = None, prefix: str = '')

Bases: ColumnParallelLinear

Packed linear layers with column parallelism.

Similar to ColumnParallelLinear, but the weight matrix is concatenated along the output dimension. When the weight matrix is loaded, the different partitions are sharded separately.

Parameters:

Name Type Description Default
input_size int

input dimension of the linear layer.

required
output_sizes list[int]

list of output dimensions of the linear layer.

required
bias bool

If true, add bias.

True
gather_output bool

If true, call all-gather on output and make the output available to all GPUs, otherwise, every GPU will have its own output.

False
skip_bias_add bool

This was added to enable performance optimizations where bias can be fused with other element-wise operations. we skip adding bias but instead return it.

False
params_dtype dtype | None

Data type for the parameters.

None
quant_config QuantizationConfig | None

Quantization configure.

None
prefix str

The name of the layer in the state dict, including all parents (e.g. model.layers.0.qkv_proj)

''
Source code in fastvideo/layers/linear.py
def __init__(self,
             input_size: int,
             output_sizes: list[int],
             bias: bool = True,
             gather_output: bool = False,
             skip_bias_add: bool = False,
             params_dtype: torch.dtype | None = None,
             quant_config: QuantizationConfig | None = None,
             prefix: str = ""):
    self.output_sizes = output_sizes
    tp_size = get_tp_world_size()
    assert all(output_size % tp_size == 0 for output_size in output_sizes)
    super().__init__(input_size=input_size,
                     output_size=sum(output_sizes),
                     bias=bias,
                     gather_output=gather_output,
                     skip_bias_add=skip_bias_add,
                     params_dtype=params_dtype,
                     quant_config=quant_config,
                     prefix=prefix)
fastvideo.layers.linear.QKVParallelLinear
QKVParallelLinear(hidden_size: int, head_size: int, total_num_heads: int, total_num_kv_heads: int | None = None, bias: bool = True, skip_bias_add: bool = False, params_dtype: dtype | None = None, quant_config: QuantizationConfig | None = None, prefix: str = '')

Bases: ColumnParallelLinear

Linear layers for the attention's QKV transformation.

Linear layers for the linear transformation of the query, key, and value vectors in the attention layer. The weight matrix is concatenated along the output dimension. The layer is parallelized along the head dimension. When the number of key/value heads is smaller than the number of query heads (e.g., multi-query/grouped-query attention), the key/value head may be replicated while the query heads are partitioned.

Parameters:

Name Type Description Default
hidden_size int

input hidden state size of the transformer.

required
head_size int

size of each attention head.

required
total_num_heads int

total number of attention query heads.

required
total_num_kv_heads int | None

total number of attention key/value heads. If None, assume total_num_kv_heads = total_num_heads.

None
bias bool

If true, add bias.

True
skip_bias_add bool

This was added to enable performance optimizations where bias can be fused with other element-wise operations. we skip adding bias but instead return it.

False
params_dtype dtype | None

Data type for the parameters.

None
quant_config QuantizationConfig | None

Quantization configure.

None
prefix str

The name of the layer in the state dict, including all parents (e.g. model.layers.0.qkv_proj)

''
Source code in fastvideo/layers/linear.py
def __init__(self,
             hidden_size: int,
             head_size: int,
             total_num_heads: int,
             total_num_kv_heads: int | None = None,
             bias: bool = True,
             skip_bias_add: bool = False,
             params_dtype: torch.dtype | None = None,
             quant_config: QuantizationConfig | None = None,
             prefix: str = ""):
    self.hidden_size = hidden_size
    self.head_size = head_size
    self.total_num_heads = total_num_heads
    if total_num_kv_heads is None:
        total_num_kv_heads = total_num_heads
    self.total_num_kv_heads = total_num_kv_heads
    # Divide the weight matrix along the last dimension.
    tp_size = get_tp_world_size()
    self.num_heads = divide(self.total_num_heads, tp_size)
    if tp_size >= self.total_num_kv_heads:
        self.num_kv_heads = 1
        self.num_kv_head_replicas = divide(tp_size, self.total_num_kv_heads)
    else:
        self.num_kv_heads = divide(self.total_num_kv_heads, tp_size)
        self.num_kv_head_replicas = 1
    input_size = self.hidden_size
    output_size = (self.num_heads +
                   2 * self.num_kv_heads) * tp_size * self.head_size
    self.output_sizes = [
        self.num_heads * self.head_size * tp_size,  # q_proj
        self.num_kv_heads * self.head_size * tp_size,  # k_proj
        self.num_kv_heads * self.head_size * tp_size,  # v_proj 
    ]

    super().__init__(input_size=input_size,
                     output_size=output_size,
                     bias=bias,
                     gather_output=False,
                     skip_bias_add=skip_bias_add,
                     params_dtype=params_dtype,
                     quant_config=quant_config,
                     prefix=prefix)
fastvideo.layers.linear.ReplicatedLinear
ReplicatedLinear(input_size: int, output_size: int, bias: bool = True, skip_bias_add: bool = False, params_dtype: dtype | None = None, quant_config: QuantizationConfig | None = None, prefix: str = '')

Bases: LinearBase

Replicated linear layer.

Parameters:

Name Type Description Default
input_size int

input dimension of the linear layer.

required
output_size int

output dimension of the linear layer.

required
bias bool

If true, add bias.

True
skip_bias_add bool

If true, skip adding bias but instead return it.

False
params_dtype dtype | None

Data type for the parameters.

None
quant_config QuantizationConfig | None

Quantization configure.

None
prefix str

The name of the layer in the state dict, including all parents (e.g. model.layers.0.qkv_proj)

''
Source code in fastvideo/layers/linear.py
def __init__(self,
             input_size: int,
             output_size: int,
             bias: bool = True,
             skip_bias_add: bool = False,
             params_dtype: torch.dtype | None = None,
             quant_config: QuantizationConfig | None = None,
             prefix: str = ""):
    super().__init__(input_size,
                     output_size,
                     skip_bias_add,
                     params_dtype,
                     quant_config,
                     prefix=prefix)

    # All the linear layer supports quant method.
    assert self.quant_method is not None
    self.quant_method.create_weights(self,
                                     self.input_size, [self.output_size],
                                     self.input_size,
                                     self.output_size,
                                     self.params_dtype,
                                     weight_loader=self.weight_loader)

    if bias:
        self.bias = Parameter(
            torch.empty(
                self.output_size,
                dtype=self.params_dtype,
            ))
        set_weight_attrs(self.bias, {
            "output_dim": 0,
            "weight_loader": self.weight_loader,
        })
    else:
        self.register_parameter("bias", None)
fastvideo.layers.linear.RowParallelLinear
RowParallelLinear(input_size: int, output_size: int, bias: bool = True, input_is_parallel: bool = True, skip_bias_add: bool = False, params_dtype: dtype | None = None, reduce_results: bool = True, quant_config: QuantizationConfig | None = None, prefix: str = '')

Bases: LinearBase

Linear layer with row parallelism.

The linear layer is defined as Y = XA + b. A is parallelized along its first dimension and X along its second dimension as: - - | A_1 | | . | A = | . | X = [X_1, ..., X_p] | . | | A_p | - - Arguments: input_size: first dimension of matrix A. output_size: second dimension of matrix A. bias: If true, add bias. Note that bias is not parallelized. input_is_parallel: If true, we assume that the input is already split across the GPUs and we do not split again. skip_bias_add: This was added to enable performance optimization where bias can be fused with other element-wise operations. We skip adding bias but instead return it. params_dtype: Data type for the parameters. quant_config: Quantization configure.

Source code in fastvideo/layers/linear.py
def __init__(self,
             input_size: int,
             output_size: int,
             bias: bool = True,
             input_is_parallel: bool = True,
             skip_bias_add: bool = False,
             params_dtype: torch.dtype | None = None,
             reduce_results: bool = True,
             quant_config: QuantizationConfig | None = None,
             prefix: str = ""):
    # Divide the weight matrix along the first dimension.
    self.tp_rank = get_tp_rank()
    self.tp_size = get_tp_world_size()
    self.input_size_per_partition = divide(input_size, self.tp_size)
    self.output_size_per_partition = output_size
    self.output_partition_sizes = [output_size]

    super().__init__(input_size, output_size, skip_bias_add, params_dtype,
                     quant_config, prefix)

    self.input_is_parallel = input_is_parallel
    self.reduce_results = reduce_results

    assert self.quant_method is not None
    self.quant_method.create_weights(
        layer=self,
        input_size_per_partition=self.input_size_per_partition,
        output_partition_sizes=self.output_partition_sizes,
        input_size=self.input_size,
        output_size=self.output_size,
        params_dtype=self.params_dtype,
        weight_loader=(
            self.weight_loader_v2 if self.quant_method.__class__.__name__
            in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader))
    if not reduce_results and (bias and not skip_bias_add):
        raise ValueError("When not reduce the results, adding bias to the "
                         "results can lead to incorrect results")

    if bias:
        self.bias = Parameter(
            torch.empty(self.output_size, dtype=params_dtype))
        set_weight_attrs(self.bias, {
            "output_dim": 0,
            "weight_loader": self.weight_loader,
        })
    else:
        self.register_parameter("bias", None)
fastvideo.layers.linear.UnquantizedLinearMethod

Bases: LinearMethodBase

Linear method without quantization.

Functions

fastvideo.layers.linear.adjust_scalar_to_fused_array
adjust_scalar_to_fused_array(param: Tensor, loaded_weight: Tensor, shard_id: str | int) -> tuple[Tensor, Tensor]

For fused modules (QKV and MLP) we have an array of length N that holds 1 scale for each "logical" matrix. So the param is an array of length N. The loaded_weight corresponds to one of the shards on disk. Here, we slice the param based on the shard_id for loading.

Source code in fastvideo/layers/linear.py
def adjust_scalar_to_fused_array(
        param: torch.Tensor, loaded_weight: torch.Tensor,
        shard_id: str | int) -> tuple[torch.Tensor, torch.Tensor]:
    """For fused modules (QKV and MLP) we have an array of length
    N that holds 1 scale for each "logical" matrix. So the param
    is an array of length N. The loaded_weight corresponds to 
    one of the shards on disk. Here, we slice the param based on 
    the shard_id for loading.
    """
    qkv_idxs = {"q": 0, "k": 1, "v": 2}

    if isinstance(shard_id, str):
        shard_id = qkv_idxs[shard_id]
    elif not isinstance(shard_id, int):
        raise ValueError(f"Unknown Shard Id {shard_id}")

    # AutoFP8 scales do not have a shape
    # compressed-tensors scales do have a shape
    if len(loaded_weight.shape) != 0:
        assert loaded_weight.shape[0] == 1
        loaded_weight = loaded_weight[0]

    return param[shard_id], loaded_weight

fastvideo.layers.mlp

Classes

fastvideo.layers.mlp.MLP
MLP(input_dim: int, mlp_hidden_dim: int, output_dim: int | None = None, bias: bool = True, act_type: str = 'gelu_pytorch_tanh', dtype: dtype | None = None, prefix: str = '')

Bases: Module

MLP for DiT blocks, NO gated linear units

Source code in fastvideo/layers/mlp.py
def __init__(
    self,
    input_dim: int,
    mlp_hidden_dim: int,
    output_dim: int | None = None,
    bias: bool = True,
    act_type: str = "gelu_pytorch_tanh",
    dtype: torch.dtype | None = None,
    prefix: str = "",
):
    super().__init__()
    self.fc_in = ReplicatedLinear(
        input_dim,
        mlp_hidden_dim,  # For activation func like SiLU that need 2x width
        bias=bias,
        params_dtype=dtype)

    self.act = get_act_fn(act_type)
    if output_dim is None:
        output_dim = input_dim
    self.fc_out = ReplicatedLinear(mlp_hidden_dim,
                                   output_dim,
                                   bias=bias,
                                   params_dtype=dtype)

Functions

fastvideo.layers.quantization

Classes

Functions

fastvideo.layers.quantization.register_quantization_config
register_quantization_config(quantization: str)

Register a customized vllm quantization config.

When a quantization method is not supported by vllm, you can register a customized quantization config to support it.

Parameters:

Name Type Description Default
quantization str

The quantization method name.

required

Examples:

>>> from fastvideo.layers.quantization import register_quantization_config
>>> from fastvideo.layers.quantization import get_quantization_config
>>> from fastvideo.layers.quantization.base_config import QuantizationConfig
>>>
>>> @register_quantization_config("my_quant")
... class MyQuantConfig(QuantizationConfig):
...     pass
>>>
>>> get_quantization_config("my_quant")
<class 'MyQuantConfig'>
Source code in fastvideo/layers/quantization/__init__.py
def register_quantization_config(quantization: str):
    """Register a customized vllm quantization config.

    When a quantization method is not supported by vllm, you can register a customized
    quantization config to support it.

    Args:
        quantization (str): The quantization method name.

    Examples:
        >>> from fastvideo.layers.quantization import register_quantization_config
        >>> from fastvideo.layers.quantization import get_quantization_config
        >>> from fastvideo.layers.quantization.base_config import QuantizationConfig
        >>>
        >>> @register_quantization_config("my_quant")
        ... class MyQuantConfig(QuantizationConfig):
        ...     pass
        >>>
        >>> get_quantization_config("my_quant")
        <class 'MyQuantConfig'>
    """  # noqa: E501

    def _wrapper(quant_config_cls):
        if quantization in QUANTIZATION_METHODS:
            raise ValueError(
                f"The quantization method `{quantization}` is already exists.")
        if not issubclass(quant_config_cls, QuantizationConfig):
            raise ValueError("The quantization config must be a subclass of "
                             "`QuantizationConfig`.")
        _CUSTOMIZED_METHOD_TO_QUANT_CONFIG[quantization] = quant_config_cls
        QUANTIZATION_METHODS.append(quantization)
        return quant_config_cls

    return _wrapper

Modules

fastvideo.layers.quantization.base_config
Classes
fastvideo.layers.quantization.base_config.QuantizationConfig
QuantizationConfig()

Bases: ABC

Base class for quantization configs.

Source code in fastvideo/layers/quantization/base_config.py
def __init__(self):
    super().__init__()
    # mapping is updated by models as they initialize
    self.packed_modules_mapping: dict[str, list[str]] = dict()
Functions
fastvideo.layers.quantization.base_config.QuantizationConfig.from_config abstractmethod classmethod
from_config(config: dict[str, Any]) -> QuantizationConfig

Create a config class from the model's quantization config.

Source code in fastvideo/layers/quantization/base_config.py
@classmethod
@abstractmethod
def from_config(cls, config: dict[str, Any]) -> "QuantizationConfig":
    """Create a config class from the model's quantization config."""
    raise NotImplementedError
fastvideo.layers.quantization.base_config.QuantizationConfig.get_config_filenames abstractmethod staticmethod
get_config_filenames() -> list[str]

List of filenames to search for in the model directory.

Source code in fastvideo/layers/quantization/base_config.py
@staticmethod
@abstractmethod
def get_config_filenames() -> list[str]:
    """List of filenames to search for in the model directory."""
    raise NotImplementedError
fastvideo.layers.quantization.base_config.QuantizationConfig.get_from_keys staticmethod
get_from_keys(config: dict[str, Any], keys: list[str]) -> Any

Get a value from the model's quantization config.

Source code in fastvideo/layers/quantization/base_config.py
@staticmethod
def get_from_keys(config: dict[str, Any], keys: list[str]) -> Any:
    """Get a value from the model's quantization config."""
    for key in keys:
        if key in config:
            return config[key]
    raise ValueError(f"Cannot find any of {keys} in the model's "
                     "quantization config.")
fastvideo.layers.quantization.base_config.QuantizationConfig.get_from_keys_or staticmethod
get_from_keys_or(config: dict[str, Any], keys: list[str], default: Any) -> Any

Get a optional value from the model's quantization config.

Source code in fastvideo/layers/quantization/base_config.py
@staticmethod
def get_from_keys_or(config: dict[str, Any], keys: list[str],
                     default: Any) -> Any:
    """Get a optional value from the model's quantization config."""
    try:
        return QuantizationConfig.get_from_keys(config, keys)
    except ValueError:
        return default
fastvideo.layers.quantization.base_config.QuantizationConfig.get_min_capability abstractmethod classmethod
get_min_capability() -> int

Minimum GPU capability to support the quantization method.

E.g., 70 for Volta, 75 for Turing, 80 for Ampere. This requirement is due to the custom CUDA kernels used by the quantization method.

Source code in fastvideo/layers/quantization/base_config.py
@classmethod
@abstractmethod
def get_min_capability(cls) -> int:
    """Minimum GPU capability to support the quantization method.

    E.g., 70 for Volta, 75 for Turing, 80 for Ampere.
    This requirement is due to the custom CUDA kernels used by the
    quantization method.
    """
    raise NotImplementedError
fastvideo.layers.quantization.base_config.QuantizationConfig.get_name abstractmethod
get_name() -> QuantizationMethods

Name of the quantization method.

Source code in fastvideo/layers/quantization/base_config.py
@abstractmethod
def get_name(self) -> QuantizationMethods:
    """Name of the quantization method."""
    raise NotImplementedError
fastvideo.layers.quantization.base_config.QuantizationConfig.get_quant_method abstractmethod
get_quant_method(layer: Module, prefix: str) -> QuantizeMethodBase | None

Get the quantize method to use for the quantized layer.

Parameters:

Name Type Description Default
layer Module

The layer for the quant method.

required
prefix str

The full name of the layer in the state dict

required

Returns: The quantize method. None if the given layer doesn't support quant method.

Source code in fastvideo/layers/quantization/base_config.py
@abstractmethod
def get_quant_method(self, layer: torch.nn.Module,
                     prefix: str) -> QuantizeMethodBase | None:
    """Get the quantize method to use for the quantized layer.

    Args:
        layer: The layer for the quant method.
        prefix: The full name of the layer in the state dict
    Returns:
        The quantize method. None if the given layer doesn't support quant
        method.
    """
    raise NotImplementedError
fastvideo.layers.quantization.base_config.QuantizationConfig.get_supported_act_dtypes abstractmethod
get_supported_act_dtypes() -> list[dtype]

List of supported activation dtypes.

Source code in fastvideo/layers/quantization/base_config.py
@abstractmethod
def get_supported_act_dtypes(self) -> list[torch.dtype]:
    """List of supported activation dtypes."""
    raise NotImplementedError
fastvideo.layers.quantization.base_config.QuantizationConfig.override_quantization_method classmethod
override_quantization_method(hf_quant_cfg, user_quant) -> QuantizationMethods | None

Detects if this quantization method can support a given checkpoint format by overriding the user specified quantization method -- this method should only be overwritten by subclasses in exceptional circumstances

Source code in fastvideo/layers/quantization/base_config.py
@classmethod
def override_quantization_method(cls, hf_quant_cfg,
                                 user_quant) -> QuantizationMethods | None:
    """
       Detects if this quantization method can support a given checkpoint
       format by overriding the user specified quantization method -- 
       this method should only be overwritten by subclasses in exceptional 
       circumstances
    """
    return None
fastvideo.layers.quantization.base_config.QuantizeMethodBase

Bases: ABC

Base class for different quantized methods.

Functions
fastvideo.layers.quantization.base_config.QuantizeMethodBase.apply abstractmethod
apply(layer: Module, *args, **kwargs) -> Tensor

Apply the weights in layer to the input tensor.

Expects create_weights to have been called before on the layer.

Source code in fastvideo/layers/quantization/base_config.py
@abstractmethod
def apply(self, layer: torch.nn.Module, *args, **kwargs) -> torch.Tensor:
    """Apply the weights in layer to the input tensor.

    Expects create_weights to have been called before on the layer."""
    raise NotImplementedError
fastvideo.layers.quantization.base_config.QuantizeMethodBase.create_weights abstractmethod
create_weights(layer: Module, *weight_args, **extra_weight_attrs)

Create weights for a layer.

The weights will be set as attributes of the layer.

Source code in fastvideo/layers/quantization/base_config.py
@abstractmethod
def create_weights(self, layer: torch.nn.Module, *weight_args,
                   **extra_weight_attrs):
    """Create weights for a layer.

    The weights will be set as attributes of the layer."""
    raise NotImplementedError
fastvideo.layers.quantization.base_config.QuantizeMethodBase.embedding
embedding(layer: Module, *args, **kwargs) -> Tensor

Gather embeddings in the layer based on indices in the input tensor.

Expects create_weights to have been called before on the layer.

Source code in fastvideo/layers/quantization/base_config.py
def embedding(self, layer: torch.nn.Module, *args,
              **kwargs) -> torch.Tensor:
    """Gather embeddings in the layer based on indices in the input tensor.

    Expects create_weights to have been called before on the layer."""
    raise NotImplementedError
fastvideo.layers.quantization.base_config.QuantizeMethodBase.process_weights_after_loading
process_weights_after_loading(layer: Module) -> None

Process the weight after loading.

This can be used for example, to transpose weights for computation.

Source code in fastvideo/layers/quantization/base_config.py
def process_weights_after_loading(self, layer: nn.Module) -> None:
    """Process the weight after loading.

    This can be used for example, to transpose weights for computation.
    """
    return
Functions
fastvideo.layers.quantization.base_config.method_has_implemented_embedding
method_has_implemented_embedding(method_class: type[QuantizeMethodBase]) -> bool

Not all quant methods have embedding implemented, so we need to check that it exists for our given method. We check this by making sure the function has been changed from the base implementation.

Source code in fastvideo/layers/quantization/base_config.py
def method_has_implemented_embedding(
        method_class: type[QuantizeMethodBase]) -> bool:
    """
    Not all quant methods have embedding implemented, so we need to check that
    it exists for our given method. We check this by making sure the function
    has been changed from the base implementation.
    """
    base_embedding = inspect.getattr_static(QuantizeMethodBase, "embedding",
                                            None)
    class_embedding = inspect.getattr_static(method_class, "embedding", None)

    return (class_embedding is not None
            and class_embedding is not base_embedding)

fastvideo.layers.rotary_embedding

Rotary Positional Embeddings.

Classes

fastvideo.layers.rotary_embedding.RotaryEmbedding
RotaryEmbedding(head_size: int, rotary_dim: int, max_position_embeddings: int, base: int | float, is_neox_style: bool, dtype: dtype)

Bases: CustomOp

Original rotary positional embedding.

Source code in fastvideo/layers/rotary_embedding.py
def __init__(
    self,
    head_size: int,
    rotary_dim: int,
    max_position_embeddings: int,
    base: int | float,
    is_neox_style: bool,
    dtype: torch.dtype,
) -> None:
    super().__init__()
    self.head_size = head_size
    self.rotary_dim = rotary_dim
    self.max_position_embeddings = max_position_embeddings
    self.base = base
    self.is_neox_style = is_neox_style
    self.dtype = dtype

    cache = self._compute_cos_sin_cache()
    cache = cache.to(dtype)
    self.cos_sin_cache: torch.Tensor
    self.register_buffer("cos_sin_cache", cache, persistent=False)
Functions
fastvideo.layers.rotary_embedding.RotaryEmbedding.forward_native
forward_native(positions: Tensor, query: Tensor, key: Tensor, offsets: Tensor | None = None) -> tuple[Tensor, Tensor]

A PyTorch-native implementation of forward().

Source code in fastvideo/layers/rotary_embedding.py
def forward_native(
    self,
    positions: torch.Tensor,
    query: torch.Tensor,
    key: torch.Tensor,
    offsets: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
    """A PyTorch-native implementation of forward()."""
    if offsets is not None:
        positions = positions + offsets
    positions = positions.flatten()
    num_tokens = positions.shape[0]
    cos_sin = self.cos_sin_cache.index_select(0, positions)
    cos, sin = cos_sin.chunk(2, dim=-1)

    query_shape = query.shape
    query = query.view(num_tokens, -1, self.head_size)
    query_rot = query[..., :self.rotary_dim]
    query_pass = query[..., self.rotary_dim:]
    query_rot = _apply_rotary_emb(query_rot, cos, sin, self.is_neox_style)
    query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)

    key_shape = key.shape
    key = key.view(num_tokens, -1, self.head_size)
    key_rot = key[..., :self.rotary_dim]
    key_pass = key[..., self.rotary_dim:]
    key_rot = _apply_rotary_emb(key_rot, cos, sin, self.is_neox_style)
    key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
    return query, key

Functions

fastvideo.layers.rotary_embedding.apply_rotary_emb
apply_rotary_emb(x: Tensor, freqs_cis: Tensor | tuple[Tensor, Tensor], use_real: bool = True, use_real_unbind_dim: int = -1) -> Tensor

Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are returned as real tensors. Args: x (torch.Tensor): Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply freqs_cis (Tuple[torch.Tensor]): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],) Returns: Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.

Source code in fastvideo/layers/rotary_embedding.py
def apply_rotary_emb(
    x: torch.Tensor,
    freqs_cis: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
    use_real: bool = True,
    use_real_unbind_dim: int = -1,
) -> torch.Tensor:
    """
    Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
    to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are
    reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting
    tensors contain rotary embeddings and are returned as real tensors.
    Args:
        x (`torch.Tensor`):
            Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply
        freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)
    Returns:
        Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
    """
    if use_real:
        cos, sin = freqs_cis  # [S, D]
        # Match Diffusers broadcasting (sequence_dim=2 case)
        cos = cos[None, None, :, :]
        sin = sin[None, None, :, :]
        cos, sin = cos.to(x.device), sin.to(x.device)

        if use_real_unbind_dim == -1:
            # Used for flux, cogvideox, hunyuan-dit
            x_real, x_imag = x.reshape(*x.shape[:-1], -1,
                                       2).unbind(-1)  # [B, S, H, D//2]
            x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
        elif use_real_unbind_dim == -2:
            # Used for Stable Audio, OmniGen, CogView4 and Cosmos
            x_real, x_imag = x.reshape(*x.shape[:-1], 2,
                                       -1).unbind(-2)  # [B, S, H, D//2]
            x_rotated = torch.cat([-x_imag, x_real], dim=-1)
        else:
            raise ValueError(
                f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2."
            )

        out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)

        return out
    else:
        # used for lumina
        x_rotated = torch.view_as_complex(x.float().reshape(
            *x.shape[:-1], -1, 2))
        freqs_cis = freqs_cis.unsqueeze(2)
        x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3)

        return x_out.type_as(x)
fastvideo.layers.rotary_embedding.get_1d_rotary_pos_embed
get_1d_rotary_pos_embed(dim: int, pos: FloatTensor | int, theta: float = 10000.0, theta_rescale_factor: float = 1.0, interpolation_factor: float = 1.0, dtype: dtype = float32) -> tuple[Tensor, Tensor]

Precompute the frequency tensor for complex exponential (cis) with given dimensions. (Note: cis means cos + i * sin, where i is the imaginary unit.)

This function calculates a frequency tensor with complex exponential using the given dimension 'dim' and the end index 'end'. The 'theta' parameter scales the frequencies.

Parameters:

Name Type Description Default
dim int

Dimension of the frequency tensor.

required
pos int or FloatTensor

Position indices for the frequency tensor. [S] or scalar

required
theta float

Scaling factor for frequency computation. Defaults to 10000.0.

10000.0
theta_rescale_factor float

Rescale factor for theta. Defaults to 1.0.

1.0
interpolation_factor float

Factor to scale positions. Defaults to 1.0.

1.0

Returns:

Type Description
tuple[Tensor, Tensor]

freqs_cos, freqs_sin: Precomputed frequency tensor with real and imaginary parts separately. [S, D]

Source code in fastvideo/layers/rotary_embedding.py
def get_1d_rotary_pos_embed(
    dim: int,
    pos: torch.FloatTensor | int,
    theta: float = 10000.0,
    theta_rescale_factor: float = 1.0,
    interpolation_factor: float = 1.0,
    dtype: torch.dtype = torch.float32,
) -> tuple[torch.Tensor, torch.Tensor]:
    """
    Precompute the frequency tensor for complex exponential (cis) with given dimensions.
    (Note: `cis` means `cos + i * sin`, where i is the imaginary unit.)

    This function calculates a frequency tensor with complex exponential using the given dimension 'dim'
    and the end index 'end'. The 'theta' parameter scales the frequencies.

    Args:
        dim (int): Dimension of the frequency tensor.
        pos (int or torch.FloatTensor): Position indices for the frequency tensor. [S] or scalar
        theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0.
        theta_rescale_factor (float, optional): Rescale factor for theta. Defaults to 1.0.
        interpolation_factor (float, optional): Factor to scale positions. Defaults to 1.0.

    Returns:
        freqs_cos, freqs_sin: Precomputed frequency tensor with real and imaginary parts separately. [S, D]
    """
    if isinstance(pos, int):
        pos = torch.arange(pos).float()

    # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
    # has some connection to NTK literature
    if theta_rescale_factor != 1.0:
        theta *= theta_rescale_factor**(dim / (dim - 2))

    freqs = 1.0 / (theta**(torch.arange(0, dim, 2)[:(dim // 2)].to(dtype) / dim)
                   )  # [D/2]
    freqs = torch.outer(pos * interpolation_factor, freqs)  # [S, D/2]
    freqs_cos = freqs.cos()  # [S, D/2]
    freqs_sin = freqs.sin()  # [S, D/2]
    return freqs_cos, freqs_sin
fastvideo.layers.rotary_embedding.get_meshgrid_nd
get_meshgrid_nd(start: int | tuple[int, ...], *args: int | tuple[int, ...], dim: int = 2) -> Tensor

Get n-D meshgrid with start, stop and num.

Parameters:

Name Type Description Default
start int or tuple

If len(args) == 0, start is num; If len(args) == 1, start is start, args[0] is stop, step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num. For n-dim, start/stop/num should be int or n-tuple. If n-tuple is provided, the meshgrid will be stacked following the dim order in n-tuples.

required
*args int | tuple[int, ...]

See above.

()
dim int

Dimension of the meshgrid. Defaults to 2.

2

Returns:

Name Type Description
grid ndarray

[dim, ...]

Source code in fastvideo/layers/rotary_embedding.py
def get_meshgrid_nd(start: int | tuple[int, ...],
                    *args: int | tuple[int, ...],
                    dim: int = 2) -> torch.Tensor:
    """
    Get n-D meshgrid with start, stop and num.

    Args:
        start (int or tuple): If len(args) == 0, start is num; If len(args) == 1, start is start, args[0] is stop,
            step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num. For n-dim, start/stop/num
            should be int or n-tuple. If n-tuple is provided, the meshgrid will be stacked following the dim order in
            n-tuples.
        *args: See above.
        dim (int): Dimension of the meshgrid. Defaults to 2.

    Returns:
        grid (np.ndarray): [dim, ...]
    """
    if len(args) == 0:
        # start is grid_size
        num = _to_tuple(start, dim=dim)
        start = (0, ) * dim
        stop = num
    elif len(args) == 1:
        # start is start, args[0] is stop, step is 1
        start = _to_tuple(start, dim=dim)
        stop = _to_tuple(args[0], dim=dim)
        num = tuple(stop[i] - start[i] for i in range(dim))
    elif len(args) == 2:
        # start is start, args[0] is stop, args[1] is num
        start = _to_tuple(start, dim=dim)  # Left-Top       eg: 12,0
        stop = _to_tuple(args[0], dim=dim)  # Right-Bottom   eg: 20,32
        num = _to_tuple(args[1], dim=dim)  # Target Size    eg: 32,124
    else:
        raise ValueError(f"len(args) should be 0, 1 or 2, but got {len(args)}")

    # PyTorch implement of np.linspace(start[i], stop[i], num[i], endpoint=False)
    axis_grid = []
    for i in range(dim):
        a, b, n = start[i], stop[i], num[i]
        g = torch.linspace(a, b, n + 1, dtype=torch.float32)[:n]
        axis_grid.append(g)
    grid = torch.meshgrid(*axis_grid, indexing="ij")  # dim x [W, H, D]
    grid = torch.stack(grid, dim=0)  # [dim, W, H, D]

    return grid
fastvideo.layers.rotary_embedding.get_nd_rotary_pos_embed
get_nd_rotary_pos_embed(rope_dim_list, start, *args, theta=10000.0, theta_rescale_factor: float | list[float] = 1.0, interpolation_factor: float | list[float] = 1.0, shard_dim: int = 0, sp_rank: int = 0, sp_world_size: int = 1, dtype: dtype = float32, start_frame: int = 0) -> tuple[Tensor, Tensor]

This is a n-d version of precompute_freqs_cis, which is a RoPE for tokens with n-d structure. Supports sequence parallelism by allowing sharding of a specific dimension.

Parameters:

Name Type Description Default
rope_dim_list list of int

Dimension of each rope. len(rope_dim_list) should equal to n. sum(rope_dim_list) should equal to head_dim of attention layer.

required
start int | tuple of int | list of int

If len(args) == 0, start is num; If len(args) == 1, start is start, args[0] is stop, step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num.

required
*args

See above.

()
theta float

Scaling factor for frequency computation. Defaults to 10000.0.

10000.0
theta_rescale_factor float

Rescale factor for theta. Defaults to 1.0.

1.0
interpolation_factor float

Factor to scale positions. Defaults to 1.0.

1.0
shard_dim int

Which dimension to shard for sequence parallelism. Defaults to 0.

0
sp_rank int

Rank in the sequence parallel group. Defaults to 0.

0
sp_world_size int

World size of the sequence parallel group. Defaults to 1.

1

Returns:

Type Description
tuple[Tensor, Tensor]

Tuple[torch.Tensor, torch.Tensor]: (cos, sin) tensors of shape [HW, D/2]

Source code in fastvideo/layers/rotary_embedding.py
def get_nd_rotary_pos_embed(
    rope_dim_list,
    start,
    *args,
    theta=10000.0,
    theta_rescale_factor: float | list[float] = 1.0,
    interpolation_factor: float | list[float] = 1.0,
    shard_dim: int = 0,
    sp_rank: int = 0,
    sp_world_size: int = 1,
    dtype: torch.dtype = torch.float32,
    start_frame: int = 0,
) -> tuple[torch.Tensor, torch.Tensor]:
    """
    This is a n-d version of precompute_freqs_cis, which is a RoPE for tokens with n-d structure.
    Supports sequence parallelism by allowing sharding of a specific dimension.

    Args:
        rope_dim_list (list of int): Dimension of each rope. len(rope_dim_list) should equal to n.
            sum(rope_dim_list) should equal to head_dim of attention layer.
        start (int | tuple of int | list of int): If len(args) == 0, start is num; If len(args) == 1, start is start,
            args[0] is stop, step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num.
        *args: See above.
        theta (float): Scaling factor for frequency computation. Defaults to 10000.0.
        theta_rescale_factor (float): Rescale factor for theta. Defaults to 1.0.
        interpolation_factor (float): Factor to scale positions. Defaults to 1.0.
        shard_dim (int): Which dimension to shard for sequence parallelism. Defaults to 0.
        sp_rank (int): Rank in the sequence parallel group. Defaults to 0.
        sp_world_size (int): World size of the sequence parallel group. Defaults to 1.

    Returns:
        Tuple[torch.Tensor, torch.Tensor]: (cos, sin) tensors of shape [HW, D/2]
    """
    # Get the full grid
    full_grid = get_meshgrid_nd(
        start, *args, dim=len(rope_dim_list))  # [3, W, H, D] / [2, W, H]

    if start_frame > 0:
        full_grid[0] += start_frame

    # Shard the grid if using sequence parallelism (sp_world_size > 1)
    assert shard_dim < len(
        rope_dim_list
    ), f"shard_dim {shard_dim} must be less than number of dimensions {len(rope_dim_list)}"
    if sp_world_size > 1:
        # Get the shape of the full grid
        grid_shape = list(full_grid.shape[1:])

        # Ensure the dimension to shard is divisible by sp_world_size
        assert grid_shape[shard_dim] % sp_world_size == 0, (
            f"Dimension {shard_dim} with size {grid_shape[shard_dim]} is not divisible "
            f"by sequence parallel world size {sp_world_size}")

        # Compute the start and end indices for this rank's shard
        shard_size = grid_shape[shard_dim] // sp_world_size
        start_idx = sp_rank * shard_size
        end_idx = (sp_rank + 1) * shard_size

        # Create slicing indices for each dimension
        slice_indices = [slice(None) for _ in range(len(grid_shape))]
        slice_indices[shard_dim] = slice(start_idx, end_idx)

        # Shard the grid
        # Update grid shape for the sharded dimension
        grid_shape[shard_dim] = grid_shape[shard_dim] // sp_world_size
        grid = torch.empty((len(rope_dim_list), ) + tuple(grid_shape),
                           dtype=full_grid.dtype)
        for i in range(len(rope_dim_list)):
            grid[i] = full_grid[i][tuple(slice_indices)]
    else:
        grid = full_grid

    if isinstance(theta_rescale_factor, int | float):
        theta_rescale_factor = [theta_rescale_factor] * len(rope_dim_list)
    elif isinstance(theta_rescale_factor,
                    list) and len(theta_rescale_factor) == 1:
        theta_rescale_factor = [theta_rescale_factor[0]] * len(rope_dim_list)
    assert len(theta_rescale_factor) == len(
        rope_dim_list
    ), "len(theta_rescale_factor) should equal to len(rope_dim_list)"

    if isinstance(interpolation_factor, int | float):
        interpolation_factor = [interpolation_factor] * len(rope_dim_list)
    elif isinstance(interpolation_factor,
                    list) and len(interpolation_factor) == 1:
        interpolation_factor = [interpolation_factor[0]] * len(rope_dim_list)
    assert len(interpolation_factor) == len(
        rope_dim_list
    ), "len(interpolation_factor) should equal to len(rope_dim_list)"

    # use 1/ndim of dimensions to encode grid_axis
    embs = []
    for i in range(len(rope_dim_list)):
        emb = get_1d_rotary_pos_embed(
            rope_dim_list[i],
            grid[i].reshape(-1),
            theta,
            theta_rescale_factor=theta_rescale_factor[i],
            interpolation_factor=interpolation_factor[i],
            dtype=dtype,
        )  # 2 x [WHD, rope_dim_list[i]]
        embs.append(emb)

    cos = torch.cat([emb[0] for emb in embs], dim=1)  # (WHD, D/2)
    sin = torch.cat([emb[1] for emb in embs], dim=1)  # (WHD, D/2)
    return cos, sin
fastvideo.layers.rotary_embedding.get_rotary_pos_embed
get_rotary_pos_embed(rope_sizes, hidden_size, heads_num, rope_dim_list, rope_theta, theta_rescale_factor=1.0, interpolation_factor=1.0, shard_dim: int = 0, dtype: dtype = float32, start_frame: int = 0) -> tuple[Tensor, Tensor]

Generate rotary positional embeddings for the given sizes.

Parameters:

Name Type Description Default
rope_sizes

Tuple of dimensions (t, h, w)

required
hidden_size

Hidden dimension size

required
heads_num

Number of attention heads

required
rope_dim_list

List of dimensions for each axis, or None

required
rope_theta

Base for frequency calculations

required
theta_rescale_factor

Rescale factor for theta. Defaults to 1.0

1.0
interpolation_factor

Factor to scale positions. Defaults to 1.0

1.0
shard_dim int

Which dimension to shard for sequence parallelism. Defaults to 0.

0

Returns:

Type Description
tuple[Tensor, Tensor]

Tuple of (cos, sin) tensors for rotary embeddings

Source code in fastvideo/layers/rotary_embedding.py
def get_rotary_pos_embed(
    rope_sizes,
    hidden_size,
    heads_num,
    rope_dim_list,
    rope_theta,
    theta_rescale_factor=1.0,
    interpolation_factor=1.0,
    shard_dim: int = 0,
    dtype: torch.dtype = torch.float32,
    start_frame: int = 0,
) -> tuple[torch.Tensor, torch.Tensor]:
    """
    Generate rotary positional embeddings for the given sizes.

    Args:
        rope_sizes: Tuple of dimensions (t, h, w)
        hidden_size: Hidden dimension size
        heads_num: Number of attention heads
        rope_dim_list: List of dimensions for each axis, or None
        rope_theta: Base for frequency calculations
        theta_rescale_factor: Rescale factor for theta. Defaults to 1.0
        interpolation_factor: Factor to scale positions. Defaults to 1.0
        shard_dim: Which dimension to shard for sequence parallelism. Defaults to 0.

    Returns:
        Tuple of (cos, sin) tensors for rotary embeddings
    """

    target_ndim = 3
    head_dim = hidden_size // heads_num

    if rope_dim_list is None:
        rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)]

    assert sum(
        rope_dim_list
    ) == head_dim, "sum(rope_dim_list) should equal to head_dim of attention layer"

    # Get SP info
    sp_group = get_sp_group()
    sp_rank = sp_group.rank_in_group
    sp_world_size = sp_group.world_size

    freqs_cos, freqs_sin = get_nd_rotary_pos_embed(
        rope_dim_list,
        rope_sizes,
        theta=rope_theta,
        theta_rescale_factor=theta_rescale_factor,
        interpolation_factor=interpolation_factor,
        shard_dim=shard_dim,
        sp_rank=sp_rank,
        sp_world_size=sp_world_size,
        dtype=dtype,
        start_frame=start_frame,
    )
    return freqs_cos, freqs_sin

fastvideo.layers.utils

Utility methods for model layers.

fastvideo.layers.visual_embedding

Classes

fastvideo.layers.visual_embedding.ModulateProjection
ModulateProjection(hidden_size: int, factor: int = 2, act_layer: str = 'silu', dtype: dtype | None = None, prefix: str = '')

Bases: Module

Modulation layer for DiT blocks.

Source code in fastvideo/layers/visual_embedding.py
def __init__(
    self,
    hidden_size: int,
    factor: int = 2,
    act_layer: str = "silu",
    dtype: torch.dtype | None = None,
    prefix: str = "",
):
    super().__init__()
    self.factor = factor
    self.hidden_size = hidden_size
    self.linear = ReplicatedLinear(hidden_size,
                                   hidden_size * factor,
                                   bias=True,
                                   params_dtype=dtype)
    self.act = get_act_fn(act_layer)
fastvideo.layers.visual_embedding.PatchEmbed
PatchEmbed(patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True, bias=True, dtype=None, prefix: str = '')

Bases: Module

2D Image to Patch Embedding

Image to Patch Embedding using Conv2d

A convolution based approach to patchifying a 2D image w/ embedding projection.

Based on the impl in https://github.com/google-research/vision_transformer

Hacked together by / Copyright 2020 Ross Wightman

Remove the _assert function in forward function to be compatible with multi-resolution images.

Source code in fastvideo/layers/visual_embedding.py
def __init__(self,
             patch_size=16,
             in_chans=3,
             embed_dim=768,
             norm_layer=None,
             flatten=True,
             bias=True,
             dtype=None,
             prefix: str = ""):
    super().__init__()
    # Convert patch_size to 2-tuple
    if isinstance(patch_size, list | tuple):
        if len(patch_size) == 1:
            patch_size = (patch_size[0], patch_size[0])
    else:
        patch_size = (patch_size, patch_size)

    self.patch_size = patch_size
    self.flatten = flatten

    self.proj = nn.Conv3d(in_chans,
                          embed_dim,
                          kernel_size=patch_size,
                          stride=patch_size,
                          bias=bias,
                          dtype=dtype)
    self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
fastvideo.layers.visual_embedding.TimestepEmbedder
TimestepEmbedder(hidden_size, act_layer='silu', frequency_embedding_size=256, max_period=10000, dtype=None, freq_dtype=float32, prefix: str = '')

Bases: Module

Embeds scalar timesteps into vector representations.

Source code in fastvideo/layers/visual_embedding.py
def __init__(
    self,
    hidden_size,
    act_layer="silu",
    frequency_embedding_size=256,
    max_period=10000,
    dtype=None,
    freq_dtype=torch.float32,
    prefix: str = "",
):
    super().__init__()
    self.frequency_embedding_size = frequency_embedding_size
    self.max_period = max_period

    self.mlp = MLP(frequency_embedding_size,
                   hidden_size,
                   hidden_size,
                   act_type=act_layer,
                   dtype=dtype)
    self.freq_dtype = freq_dtype

Functions

fastvideo.layers.visual_embedding.get_timestep_embedding
get_timestep_embedding(timesteps: Tensor, embedding_dim: int, flip_sin_to_cos: bool = False, downscale_freq_shift: float = 1, scale: float = 1, max_period: int = 10000) -> Tensor

This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings. Args timesteps (torch.Tensor): a 1-D Tensor of N indices, one per batch element. These may be fractional. embedding_dim (int): the dimension of the output. flip_sin_to_cos (bool): Whether the embedding order should be cos, sin (if True) or sin, cos (if False) downscale_freq_shift (float): Controls the delta between frequencies between dimensions scale (float): Scaling factor applied to the embeddings. max_period (int): Controls the maximum frequency of the embeddings Returns torch.Tensor: an [N x dim] Tensor of positional embeddings.

Source code in fastvideo/layers/visual_embedding.py
def get_timestep_embedding(
    timesteps: torch.Tensor,
    embedding_dim: int,
    flip_sin_to_cos: bool = False,
    downscale_freq_shift: float = 1,
    scale: float = 1,
    max_period: int = 10000,
) -> torch.Tensor:
    """
    This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
    Args
        timesteps (torch.Tensor):
            a 1-D Tensor of N indices, one per batch element. These may be fractional.
        embedding_dim (int):
            the dimension of the output.
        flip_sin_to_cos (bool):
            Whether the embedding order should be `cos, sin` (if True) or `sin, cos` (if False)
        downscale_freq_shift (float):
            Controls the delta between frequencies between dimensions
        scale (float):
            Scaling factor applied to the embeddings.
        max_period (int):
            Controls the maximum frequency of the embeddings
    Returns
        torch.Tensor: an [N x dim] Tensor of positional embeddings.
    """
    assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"

    half_dim = embedding_dim // 2
    exponent = -math.log(max_period) * torch.arange(
        start=0, end=half_dim, dtype=torch.float32, device=timesteps.device)
    exponent = exponent / (half_dim - downscale_freq_shift)

    emb = torch.exp(exponent)
    emb = timesteps[:, None].float() * emb[None, :]

    # scale embeddings
    emb = scale * emb

    # concat sine and cosine embeddings
    emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)

    # flip sine and cosine embeddings
    if flip_sin_to_cos:
        emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)

    # zero pad
    if embedding_dim % 2 == 1:
        emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
    return emb
fastvideo.layers.visual_embedding.timestep_embedding
timestep_embedding(t: Tensor, dim: int, max_period: int = 10000, dtype: dtype = float32) -> Tensor

Create sinusoidal timestep embeddings.

Parameters:

Name Type Description Default
t Tensor

Tensor of shape [B] with timesteps

required
dim int

Embedding dimension

required
max_period int

Controls the minimum frequency of the embeddings

10000

Returns:

Type Description
Tensor

Tensor of shape [B, dim] with embeddings

Source code in fastvideo/layers/visual_embedding.py
def timestep_embedding(t: torch.Tensor,
                       dim: int,
                       max_period: int = 10000,
                       dtype: torch.dtype = torch.float32) -> torch.Tensor:
    """
    Create sinusoidal timestep embeddings.

    Args:
        t: Tensor of shape [B] with timesteps
        dim: Embedding dimension
        max_period: Controls the minimum frequency of the embeddings

    Returns:
        Tensor of shape [B, dim] with embeddings
    """
    half = dim // 2
    freqs = torch.exp(-math.log(max_period) *
                      torch.arange(start=0, end=half, dtype=dtype) /
                      half).to(device=t.device)
    args = t[:, None].float() * freqs[None]
    embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
    if dim % 2:
        embedding = torch.cat(
            [embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
    return embedding
fastvideo.layers.visual_embedding.unpatchify
unpatchify(x, t, h, w, patch_size, channels) -> Tensor

Convert patched representation back to image space.

Parameters:

Name Type Description Default
x

Tensor of shape [B, THW, CP_tP_h*P_w]

required
t, h, w

Temporal and spatial dimensions

required

Returns:

Type Description
Tensor

Unpatchified tensor of shape [B, C, TP_t, HP_h, W*P_w]

Source code in fastvideo/layers/visual_embedding.py
def unpatchify(x, t, h, w, patch_size, channels) -> torch.Tensor:
    """
    Convert patched representation back to image space.

    Args:
        x: Tensor of shape [B, T*H*W, C*P_t*P_h*P_w]
        t, h, w: Temporal and spatial dimensions

    Returns:
        Unpatchified tensor of shape [B, C, T*P_t, H*P_h, W*P_w]
    """
    assert x.ndim == 3, f"x.ndim: {x.ndim}"
    assert len(patch_size) == 3, f"patch_size: {patch_size}"
    assert t * h * w == x.shape[
        1], f"t * h * w: {t * h * w}, x.shape[1]: {x.shape[1]}"
    c = channels
    pt, ph, pw = patch_size

    x = x.reshape(shape=(x.shape[0], t, h, w, c, pt, ph, pw))
    x = torch.einsum("nthwcopq->nctohpwq", x)
    imgs = x.reshape(shape=(x.shape[0], c, t * pt, h * ph, w * pw))

    return imgs

fastvideo.layers.vocab_parallel_embedding

Classes

fastvideo.layers.vocab_parallel_embedding.UnquantizedEmbeddingMethod

Bases: QuantizeMethodBase

Unquantized method for embeddings.

Functions
fastvideo.layers.vocab_parallel_embedding.UnquantizedEmbeddingMethod.create_weights
create_weights(layer: Module, input_size_per_partition: int, output_partition_sizes: list[int], input_size: int, output_size: int, params_dtype: dtype, **extra_weight_attrs)

Create weights for embedding layer.

Source code in fastvideo/layers/vocab_parallel_embedding.py
def create_weights(self, layer: torch.nn.Module,
                   input_size_per_partition: int,
                   output_partition_sizes: list[int], input_size: int,
                   output_size: int, params_dtype: torch.dtype,
                   **extra_weight_attrs):
    """Create weights for embedding layer."""

    weight = Parameter(torch.empty(
        sum(output_partition_sizes),
        input_size_per_partition,
        dtype=params_dtype,
    ),
                       requires_grad=False)
    set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
    layer.register_parameter("weight", weight)
    set_weight_attrs(weight, extra_weight_attrs)
fastvideo.layers.vocab_parallel_embedding.VocabParallelEmbedding
VocabParallelEmbedding(num_embeddings: int, embedding_dim: int, params_dtype: dtype | None = None, org_num_embeddings: int | None = None, padding_size: int = DEFAULT_VOCAB_PADDING_SIZE, quant_config: QuantizationConfig | None = None, prefix: str = '')

Bases: Module

Embedding parallelized in the vocabulary dimension.

Adapted from torch.nn.Embedding, note that we pad the vocabulary size to make sure it is divisible by the number of model parallel GPUs.

In order to support various loading methods, we ensure that LoRA-added embeddings are always at the end of TP-sharded tensors. In other words, we shard base embeddings and LoRA embeddings separately (both padded), and place them in the same tensor. In this example, we will have the original vocab size = 1010, added vocab size = 16 and padding to 64. Therefore, the total vocab size with padding will be 1088 (because we first pad 1010 to 1024, add 16, and then pad to 1088). Therefore, the tensor format looks like the following: TP1, rank 0 (no sharding): |< --------BASE-------- >|< -BASE PADDING-- >|< -----LORA------ >|< -LORA PADDING-- >| corresponding token_id: | 0 | 1 | ... | 1009 | -1 | ... | -1 | 1010 | ... | 1015 | -1 | ... | -1 | index: | 0 | 1 | ... | 1009 | 1010 | ... | 1023 | 1024 | ... | 1039 | 1040 | ... | 1087 |

TP2, rank 0: |< --------------------BASE--------------------- >|< -----LORA------ >|< -LORA PADDING- >| corresponding token_id: | 0 | 1 | 2 | ... | 497 | 498 | ... | 511 | 1000 | ... | 1015 | -1 | ... | -1 | index: | 0 | 1 | 2 | ... | 497 | 498 | ... | 511 | 512 | ... | 527 | 520 | ... | 543 | TP2, rank 1: |< -----------BASE----------- >|< -BASE PADDING- >|< -----------LORA PADDING----------- >| corresponding token_id: | 512 | 513 | 514 | ... | 1009 | -1 | ... | -1 | -1 | ... | -1 | -1 | ... | -1 | index: | 0 | 1 | 2 | ... | 497 | 498 | ... | 511 | 512 | ... | 519 | 520 | ... | 543 |

Parameters:

Name Type Description Default
num_embeddings int

vocabulary size.

required
embedding_dim int

size of hidden state.

required
params_dtype dtype | None

type of the parameters.

None
org_num_embeddings int | None

original vocabulary size (without LoRA).

None
padding_size int

padding size for the vocabulary.

DEFAULT_VOCAB_PADDING_SIZE
quant_config QuantizationConfig | None

quant config for the layer

None
prefix str

full name of the layer in the state dict

''
Source code in fastvideo/layers/vocab_parallel_embedding.py
def __init__(self,
             num_embeddings: int,
             embedding_dim: int,
             params_dtype: torch.dtype | None = None,
             org_num_embeddings: int | None = None,
             padding_size: int = DEFAULT_VOCAB_PADDING_SIZE,
             quant_config: QuantizationConfig | None = None,
             prefix: str = ""):
    super().__init__()

    # Keep the input dimensions.
    tp_rank = get_tp_rank()
    self.tp_size = get_tp_world_size()
    self.num_embeddings = num_embeddings
    self.padding_size = padding_size
    self.org_vocab_size = org_num_embeddings or num_embeddings
    num_added_embeddings = num_embeddings - self.org_vocab_size
    self.org_vocab_size_padded = pad_vocab_size(self.org_vocab_size,
                                                self.padding_size)
    self.num_embeddings_padded = pad_vocab_size(
        self.org_vocab_size_padded + num_added_embeddings,
        self.padding_size)
    assert self.org_vocab_size_padded <= self.num_embeddings_padded

    self.shard_indices = self._get_indices(self.num_embeddings_padded,
                                           self.org_vocab_size_padded,
                                           self.num_embeddings,
                                           self.org_vocab_size, tp_rank,
                                           self.tp_size)
    self.embedding_dim = embedding_dim

    quant_method = None
    if quant_config is not None:
        quant_method = quant_config.get_quant_method(self, prefix=prefix)
    if quant_method is None:
        quant_method = UnquantizedEmbeddingMethod()

    # If we are making an embedding layer, then our quantization linear
    # method must implement the embedding operation. If we are another
    # layer type like ParallelLMHead, this is not important.
    is_embedding_layer = type(self.__class__) is VocabParallelEmbedding
    quant_method_implements_embedding = method_has_implemented_embedding(
        type(quant_method))
    if is_embedding_layer and not quant_method_implements_embedding:
        raise NotImplementedError(
            f"The class {type(quant_method).__name__} must implement "
            "the 'embedding' method, see UnquantizedEmbeddingMethod.")

    self.quant_method: QuantizeMethodBase = quant_method

    if params_dtype is None:
        params_dtype = torch.get_default_dtype()
    # Divide the weight matrix along the vocaburaly dimension.
    self.num_added_embeddings = self.num_embeddings - self.org_vocab_size
    self.num_embeddings_per_partition = divide(self.num_embeddings_padded,
                                               self.tp_size)
    assert (self.shard_indices.num_elements_padded ==
            self.num_embeddings_per_partition)
    self.num_org_embeddings_per_partition = (
        self.shard_indices.org_vocab_end_index -
        self.shard_indices.org_vocab_start_index)
    self.num_added_embeddings_per_partition = (
        self.shard_indices.added_vocab_end_index -
        self.shard_indices.added_vocab_start_index)

    self.quant_method.create_weights(self,
                                     self.embedding_dim,
                                     [self.num_embeddings_per_partition],
                                     self.embedding_dim,
                                     self.num_embeddings_padded,
                                     params_dtype=params_dtype,
                                     weight_loader=self.weight_loader)
Functions
fastvideo.layers.vocab_parallel_embedding.VocabParallelEmbedding.get_sharded_to_full_mapping
get_sharded_to_full_mapping() -> list[int] | None

Get a mapping that can be used to reindex the gathered logits for sampling.

During sampling, we gather logits from all ranks. The relationship of index->token_id will follow the same format as outlined in the class docstring. However, after the gather, we want to reindex the final logits tensor to map index->token_id one-to-one (the index is always equal the token_id it corresponds to). The indices returned by this method allow us to do that.

Source code in fastvideo/layers/vocab_parallel_embedding.py
def get_sharded_to_full_mapping(self) -> list[int] | None:
    """Get a mapping that can be used to reindex the gathered
    logits for sampling.

    During sampling, we gather logits from all ranks. The relationship
    of index->token_id will follow the same format as outlined in the class
    docstring. However, after the gather, we want to reindex the final
    logits tensor to map index->token_id one-to-one (the index is always
    equal the token_id it corresponds to). The indices returned by this
    method allow us to do that.
    """
    if self.tp_size < 2:
        return None

    base_embeddings: list[int] = []
    added_embeddings: list[int] = []
    padding: list[int] = []
    for tp_rank in range(self.tp_size):
        shard_indices = self._get_indices(self.num_embeddings_padded,
                                          self.org_vocab_size_padded,
                                          self.num_embeddings,
                                          self.org_vocab_size, tp_rank,
                                          self.tp_size)
        range_start = self.num_embeddings_per_partition * tp_rank
        range_end = self.num_embeddings_per_partition * (tp_rank + 1)
        base_embeddings.extend(
            range(range_start,
                  range_start + shard_indices.num_org_elements))
        padding.extend(
            range(range_start + shard_indices.num_org_elements,
                  range_start + shard_indices.num_org_elements_padded))
        added_embeddings.extend(
            range(
                range_start + shard_indices.num_org_elements_padded,
                range_start + shard_indices.num_org_elements_padded +
                shard_indices.num_added_elements))
        padding.extend(
            range(
                range_start + shard_indices.num_org_elements_padded +
                shard_indices.num_added_elements,
                range_start + shard_indices.num_org_elements_padded +
                shard_indices.num_added_elements_padded))
        assert (range_start + shard_indices.num_org_elements_padded +
                shard_indices.num_added_elements_padded == range_end)
    ret = base_embeddings + added_embeddings + padding
    assert len(ret) == self.num_embeddings_padded
    return ret
fastvideo.layers.vocab_parallel_embedding.VocabParallelEmbeddingShardIndices dataclass
VocabParallelEmbeddingShardIndices(padded_org_vocab_start_index: int, padded_org_vocab_end_index: int, padded_added_vocab_start_index: int, padded_added_vocab_end_index: int, org_vocab_start_index: int, org_vocab_end_index: int, added_vocab_start_index: int, added_vocab_end_index: int)

Indices for a shard of a vocab parallel embedding.

Functions

fastvideo.layers.vocab_parallel_embedding.pad_vocab_size
pad_vocab_size(vocab_size: int, pad_to: int = DEFAULT_VOCAB_PADDING_SIZE) -> int

Pad the vocab size to the given value.

Source code in fastvideo/layers/vocab_parallel_embedding.py
def pad_vocab_size(vocab_size: int,
                   pad_to: int = DEFAULT_VOCAB_PADDING_SIZE) -> int:
    """Pad the vocab size to the given value."""
    return ((vocab_size + pad_to - 1) // pad_to) * pad_to