fastvideo.v1.layers.layernorm#

Custom normalization layers.

Module Contents#

Classes#

LayerNormScaleShift

Fused operation that combines LayerNorm with scale and shift operations. This reduces memory bandwidth by combining memory-bound operations.

RMSNorm

Root mean square normalization.

ScaleResidual

Applies gated residual connection.

ScaleResidualLayerNormScaleShift

Fused operation that combines:

API#

class fastvideo.v1.layers.layernorm.LayerNormScaleShift(hidden_size: int, norm_type: str = 'rms', eps: float = 1e-06, elementwise_affine: bool = False, dtype: torch.dtype = torch.float32, prefix: str = '')[source]#

Bases: torch.nn.Module

Fused operation that combines LayerNorm with scale and shift operations. This reduces memory bandwidth by combining memory-bound operations.

Initialization

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) torch.Tensor[source]#

Apply ln followed by scale and shift in a single fused operation.

class fastvideo.v1.layers.layernorm.RMSNorm(hidden_size: int, eps: float = 1e-06, dtype: torch.dtype = torch.float32, var_hidden_size: Optional[int] = None, has_weight: bool = True)[source]#

Bases: fastvideo.v1.layers.custom_op.CustomOp

Root mean square normalization.

Computes x -> w * x / sqrt(E[x^2] + eps) where w is the learned weight. Refer to https://arxiv.org/abs/1910.07467

Initialization

Initialize internal Module state, shared by both nn.Module and ScriptModule.

extra_repr() str[source]#
forward_native(x: torch.Tensor, residual: Optional[torch.Tensor] = None) Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]][source]#

PyTorch-native implementation equivalent to forward().

class fastvideo.v1.layers.layernorm.ScaleResidual(prefix: str = '')[source]#

Bases: torch.nn.Module

Applies gated residual connection.

Initialization

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(residual: torch.Tensor, x: torch.Tensor, gate: torch.Tensor) torch.Tensor[source]#

Apply gated residual connection.

class fastvideo.v1.layers.layernorm.ScaleResidualLayerNormScaleShift(hidden_size: int, norm_type: str = 'rms', eps: float = 1e-06, elementwise_affine: bool = False, dtype: torch.dtype = torch.float32, prefix: str = '')[source]#

Bases: torch.nn.Module

Fused operation that combines:

  1. Gated residual connection

  2. LayerNorm

  3. Scale and shift operations

This reduces memory bandwidth by combining memory-bound operations.

Initialization

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(residual: torch.Tensor, x: torch.Tensor, gate: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) Tuple[torch.Tensor, torch.Tensor][source]#

Apply gated residual connection, followed by layernorm and scale/shift in a single fused operation.

Returns:

  • normalized and modulated output

  • residual value (value after residual connection but before normalization)

Return type:

Tuple containing