fastvideo.v1.layers.layernorm
#
Custom normalization layers.
Module Contents#
Classes#
Fused operation that combines LayerNorm with scale and shift operations. This reduces memory bandwidth by combining memory-bound operations. |
|
Root mean square normalization. |
|
Applies gated residual connection. |
|
Fused operation that combines: |
API#
- class fastvideo.v1.layers.layernorm.LayerNormScaleShift(hidden_size: int, norm_type: str = 'rms', eps: float = 1e-06, elementwise_affine: bool = False, dtype: torch.dtype = torch.float32, prefix: str = '')[source]#
Bases:
torch.nn.Module
Fused operation that combines LayerNorm with scale and shift operations. This reduces memory bandwidth by combining memory-bound operations.
Initialization
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- forward(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) torch.Tensor [source]#
Apply ln followed by scale and shift in a single fused operation.
- class fastvideo.v1.layers.layernorm.RMSNorm(hidden_size: int, eps: float = 1e-06, dtype: torch.dtype = torch.float32, var_hidden_size: Optional[int] = None, has_weight: bool = True)[source]#
Bases:
fastvideo.v1.layers.custom_op.CustomOp
Root mean square normalization.
Computes x -> w * x / sqrt(E[x^2] + eps) where w is the learned weight. Refer to https://arxiv.org/abs/1910.07467
Initialization
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- forward_native(x: torch.Tensor, residual: Optional[torch.Tensor] = None) Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] [source]#
PyTorch-native implementation equivalent to forward().
- class fastvideo.v1.layers.layernorm.ScaleResidual(prefix: str = '')[source]#
Bases:
torch.nn.Module
Applies gated residual connection.
Initialization
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- forward(residual: torch.Tensor, x: torch.Tensor, gate: torch.Tensor) torch.Tensor [source]#
Apply gated residual connection.
- class fastvideo.v1.layers.layernorm.ScaleResidualLayerNormScaleShift(hidden_size: int, norm_type: str = 'rms', eps: float = 1e-06, elementwise_affine: bool = False, dtype: torch.dtype = torch.float32, prefix: str = '')[source]#
Bases:
torch.nn.Module
Fused operation that combines:
Gated residual connection
LayerNorm
Scale and shift operations
This reduces memory bandwidth by combining memory-bound operations.
Initialization
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- forward(residual: torch.Tensor, x: torch.Tensor, gate: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) Tuple[torch.Tensor, torch.Tensor] [source]#
Apply gated residual connection, followed by layernorm and scale/shift in a single fused operation.
- Returns:
normalized and modulated output
residual value (value after residual connection but before normalization)
- Return type:
Tuple containing