MLP(input_dim: int, mlp_hidden_dim: int, output_dim: int | None = None, bias: bool = True, act_type: str = 'gelu_pytorch_tanh', dtype: dtype | None = None, prefix: str = '')
Bases: Module
MLP for DiT blocks, NO gated linear units
Source code in fastvideo/layers/mlp.py
| def __init__(
self,
input_dim: int,
mlp_hidden_dim: int,
output_dim: int | None = None,
bias: bool = True,
act_type: str = "gelu_pytorch_tanh",
dtype: torch.dtype | None = None,
prefix: str = "",
):
super().__init__()
self.fc_in = ReplicatedLinear(
input_dim,
mlp_hidden_dim, # For activation func like SiLU that need 2x width
bias=bias,
params_dtype=dtype)
self.act = get_act_fn(act_type)
if output_dim is None:
output_dim = input_dim
self.fc_out = ReplicatedLinear(mlp_hidden_dim,
output_dim,
bias=bias,
params_dtype=dtype)
|