fastvideo.v1.layers.activation#

Custom activation functions.

Module Contents#

Classes#

GeluAndMul

An activation function for GeGLU.

NewGELU

QuickGELU

SiluAndMul

An activation function for SwiGLU.

Functions#

get_act_and_mul_fn

Get an activation-and-mul (i.e. SiluAndMul) function by name.

get_act_fn

Get an activation function by name.

API#

class fastvideo.v1.layers.activation.GeluAndMul(approximate: str = 'none')[source]#

Bases: fastvideo.v1.layers.custom_op.CustomOp

An activation function for GeGLU.

The function computes x -> GELU(x[:d]) * x[d:] where d = x.shape[-1] // 2.

Shapes: x: (batch_size, seq_len, 2 * d) or (num_tokens, 2 * d) return: (batch_size, seq_len, d) or (num_tokens, d)

Initialization

Initialize internal Module state, shared by both nn.Module and ScriptModule.

extra_repr() str[source]#
forward_native(x: torch.Tensor) torch.Tensor[source]#

PyTorch-native implementation equivalent to forward().

class fastvideo.v1.layers.activation.NewGELU[source]#

Bases: fastvideo.v1.layers.custom_op.CustomOp

forward_native(x: torch.Tensor) torch.Tensor[source]#

PyTorch-native implementation equivalent to forward().

class fastvideo.v1.layers.activation.QuickGELU[source]#

Bases: fastvideo.v1.layers.custom_op.CustomOp

forward_native(x: torch.Tensor) torch.Tensor[source]#

PyTorch-native implementation equivalent to forward().

class fastvideo.v1.layers.activation.SiluAndMul[source]#

Bases: fastvideo.v1.layers.custom_op.CustomOp

An activation function for SwiGLU.

The function computes x -> silu(x[:d]) * x[d:] where d = x.shape[-1] // 2.

Shapes: x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d) return: (num_tokens, d) or (batch_size, seq_len, d)

Initialization

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward_native(x: torch.Tensor) torch.Tensor[source]#

PyTorch-native implementation equivalent to forward().

fastvideo.v1.layers.activation.get_act_and_mul_fn(act_fn_name: str) torch.nn.Module[source]#

Get an activation-and-mul (i.e. SiluAndMul) function by name.

fastvideo.v1.layers.activation.get_act_fn(act_fn_name: str) torch.nn.Module[source]#

Get an activation function by name.