Skip to content

rotary_embedding

Rotary Positional Embeddings.

Classes

fastvideo.layers.rotary_embedding.RotaryEmbedding

RotaryEmbedding(head_size: int, rotary_dim: int, max_position_embeddings: int, base: int | float, is_neox_style: bool, dtype: dtype)

Bases: CustomOp

Original rotary positional embedding.

Source code in fastvideo/layers/rotary_embedding.py
def __init__(
    self,
    head_size: int,
    rotary_dim: int,
    max_position_embeddings: int,
    base: int | float,
    is_neox_style: bool,
    dtype: torch.dtype,
) -> None:
    super().__init__()
    self.head_size = head_size
    self.rotary_dim = rotary_dim
    self.max_position_embeddings = max_position_embeddings
    self.base = base
    self.is_neox_style = is_neox_style
    self.dtype = dtype

    cache = self._compute_cos_sin_cache()
    cache = cache.to(dtype)
    self.cos_sin_cache: torch.Tensor
    self.register_buffer("cos_sin_cache", cache, persistent=False)

Functions

fastvideo.layers.rotary_embedding.RotaryEmbedding.forward_native
forward_native(positions: Tensor, query: Tensor, key: Tensor, offsets: Tensor | None = None) -> tuple[Tensor, Tensor]

A PyTorch-native implementation of forward().

Source code in fastvideo/layers/rotary_embedding.py
def forward_native(
    self,
    positions: torch.Tensor,
    query: torch.Tensor,
    key: torch.Tensor,
    offsets: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
    """A PyTorch-native implementation of forward()."""
    if offsets is not None:
        positions = positions + offsets
    positions = positions.flatten()
    num_tokens = positions.shape[0]
    cos_sin = self.cos_sin_cache.index_select(0, positions)
    cos, sin = cos_sin.chunk(2, dim=-1)

    query_shape = query.shape
    query = query.view(num_tokens, -1, self.head_size)
    query_rot = query[..., :self.rotary_dim]
    query_pass = query[..., self.rotary_dim:]
    query_rot = _apply_rotary_emb(query_rot, cos, sin, self.is_neox_style)
    query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)

    key_shape = key.shape
    key = key.view(num_tokens, -1, self.head_size)
    key_rot = key[..., :self.rotary_dim]
    key_pass = key[..., self.rotary_dim:]
    key_rot = _apply_rotary_emb(key_rot, cos, sin, self.is_neox_style)
    key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
    return query, key

Functions

fastvideo.layers.rotary_embedding.apply_rotary_emb

apply_rotary_emb(x: Tensor, freqs_cis: Tensor | tuple[Tensor, Tensor], use_real: bool = True, use_real_unbind_dim: int = -1) -> Tensor

Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are returned as real tensors. Args: x (torch.Tensor): Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply freqs_cis (Tuple[torch.Tensor]): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],) Returns: Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.

Source code in fastvideo/layers/rotary_embedding.py
def apply_rotary_emb(
    x: torch.Tensor,
    freqs_cis: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
    use_real: bool = True,
    use_real_unbind_dim: int = -1,
) -> torch.Tensor:
    """
    Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
    to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are
    reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting
    tensors contain rotary embeddings and are returned as real tensors.
    Args:
        x (`torch.Tensor`):
            Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply
        freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)
    Returns:
        Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
    """
    if use_real:
        cos, sin = freqs_cis  # [S, D]
        # Match Diffusers broadcasting (sequence_dim=2 case)
        cos = cos[None, None, :, :]
        sin = sin[None, None, :, :]
        cos, sin = cos.to(x.device), sin.to(x.device)

        if use_real_unbind_dim == -1:
            # Used for flux, cogvideox, hunyuan-dit
            x_real, x_imag = x.reshape(*x.shape[:-1], -1,
                                       2).unbind(-1)  # [B, S, H, D//2]
            x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
        elif use_real_unbind_dim == -2:
            # Used for Stable Audio, OmniGen, CogView4 and Cosmos
            x_real, x_imag = x.reshape(*x.shape[:-1], 2,
                                       -1).unbind(-2)  # [B, S, H, D//2]
            x_rotated = torch.cat([-x_imag, x_real], dim=-1)
        else:
            raise ValueError(
                f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2."
            )

        out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)

        return out
    else:
        # used for lumina
        x_rotated = torch.view_as_complex(x.float().reshape(
            *x.shape[:-1], -1, 2))
        freqs_cis = freqs_cis.unsqueeze(2)
        x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3)

        return x_out.type_as(x)

fastvideo.layers.rotary_embedding.get_1d_rotary_pos_embed

get_1d_rotary_pos_embed(dim: int, pos: FloatTensor | int, theta: float = 10000.0, theta_rescale_factor: float = 1.0, interpolation_factor: float = 1.0, dtype: dtype = float32) -> tuple[Tensor, Tensor]

Precompute the frequency tensor for complex exponential (cis) with given dimensions. (Note: cis means cos + i * sin, where i is the imaginary unit.)

This function calculates a frequency tensor with complex exponential using the given dimension 'dim' and the end index 'end'. The 'theta' parameter scales the frequencies.

Parameters:

Name Type Description Default
dim int

Dimension of the frequency tensor.

required
pos int or FloatTensor

Position indices for the frequency tensor. [S] or scalar

required
theta float

Scaling factor for frequency computation. Defaults to 10000.0.

10000.0
theta_rescale_factor float

Rescale factor for theta. Defaults to 1.0.

1.0
interpolation_factor float

Factor to scale positions. Defaults to 1.0.

1.0

Returns:

Type Description
tuple[Tensor, Tensor]

freqs_cos, freqs_sin: Precomputed frequency tensor with real and imaginary parts separately. [S, D]

Source code in fastvideo/layers/rotary_embedding.py
def get_1d_rotary_pos_embed(
    dim: int,
    pos: torch.FloatTensor | int,
    theta: float = 10000.0,
    theta_rescale_factor: float = 1.0,
    interpolation_factor: float = 1.0,
    dtype: torch.dtype = torch.float32,
) -> tuple[torch.Tensor, torch.Tensor]:
    """
    Precompute the frequency tensor for complex exponential (cis) with given dimensions.
    (Note: `cis` means `cos + i * sin`, where i is the imaginary unit.)

    This function calculates a frequency tensor with complex exponential using the given dimension 'dim'
    and the end index 'end'. The 'theta' parameter scales the frequencies.

    Args:
        dim (int): Dimension of the frequency tensor.
        pos (int or torch.FloatTensor): Position indices for the frequency tensor. [S] or scalar
        theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0.
        theta_rescale_factor (float, optional): Rescale factor for theta. Defaults to 1.0.
        interpolation_factor (float, optional): Factor to scale positions. Defaults to 1.0.

    Returns:
        freqs_cos, freqs_sin: Precomputed frequency tensor with real and imaginary parts separately. [S, D]
    """
    if isinstance(pos, int):
        pos = torch.arange(pos).float()

    # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
    # has some connection to NTK literature
    if theta_rescale_factor != 1.0:
        theta *= theta_rescale_factor**(dim / (dim - 2))

    freqs = 1.0 / (theta**(torch.arange(0, dim, 2)[:(dim // 2)].to(dtype) / dim)
                   )  # [D/2]
    freqs = torch.outer(pos * interpolation_factor, freqs)  # [S, D/2]
    freqs_cos = freqs.cos()  # [S, D/2]
    freqs_sin = freqs.sin()  # [S, D/2]
    return freqs_cos, freqs_sin

fastvideo.layers.rotary_embedding.get_meshgrid_nd

get_meshgrid_nd(start: int | tuple[int, ...], *args: int | tuple[int, ...], dim: int = 2) -> Tensor

Get n-D meshgrid with start, stop and num.

Parameters:

Name Type Description Default
start int or tuple

If len(args) == 0, start is num; If len(args) == 1, start is start, args[0] is stop, step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num. For n-dim, start/stop/num should be int or n-tuple. If n-tuple is provided, the meshgrid will be stacked following the dim order in n-tuples.

required
*args int | tuple[int, ...]

See above.

()
dim int

Dimension of the meshgrid. Defaults to 2.

2

Returns:

Name Type Description
grid ndarray

[dim, ...]

Source code in fastvideo/layers/rotary_embedding.py
def get_meshgrid_nd(start: int | tuple[int, ...],
                    *args: int | tuple[int, ...],
                    dim: int = 2) -> torch.Tensor:
    """
    Get n-D meshgrid with start, stop and num.

    Args:
        start (int or tuple): If len(args) == 0, start is num; If len(args) == 1, start is start, args[0] is stop,
            step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num. For n-dim, start/stop/num
            should be int or n-tuple. If n-tuple is provided, the meshgrid will be stacked following the dim order in
            n-tuples.
        *args: See above.
        dim (int): Dimension of the meshgrid. Defaults to 2.

    Returns:
        grid (np.ndarray): [dim, ...]
    """
    if len(args) == 0:
        # start is grid_size
        num = _to_tuple(start, dim=dim)
        start = (0, ) * dim
        stop = num
    elif len(args) == 1:
        # start is start, args[0] is stop, step is 1
        start = _to_tuple(start, dim=dim)
        stop = _to_tuple(args[0], dim=dim)
        num = tuple(stop[i] - start[i] for i in range(dim))
    elif len(args) == 2:
        # start is start, args[0] is stop, args[1] is num
        start = _to_tuple(start, dim=dim)  # Left-Top       eg: 12,0
        stop = _to_tuple(args[0], dim=dim)  # Right-Bottom   eg: 20,32
        num = _to_tuple(args[1], dim=dim)  # Target Size    eg: 32,124
    else:
        raise ValueError(f"len(args) should be 0, 1 or 2, but got {len(args)}")

    # PyTorch implement of np.linspace(start[i], stop[i], num[i], endpoint=False)
    axis_grid = []
    for i in range(dim):
        a, b, n = start[i], stop[i], num[i]
        g = torch.linspace(a, b, n + 1, dtype=torch.float32)[:n]
        axis_grid.append(g)
    grid = torch.meshgrid(*axis_grid, indexing="ij")  # dim x [W, H, D]
    grid = torch.stack(grid, dim=0)  # [dim, W, H, D]

    return grid

fastvideo.layers.rotary_embedding.get_nd_rotary_pos_embed

get_nd_rotary_pos_embed(rope_dim_list, start, *args, theta=10000.0, theta_rescale_factor: float | list[float] = 1.0, interpolation_factor: float | list[float] = 1.0, shard_dim: int = 0, sp_rank: int = 0, sp_world_size: int = 1, dtype: dtype = float32, start_frame: int = 0) -> tuple[Tensor, Tensor]

This is a n-d version of precompute_freqs_cis, which is a RoPE for tokens with n-d structure. Supports sequence parallelism by allowing sharding of a specific dimension.

Parameters:

Name Type Description Default
rope_dim_list list of int

Dimension of each rope. len(rope_dim_list) should equal to n. sum(rope_dim_list) should equal to head_dim of attention layer.

required
start int | tuple of int | list of int

If len(args) == 0, start is num; If len(args) == 1, start is start, args[0] is stop, step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num.

required
*args

See above.

()
theta float

Scaling factor for frequency computation. Defaults to 10000.0.

10000.0
theta_rescale_factor float

Rescale factor for theta. Defaults to 1.0.

1.0
interpolation_factor float

Factor to scale positions. Defaults to 1.0.

1.0
shard_dim int

Which dimension to shard for sequence parallelism. Defaults to 0.

0
sp_rank int

Rank in the sequence parallel group. Defaults to 0.

0
sp_world_size int

World size of the sequence parallel group. Defaults to 1.

1

Returns:

Type Description
tuple[Tensor, Tensor]

Tuple[torch.Tensor, torch.Tensor]: (cos, sin) tensors of shape [HW, D/2]

Source code in fastvideo/layers/rotary_embedding.py
def get_nd_rotary_pos_embed(
    rope_dim_list,
    start,
    *args,
    theta=10000.0,
    theta_rescale_factor: float | list[float] = 1.0,
    interpolation_factor: float | list[float] = 1.0,
    shard_dim: int = 0,
    sp_rank: int = 0,
    sp_world_size: int = 1,
    dtype: torch.dtype = torch.float32,
    start_frame: int = 0,
) -> tuple[torch.Tensor, torch.Tensor]:
    """
    This is a n-d version of precompute_freqs_cis, which is a RoPE for tokens with n-d structure.
    Supports sequence parallelism by allowing sharding of a specific dimension.

    Args:
        rope_dim_list (list of int): Dimension of each rope. len(rope_dim_list) should equal to n.
            sum(rope_dim_list) should equal to head_dim of attention layer.
        start (int | tuple of int | list of int): If len(args) == 0, start is num; If len(args) == 1, start is start,
            args[0] is stop, step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num.
        *args: See above.
        theta (float): Scaling factor for frequency computation. Defaults to 10000.0.
        theta_rescale_factor (float): Rescale factor for theta. Defaults to 1.0.
        interpolation_factor (float): Factor to scale positions. Defaults to 1.0.
        shard_dim (int): Which dimension to shard for sequence parallelism. Defaults to 0.
        sp_rank (int): Rank in the sequence parallel group. Defaults to 0.
        sp_world_size (int): World size of the sequence parallel group. Defaults to 1.

    Returns:
        Tuple[torch.Tensor, torch.Tensor]: (cos, sin) tensors of shape [HW, D/2]
    """
    # Get the full grid
    full_grid = get_meshgrid_nd(
        start, *args, dim=len(rope_dim_list))  # [3, W, H, D] / [2, W, H]

    if start_frame > 0:
        full_grid[0] += start_frame

    # Shard the grid if using sequence parallelism (sp_world_size > 1)
    assert shard_dim < len(
        rope_dim_list
    ), f"shard_dim {shard_dim} must be less than number of dimensions {len(rope_dim_list)}"
    if sp_world_size > 1:
        # Get the shape of the full grid
        grid_shape = list(full_grid.shape[1:])

        # Ensure the dimension to shard is divisible by sp_world_size
        assert grid_shape[shard_dim] % sp_world_size == 0, (
            f"Dimension {shard_dim} with size {grid_shape[shard_dim]} is not divisible "
            f"by sequence parallel world size {sp_world_size}")

        # Compute the start and end indices for this rank's shard
        shard_size = grid_shape[shard_dim] // sp_world_size
        start_idx = sp_rank * shard_size
        end_idx = (sp_rank + 1) * shard_size

        # Create slicing indices for each dimension
        slice_indices = [slice(None) for _ in range(len(grid_shape))]
        slice_indices[shard_dim] = slice(start_idx, end_idx)

        # Shard the grid
        # Update grid shape for the sharded dimension
        grid_shape[shard_dim] = grid_shape[shard_dim] // sp_world_size
        grid = torch.empty((len(rope_dim_list), ) + tuple(grid_shape),
                           dtype=full_grid.dtype)
        for i in range(len(rope_dim_list)):
            grid[i] = full_grid[i][tuple(slice_indices)]
    else:
        grid = full_grid

    if isinstance(theta_rescale_factor, int | float):
        theta_rescale_factor = [theta_rescale_factor] * len(rope_dim_list)
    elif isinstance(theta_rescale_factor,
                    list) and len(theta_rescale_factor) == 1:
        theta_rescale_factor = [theta_rescale_factor[0]] * len(rope_dim_list)
    assert len(theta_rescale_factor) == len(
        rope_dim_list
    ), "len(theta_rescale_factor) should equal to len(rope_dim_list)"

    if isinstance(interpolation_factor, int | float):
        interpolation_factor = [interpolation_factor] * len(rope_dim_list)
    elif isinstance(interpolation_factor,
                    list) and len(interpolation_factor) == 1:
        interpolation_factor = [interpolation_factor[0]] * len(rope_dim_list)
    assert len(interpolation_factor) == len(
        rope_dim_list
    ), "len(interpolation_factor) should equal to len(rope_dim_list)"

    # use 1/ndim of dimensions to encode grid_axis
    embs = []
    for i in range(len(rope_dim_list)):
        emb = get_1d_rotary_pos_embed(
            rope_dim_list[i],
            grid[i].reshape(-1),
            theta,
            theta_rescale_factor=theta_rescale_factor[i],
            interpolation_factor=interpolation_factor[i],
            dtype=dtype,
        )  # 2 x [WHD, rope_dim_list[i]]
        embs.append(emb)

    cos = torch.cat([emb[0] for emb in embs], dim=1)  # (WHD, D/2)
    sin = torch.cat([emb[1] for emb in embs], dim=1)  # (WHD, D/2)
    return cos, sin

fastvideo.layers.rotary_embedding.get_rotary_pos_embed

get_rotary_pos_embed(rope_sizes, hidden_size, heads_num, rope_dim_list, rope_theta, theta_rescale_factor=1.0, interpolation_factor=1.0, shard_dim: int = 0, dtype: dtype = float32, start_frame: int = 0) -> tuple[Tensor, Tensor]

Generate rotary positional embeddings for the given sizes.

Parameters:

Name Type Description Default
rope_sizes

Tuple of dimensions (t, h, w)

required
hidden_size

Hidden dimension size

required
heads_num

Number of attention heads

required
rope_dim_list

List of dimensions for each axis, or None

required
rope_theta

Base for frequency calculations

required
theta_rescale_factor

Rescale factor for theta. Defaults to 1.0

1.0
interpolation_factor

Factor to scale positions. Defaults to 1.0

1.0
shard_dim int

Which dimension to shard for sequence parallelism. Defaults to 0.

0

Returns:

Type Description
tuple[Tensor, Tensor]

Tuple of (cos, sin) tensors for rotary embeddings

Source code in fastvideo/layers/rotary_embedding.py
def get_rotary_pos_embed(
    rope_sizes,
    hidden_size,
    heads_num,
    rope_dim_list,
    rope_theta,
    theta_rescale_factor=1.0,
    interpolation_factor=1.0,
    shard_dim: int = 0,
    dtype: torch.dtype = torch.float32,
    start_frame: int = 0,
) -> tuple[torch.Tensor, torch.Tensor]:
    """
    Generate rotary positional embeddings for the given sizes.

    Args:
        rope_sizes: Tuple of dimensions (t, h, w)
        hidden_size: Hidden dimension size
        heads_num: Number of attention heads
        rope_dim_list: List of dimensions for each axis, or None
        rope_theta: Base for frequency calculations
        theta_rescale_factor: Rescale factor for theta. Defaults to 1.0
        interpolation_factor: Factor to scale positions. Defaults to 1.0
        shard_dim: Which dimension to shard for sequence parallelism. Defaults to 0.

    Returns:
        Tuple of (cos, sin) tensors for rotary embeddings
    """

    target_ndim = 3
    head_dim = hidden_size // heads_num

    if rope_dim_list is None:
        rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)]

    assert sum(
        rope_dim_list
    ) == head_dim, "sum(rope_dim_list) should equal to head_dim of attention layer"

    # Get SP info
    sp_group = get_sp_group()
    sp_rank = sp_group.rank_in_group
    sp_world_size = sp_group.world_size

    freqs_cos, freqs_sin = get_nd_rotary_pos_embed(
        rope_dim_list,
        rope_sizes,
        theta=rope_theta,
        theta_rescale_factor=theta_rescale_factor,
        interpolation_factor=interpolation_factor,
        shard_dim=shard_dim,
        sp_rank=sp_rank,
        sp_world_size=sp_world_size,
        dtype=dtype,
        start_frame=start_frame,
    )
    return freqs_cos, freqs_sin