Skip to content

activation

Custom activation functions.

Classes

fastvideo.layers.activation.GeluAndMul

GeluAndMul(approximate: str = 'none')

Bases: CustomOp

An activation function for GeGLU.

The function computes x -> GELU(x[:d]) * x[d:] where d = x.shape[-1] // 2.

Shapes

x: (batch_size, seq_len, 2 * d) or (num_tokens, 2 * d) return: (batch_size, seq_len, d) or (num_tokens, d)

Source code in fastvideo/layers/activation.py
def __init__(self, approximate: str = "none"):
    super().__init__()
    self.approximate = approximate
    if approximate not in ("none", "tanh"):
        raise ValueError(f"Unknown approximate mode: {approximate}")

Functions

fastvideo.layers.activation.GeluAndMul.forward_native
forward_native(x: Tensor) -> Tensor

PyTorch-native implementation equivalent to forward().

Source code in fastvideo/layers/activation.py
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
    """PyTorch-native implementation equivalent to forward()."""
    d = x.shape[-1] // 2
    return F.gelu(x[..., :d], approximate=self.approximate) * x[..., d:]

fastvideo.layers.activation.NewGELU

NewGELU()

Bases: CustomOp

Source code in fastvideo/layers/activation.py
def __init__(self):
    super().__init__()

Functions

fastvideo.layers.activation.NewGELU.forward_native
forward_native(x: Tensor) -> Tensor

PyTorch-native implementation equivalent to forward().

Source code in fastvideo/layers/activation.py
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
    """PyTorch-native implementation equivalent to forward()."""
    c = math.sqrt(2.0 / math.pi)
    return 0.5 * x * (1.0 + torch.tanh(c *
                                       (x + 0.044715 * torch.pow(x, 3.0))))

fastvideo.layers.activation.QuickGELU

QuickGELU()

Bases: CustomOp

Source code in fastvideo/layers/activation.py
def __init__(self):
    super().__init__()

Functions

fastvideo.layers.activation.QuickGELU.forward_native
forward_native(x: Tensor) -> Tensor

PyTorch-native implementation equivalent to forward().

Source code in fastvideo/layers/activation.py
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
    """PyTorch-native implementation equivalent to forward()."""
    return x * torch.sigmoid(1.702 * x)

fastvideo.layers.activation.SiluAndMul

SiluAndMul()

Bases: CustomOp

An activation function for SwiGLU.

The function computes x -> silu(x[:d]) * x[d:] where d = x.shape[-1] // 2.

Shapes

x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d) return: (num_tokens, d) or (batch_size, seq_len, d)

Source code in fastvideo/layers/activation.py
def __init__(self) -> None:
    super().__init__()

Functions

fastvideo.layers.activation.SiluAndMul.forward_native
forward_native(x: Tensor) -> Tensor

PyTorch-native implementation equivalent to forward().

Source code in fastvideo/layers/activation.py
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
    """PyTorch-native implementation equivalent to forward()."""
    d = x.shape[-1] // 2
    return F.silu(x[..., :d]) * x[..., d:]

Functions

fastvideo.layers.activation.get_act_and_mul_fn

get_act_and_mul_fn(act_fn_name: str) -> Module

Get an activation-and-mul (i.e. SiluAndMul) function by name.

Source code in fastvideo/layers/activation.py
def get_act_and_mul_fn(act_fn_name: str) -> nn.Module:
    """Get an activation-and-mul (i.e. SiluAndMul) function by name."""
    act_fn_name = act_fn_name.lower()
    if act_fn_name not in _ACTIVATION_AND_MUL_REGISTRY:
        raise ValueError(
            f"Activation function {act_fn_name!r} is not supported.")

    return _ACTIVATION_AND_MUL_REGISTRY[act_fn_name]()

fastvideo.layers.activation.get_act_fn

get_act_fn(act_fn_name: str) -> Module

Get an activation function by name.

Source code in fastvideo/layers/activation.py
def get_act_fn(act_fn_name: str) -> nn.Module:
    """Get an activation function by name."""
    act_fn_name = act_fn_name.lower()
    if act_fn_name not in _ACTIVATION_REGISTRY:
        raise ValueError(
            f"Activation function {act_fn_name!r} is not supported.")

    return _ACTIVATION_REGISTRY[act_fn_name]()