Skip to content

mlp

Classes

fastvideo.layers.mlp.MLP

MLP(input_dim: int, mlp_hidden_dim: int, output_dim: int | None = None, bias: bool = True, act_type: str = 'gelu_pytorch_tanh', dtype: dtype | None = None, prefix: str = '')

Bases: Module

MLP for DiT blocks, NO gated linear units

Source code in fastvideo/layers/mlp.py
def __init__(
    self,
    input_dim: int,
    mlp_hidden_dim: int,
    output_dim: int | None = None,
    bias: bool = True,
    act_type: str = "gelu_pytorch_tanh",
    dtype: torch.dtype | None = None,
    prefix: str = "",
):
    super().__init__()
    self.fc_in = ReplicatedLinear(
        input_dim,
        mlp_hidden_dim,  # For activation func like SiLU that need 2x width
        bias=bias,
        params_dtype=dtype)

    self.act = get_act_fn(act_type)
    if output_dim is None:
        output_dim = input_dim
    self.fc_out = ReplicatedLinear(mlp_hidden_dim,
                                   output_dim,
                                   bias=bias,
                                   params_dtype=dtype)

Functions