Skip to content

layernorm

Custom normalization layers.

Classes

fastvideo.layers.layernorm.LayerNormScaleShift

LayerNormScaleShift(hidden_size: int, norm_type: str = 'rms', eps: float = 1e-06, elementwise_affine: bool = False, dtype: dtype = float32, compute_dtype: dtype | None = None, prefix: str = '')

Bases: Module

Fused operation that combines LayerNorm with scale and shift operations. This reduces memory bandwidth by combining memory-bound operations.

Source code in fastvideo/layers/layernorm.py
def __init__(
    self,
    hidden_size: int,
    norm_type: str = "rms",
    eps: float = 1e-6,
    elementwise_affine: bool = False,
    dtype: torch.dtype = torch.float32,
    compute_dtype: torch.dtype | None = None,
    prefix: str = "",
):
    super().__init__()
    self.compute_dtype = compute_dtype
    if norm_type == "rms":
        self.norm = RMSNorm(hidden_size,
                            has_weight=elementwise_affine,
                            eps=eps)
    elif norm_type == "layer":
        if self.compute_dtype == torch.float32:
            self.norm = FP32LayerNorm(hidden_size,
                                      elementwise_affine=elementwise_affine,
                                      eps=eps)
        else:
            self.norm = nn.LayerNorm(hidden_size,
                                     elementwise_affine=elementwise_affine,
                                     eps=eps,
                                     dtype=dtype)
    else:
        raise NotImplementedError(f"Norm type {norm_type} not implemented")

Functions

fastvideo.layers.layernorm.LayerNormScaleShift.forward
forward(x: Tensor, shift: Tensor, scale: Tensor) -> Tensor

Apply ln followed by scale and shift in a single fused operation.

Source code in fastvideo/layers/layernorm.py
def forward(self, x: torch.Tensor, shift: torch.Tensor,
            scale: torch.Tensor) -> torch.Tensor:
    """Apply ln followed by scale and shift in a single fused operation."""
    # x.shape: [batch_size, seq_len, inner_dim]
    normalized = self.norm(x)
    if self.compute_dtype == torch.float32:
        normalized = normalized.float()

    if scale.dim() == 4:
        # scale.shape: [batch_size, num_frames, 1, inner_dim]
        num_frames = scale.shape[1]
        frame_seqlen = normalized.shape[1] // num_frames
        output = (
            normalized.unflatten(dim=1, sizes=(num_frames, frame_seqlen)) *
            (1.0 + scale) + shift).flatten(1, 2)
    else:
        # scale.shape: [batch_size, 1, inner_dim]
        # shift.shape: [batch_size, 1, inner_dim]
        output = normalized * (1.0 + scale) + shift

    if self.compute_dtype == torch.float32:
        output = output.to(x.dtype)

    return output

fastvideo.layers.layernorm.RMSNorm

RMSNorm(hidden_size: int, eps: float = 1e-06, dtype: dtype = float32, var_hidden_size: int | None = None, has_weight: bool = True)

Bases: CustomOp

Root mean square normalization.

Computes x -> w * x / sqrt(E[x^2] + eps) where w is the learned weight. Refer to https://arxiv.org/abs/1910.07467

Source code in fastvideo/layers/layernorm.py
def __init__(
    self,
    hidden_size: int,
    eps: float = 1e-6,
    dtype: torch.dtype = torch.float32,
    var_hidden_size: int | None = None,
    has_weight: bool = True,
) -> None:
    super().__init__()

    self.hidden_size = hidden_size
    self.variance_epsilon = eps
    self.variance_size_override = (None if var_hidden_size == hidden_size
                                   else var_hidden_size)
    self.has_weight = has_weight

    from fastvideo.platforms import current_platform

    self.weight = torch.ones(hidden_size) if current_platform.is_cuda_alike(
    ) else torch.ones(hidden_size, dtype=dtype)
    if self.has_weight:
        self.weight = nn.Parameter(self.weight)

Functions

fastvideo.layers.layernorm.RMSNorm.forward_native
forward_native(x: Tensor, residual: Tensor | None = None) -> Tensor | tuple[Tensor, Tensor]

PyTorch-native implementation equivalent to forward().

Source code in fastvideo/layers/layernorm.py
def forward_native(
    self,
    x: torch.Tensor,
    residual: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
    """PyTorch-native implementation equivalent to forward()."""
    orig_dtype = x.dtype
    x = x.to(torch.float32)
    if residual is not None:
        x = x + residual.to(torch.float32)
        residual = x.to(orig_dtype)

    hidden_size = x.shape[-1]
    if hidden_size != self.hidden_size:
        raise ValueError("Expected hidden_size to be "
                         f"{self.hidden_size}, but found: {hidden_size}")

    if self.variance_size_override is None:
        x_var = x
    else:
        if hidden_size < self.variance_size_override:
            raise ValueError(
                "Expected hidden_size to be at least "
                f"{self.variance_size_override}, but found: {hidden_size}")

        x_var = x[:, :, :self.variance_size_override]

    variance = x_var.pow(2).mean(dim=-1, keepdim=True)

    x = x * torch.rsqrt(variance + self.variance_epsilon)
    x = x.to(orig_dtype)
    if self.has_weight:
        x = x * self.weight
    if residual is None:
        return x
    else:
        return x, residual

fastvideo.layers.layernorm.ScaleResidual

ScaleResidual(prefix: str = '')

Bases: Module

Applies gated residual connection.

Source code in fastvideo/layers/layernorm.py
def __init__(self, prefix: str = ""):
    super().__init__()

Functions

fastvideo.layers.layernorm.ScaleResidual.forward
forward(residual: Tensor, x: Tensor, gate: Tensor) -> Tensor

Apply gated residual connection.

Source code in fastvideo/layers/layernorm.py
def forward(self, residual: torch.Tensor, x: torch.Tensor,
            gate: torch.Tensor) -> torch.Tensor:
    """Apply gated residual connection."""
    # x.shape: [batch_size, seq_len, inner_dim]
    if gate.dim() == 4:
        # gate.shape: [batch_size, num_frames, 1, inner_dim]
        num_frames = gate.shape[1]
        frame_seqlen = x.shape[1] // num_frames
        return residual + (x.unflatten(
            dim=1, sizes=(num_frames, frame_seqlen)) * gate).flatten(1, 2)
    else:
        # gate.shape: [batch_size, 1, inner_dim]
        return residual + x * gate

fastvideo.layers.layernorm.ScaleResidualLayerNormScaleShift

ScaleResidualLayerNormScaleShift(hidden_size: int, norm_type: str = 'rms', eps: float = 1e-06, elementwise_affine: bool = False, dtype: dtype = float32, compute_dtype: dtype | None = None, prefix: str = '')

Bases: Module

Fused operation that combines: 1. Gated residual connection 2. LayerNorm 3. Scale and shift operations

This reduces memory bandwidth by combining memory-bound operations.

Source code in fastvideo/layers/layernorm.py
def __init__(
    self,
    hidden_size: int,
    norm_type: str = "rms",
    eps: float = 1e-6,
    elementwise_affine: bool = False,
    dtype: torch.dtype = torch.float32,
    compute_dtype: torch.dtype | None = None,
    prefix: str = "",
):
    super().__init__()
    if norm_type == "rms":
        self.norm = RMSNorm(hidden_size,
                            has_weight=elementwise_affine,
                            eps=eps,
                            dtype=dtype)
    elif norm_type == "layer":
        if compute_dtype == torch.float32:
            self.norm = FP32LayerNorm(hidden_size,
                                      elementwise_affine=elementwise_affine,
                                      eps=eps)
        else:
            self.norm = nn.LayerNorm(hidden_size,
                                     elementwise_affine=elementwise_affine,
                                     eps=eps,
                                     dtype=dtype)
    else:
        raise NotImplementedError(f"Norm type {norm_type} not implemented")

Functions

fastvideo.layers.layernorm.ScaleResidualLayerNormScaleShift.forward
forward(residual: Tensor, x: Tensor, gate: Tensor | int, shift: Tensor, scale: Tensor) -> tuple[Tensor, Tensor]

Apply gated residual connection, followed by layernorm and scale/shift in a single fused operation.

Returns:

Type Description
Tensor

Tuple containing:

Tensor
  • normalized and modulated output
tuple[Tensor, Tensor]
  • residual value (value after residual connection but before normalization)
Source code in fastvideo/layers/layernorm.py
def forward(self, residual: torch.Tensor, x: torch.Tensor,
            gate: torch.Tensor | int, shift: torch.Tensor,
            scale: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
    """
    Apply gated residual connection, followed by layernorm and 
    scale/shift in a single fused operation.

    Returns:
        Tuple containing:
        - normalized and modulated output
        - residual value (value after residual connection 
          but before normalization)
    """
    # x.shape: [batch_size, seq_len, inner_dim]
    # Apply residual connection with gating
    if isinstance(gate, int):
        # used by cross-attention, should be 1
        assert gate == 1
        residual_output = residual + x
    elif isinstance(gate, torch.Tensor):
        if gate.dim() == 4:
            # gate.shape: [batch_size, num_frames, 1, inner_dim]
            num_frames = gate.shape[1]
            frame_seqlen = x.shape[1] // num_frames
            residual_output = residual + (
                x.unflatten(dim=1, sizes=(num_frames, frame_seqlen)) *
                gate).flatten(1, 2)
        else:
            # used by bidirectional self attention
            # gate.shape: [batch_size, 1, inner_dim]
            residual_output = residual + x * gate
    else:
        raise ValueError(f"Gate type {type(gate)} not supported")
    # residual_output.shape: [batch_size, seq_len, inner_dim]

    # Apply normalization
    normalized = self.norm(residual_output)
    # Apply scale and shift
    if isinstance(scale, torch.Tensor) and scale.dim() == 4:
        # scale.shape: [batch_size, num_frames, 1, inner_dim]
        # shift.shape: [batch_size, num_frames, 1, inner_dim]
        num_frames = scale.shape[1]
        frame_seqlen = normalized.shape[1] // num_frames
        modulated = (
            normalized.unflatten(dim=1, sizes=(num_frames, frame_seqlen)) *
            (1.0 + scale) + shift).flatten(1, 2)
    else:
        modulated = normalized * (1.0 + scale) + shift
    return modulated, residual_output