Skip to content

visual_embedding

Classes

fastvideo.layers.visual_embedding.ModulateProjection

ModulateProjection(hidden_size: int, factor: int = 2, act_layer: str = 'silu', dtype: dtype | None = None, prefix: str = '')

Bases: Module

Modulation layer for DiT blocks.

Source code in fastvideo/layers/visual_embedding.py
def __init__(
    self,
    hidden_size: int,
    factor: int = 2,
    act_layer: str = "silu",
    dtype: torch.dtype | None = None,
    prefix: str = "",
):
    super().__init__()
    self.factor = factor
    self.hidden_size = hidden_size
    self.linear = ReplicatedLinear(hidden_size,
                                   hidden_size * factor,
                                   bias=True,
                                   params_dtype=dtype)
    self.act = get_act_fn(act_layer)

fastvideo.layers.visual_embedding.PatchEmbed

PatchEmbed(patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True, bias=True, dtype=None, prefix: str = '')

Bases: Module

2D Image to Patch Embedding

Image to Patch Embedding using Conv2d

A convolution based approach to patchifying a 2D image w/ embedding projection.

Based on the impl in https://github.com/google-research/vision_transformer

Hacked together by / Copyright 2020 Ross Wightman

Remove the _assert function in forward function to be compatible with multi-resolution images.

Source code in fastvideo/layers/visual_embedding.py
def __init__(self,
             patch_size=16,
             in_chans=3,
             embed_dim=768,
             norm_layer=None,
             flatten=True,
             bias=True,
             dtype=None,
             prefix: str = ""):
    super().__init__()
    # Convert patch_size to 2-tuple
    if isinstance(patch_size, list | tuple):
        if len(patch_size) == 1:
            patch_size = (patch_size[0], patch_size[0])
    else:
        patch_size = (patch_size, patch_size)

    self.patch_size = patch_size
    self.flatten = flatten

    self.proj = nn.Conv3d(in_chans,
                          embed_dim,
                          kernel_size=patch_size,
                          stride=patch_size,
                          bias=bias,
                          dtype=dtype)
    self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()

fastvideo.layers.visual_embedding.TimestepEmbedder

TimestepEmbedder(hidden_size, act_layer='silu', frequency_embedding_size=256, max_period=10000, dtype=None, freq_dtype=float32, prefix: str = '')

Bases: Module

Embeds scalar timesteps into vector representations.

Source code in fastvideo/layers/visual_embedding.py
def __init__(
    self,
    hidden_size,
    act_layer="silu",
    frequency_embedding_size=256,
    max_period=10000,
    dtype=None,
    freq_dtype=torch.float32,
    prefix: str = "",
):
    super().__init__()
    self.frequency_embedding_size = frequency_embedding_size
    self.max_period = max_period

    self.mlp = MLP(frequency_embedding_size,
                   hidden_size,
                   hidden_size,
                   act_type=act_layer,
                   dtype=dtype)
    self.freq_dtype = freq_dtype

Functions

fastvideo.layers.visual_embedding.get_timestep_embedding

get_timestep_embedding(timesteps: Tensor, embedding_dim: int, flip_sin_to_cos: bool = False, downscale_freq_shift: float = 1, scale: float = 1, max_period: int = 10000) -> Tensor

This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings. Args timesteps (torch.Tensor): a 1-D Tensor of N indices, one per batch element. These may be fractional. embedding_dim (int): the dimension of the output. flip_sin_to_cos (bool): Whether the embedding order should be cos, sin (if True) or sin, cos (if False) downscale_freq_shift (float): Controls the delta between frequencies between dimensions scale (float): Scaling factor applied to the embeddings. max_period (int): Controls the maximum frequency of the embeddings Returns torch.Tensor: an [N x dim] Tensor of positional embeddings.

Source code in fastvideo/layers/visual_embedding.py
def get_timestep_embedding(
    timesteps: torch.Tensor,
    embedding_dim: int,
    flip_sin_to_cos: bool = False,
    downscale_freq_shift: float = 1,
    scale: float = 1,
    max_period: int = 10000,
) -> torch.Tensor:
    """
    This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
    Args
        timesteps (torch.Tensor):
            a 1-D Tensor of N indices, one per batch element. These may be fractional.
        embedding_dim (int):
            the dimension of the output.
        flip_sin_to_cos (bool):
            Whether the embedding order should be `cos, sin` (if True) or `sin, cos` (if False)
        downscale_freq_shift (float):
            Controls the delta between frequencies between dimensions
        scale (float):
            Scaling factor applied to the embeddings.
        max_period (int):
            Controls the maximum frequency of the embeddings
    Returns
        torch.Tensor: an [N x dim] Tensor of positional embeddings.
    """
    assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"

    half_dim = embedding_dim // 2
    exponent = -math.log(max_period) * torch.arange(
        start=0, end=half_dim, dtype=torch.float32, device=timesteps.device)
    exponent = exponent / (half_dim - downscale_freq_shift)

    emb = torch.exp(exponent)
    emb = timesteps[:, None].float() * emb[None, :]

    # scale embeddings
    emb = scale * emb

    # concat sine and cosine embeddings
    emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)

    # flip sine and cosine embeddings
    if flip_sin_to_cos:
        emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)

    # zero pad
    if embedding_dim % 2 == 1:
        emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
    return emb

fastvideo.layers.visual_embedding.timestep_embedding

timestep_embedding(t: Tensor, dim: int, max_period: int = 10000, dtype: dtype = float32) -> Tensor

Create sinusoidal timestep embeddings.

Parameters:

Name Type Description Default
t Tensor

Tensor of shape [B] with timesteps

required
dim int

Embedding dimension

required
max_period int

Controls the minimum frequency of the embeddings

10000

Returns:

Type Description
Tensor

Tensor of shape [B, dim] with embeddings

Source code in fastvideo/layers/visual_embedding.py
def timestep_embedding(t: torch.Tensor,
                       dim: int,
                       max_period: int = 10000,
                       dtype: torch.dtype = torch.float32) -> torch.Tensor:
    """
    Create sinusoidal timestep embeddings.

    Args:
        t: Tensor of shape [B] with timesteps
        dim: Embedding dimension
        max_period: Controls the minimum frequency of the embeddings

    Returns:
        Tensor of shape [B, dim] with embeddings
    """
    half = dim // 2
    freqs = torch.exp(-math.log(max_period) *
                      torch.arange(start=0, end=half, dtype=dtype) /
                      half).to(device=t.device)
    args = t[:, None].float() * freqs[None]
    embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
    if dim % 2:
        embedding = torch.cat(
            [embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
    return embedding

fastvideo.layers.visual_embedding.unpatchify

unpatchify(x, t, h, w, patch_size, channels) -> Tensor

Convert patched representation back to image space.

Parameters:

Name Type Description Default
x

Tensor of shape [B, THW, CP_tP_h*P_w]

required
t, h, w

Temporal and spatial dimensions

required

Returns:

Type Description
Tensor

Unpatchified tensor of shape [B, C, TP_t, HP_h, W*P_w]

Source code in fastvideo/layers/visual_embedding.py
def unpatchify(x, t, h, w, patch_size, channels) -> torch.Tensor:
    """
    Convert patched representation back to image space.

    Args:
        x: Tensor of shape [B, T*H*W, C*P_t*P_h*P_w]
        t, h, w: Temporal and spatial dimensions

    Returns:
        Unpatchified tensor of shape [B, C, T*P_t, H*P_h, W*P_w]
    """
    assert x.ndim == 3, f"x.ndim: {x.ndim}"
    assert len(patch_size) == 3, f"patch_size: {patch_size}"
    assert t * h * w == x.shape[
        1], f"t * h * w: {t * h * w}, x.shape[1]: {x.shape[1]}"
    c = channels
    pt, ph, pw = patch_size

    x = x.reshape(shape=(x.shape[0], t, h, w, c, pt, ph, pw))
    x = torch.einsum("nthwcopq->nctohpwq", x)
    imgs = x.reshape(shape=(x.shape[0], c, t * pt, h * ph, w * pw))

    return imgs