fastvideo.v1.layers.activation
#
Custom activation functions.
Module Contents#
Classes#
An activation function for GeGLU. |
|
An activation function for SwiGLU. |
Functions#
Get an activation-and-mul (i.e. SiluAndMul) function by name. |
|
Get an activation function by name. |
API#
- class fastvideo.v1.layers.activation.GeluAndMul(approximate: str = 'none')[source]#
Bases:
fastvideo.v1.layers.custom_op.CustomOp
An activation function for GeGLU.
The function computes x -> GELU(x[:d]) * x[d:] where d = x.shape[-1] // 2.
Shapes: x: (batch_size, seq_len, 2 * d) or (num_tokens, 2 * d) return: (batch_size, seq_len, d) or (num_tokens, d)
Initialization
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- forward_native(x: torch.Tensor) torch.Tensor [source]#
PyTorch-native implementation equivalent to forward().
- class fastvideo.v1.layers.activation.NewGELU[source]#
Bases:
fastvideo.v1.layers.custom_op.CustomOp
- forward_native(x: torch.Tensor) torch.Tensor [source]#
PyTorch-native implementation equivalent to forward().
- class fastvideo.v1.layers.activation.QuickGELU[source]#
Bases:
fastvideo.v1.layers.custom_op.CustomOp
- forward_native(x: torch.Tensor) torch.Tensor [source]#
PyTorch-native implementation equivalent to forward().
- class fastvideo.v1.layers.activation.SiluAndMul[source]#
Bases:
fastvideo.v1.layers.custom_op.CustomOp
An activation function for SwiGLU.
The function computes x -> silu(x[:d]) * x[d:] where d = x.shape[-1] // 2.
Shapes: x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d) return: (num_tokens, d) or (batch_size, seq_len, d)
Initialization
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- forward_native(x: torch.Tensor) torch.Tensor [source]#
PyTorch-native implementation equivalent to forward().
- fastvideo.v1.layers.activation.get_act_and_mul_fn(act_fn_name: str) torch.nn.Module [source]#
Get an activation-and-mul (i.e. SiluAndMul) function by name.
- fastvideo.v1.layers.activation.get_act_fn(act_fn_name: str) torch.nn.Module [source]#
Get an activation function by name.