Skip to content

image_encoding

Image and video encoding stages for diffusion pipelines.

This module contains implementations of encoding stages for diffusion pipelines: - ImageEncodingStage: Encodes images using image encoders (e.g., CLIP) - RefImageEncodingStage: Encodes reference image for Wan2.1 control pipeline - ImageVAEEncodingStage: Encodes images to latent space using VAE for I2V generation - VideoVAEEncodingStage: Encodes videos to latent space using VAE for V2V and control tasks

Classes

fastvideo.pipelines.stages.image_encoding.HYWorldImageEncodingStage

HYWorldImageEncodingStage(image_encoder=None, image_processor=None, vae=None)

Bases: ImageEncodingStage

Stage for encoding image prompts into embeddings for HYWorld models.

Uses SigLIP (or other vision encoder) to encode reference images for I2V tasks. Also encodes reference image with VAE for conditional latent.

Source code in fastvideo/pipelines/stages/image_encoding.py
def __init__(self, image_encoder=None, image_processor=None, vae=None):
    super().__init__(image_encoder=image_encoder,
                     image_processor=image_processor)
    self.vae = vae

Functions

fastvideo.pipelines.stages.image_encoding.HYWorldImageEncodingStage.forward
forward(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> ForwardBatch

Encode the prompt into image encoder hidden states and VAE latents.

For I2V
  • encodes the reference image using SigLIP → image_embeds
  • encodes the reference image using VAE → image_latent (expanded to full temporal dim)

For T2V: creates zero embeddings

The image_latent is expanded to match the full temporal dimension of the video latent, following the original HunyuanVideo-1.5 implementation where: - First frame contains the encoded reference image - All other frames are zeros - Mask channel is 1 for first frame, 0 for rest

Source code in fastvideo/pipelines/stages/image_encoding.py
@torch.no_grad()
def forward(self, batch: ForwardBatch,
            fastvideo_args: FastVideoArgs) -> ForwardBatch:
    """
    Encode the prompt into image encoder hidden states and VAE latents.

    For I2V: 
        - encodes the reference image using SigLIP → image_embeds
        - encodes the reference image using VAE → image_latent (expanded to full temporal dim)
    For T2V: creates zero embeddings

    The image_latent is expanded to match the full temporal dimension of the video latent,
    following the original HunyuanVideo-1.5 implementation where:
    - First frame contains the encoded reference image
    - All other frames are zeros
    - Mask channel is 1 for first frame, 0 for rest
    """
    device = get_local_torch_device()

    # Default vision embed dimensions for HunyuanVideo1.5/HYWorld
    num_vision_tokens = 729  # (384/14)^2 for SigLIP
    vision_dim = 1152  # SigLIP hidden size

    # Get temporal dimension from raw_latent_shape (set by LatentPreparationStage)
    raw_latent_shape = list(batch.raw_latent_shape)
    latent_channels = raw_latent_shape[1]
    latent_temporal = raw_latent_shape[2]  # T dimension
    latent_height = raw_latent_shape[3]
    latent_width = raw_latent_shape[4]

    vae_dtype = PRECISION_TO_TYPE[
        fastvideo_args.pipeline_config.vae_precision]
    vae_autocast_enabled = (
        vae_dtype != torch.float32) and not fastvideo_args.disable_autocast

    if batch.pil_image is None:
        # T2V case: create zero embeddings for image_embeds
        batch.image_embeds = [
            torch.zeros(1, num_vision_tokens, vision_dim, device=device)
        ]
        # T2V: create zero latents for image_latent with full temporal dimension
        # Shape: [B, latent_channels + 1 (mask channel), T, H, W]
        batch.image_latent = torch.zeros(1,
                                         latent_channels + 1,
                                         latent_temporal,
                                         latent_height,
                                         latent_width,
                                         device=device)
    else:
        image = batch.pil_image

        # 1. Encode with SigLIP for image_embeds
        if self.image_encoder is not None:
            self.image_encoder = self.image_encoder.to(device)

            # Get model dtype for proper precision matching (HY-WorldPlay uses fp16)
            model_dtype = next(self.image_encoder.parameters()).dtype

            # Preprocess image for SigLIP
            # Convert to numpy and resize to target resolution (matching HY-WorldPlay)
            import numpy as np

            if not isinstance(image, np.ndarray):
                image_np = np.array(image)
            else:
                image_np = image

            # Resize to target resolution BEFORE SigLIP preprocessing
            from fastvideo.models.dits.hyworld.data_utils import resize_and_center_crop
            image_np = resize_and_center_crop(image_np,
                                              target_width=batch.width,
                                              target_height=batch.height)

            image_inputs = self.image_processor.preprocess(
                images=image_np, return_tensors="pt").to(
                    device=device, dtype=model_dtype)  # Match model dtype!
            pixel_values = image_inputs['pixel_values']

            with set_forward_context(current_timestep=0,
                                     attn_metadata=None):
                outputs = self.image_encoder(pixel_values=pixel_values)
                image_embeds = outputs.last_hidden_state
            batch.image_embeds = [image_embeds]

            if fastvideo_args.image_encoder_cpu_offload:
                self.image_encoder.to('cpu')
        else:
            batch.image_embeds = [
                torch.zeros(1, num_vision_tokens, vision_dim, device=device)
            ]

        # 2. Encode with VAE for image_latent (conditional latent for I2V)
        if self.vae is not None:

            from torchvision import transforms
            from PIL import Image as PILImage
            import numpy as np
            # Preprocess image for VAE
            if isinstance(image, np.ndarray):
                image = PILImage.fromarray(image)

            # Get target size from batch
            origin_size = image.size

            target_height, target_width = batch.height, batch.width
            original_width, original_height = origin_size

            scale_factor = max(target_width / original_width,
                               target_height / original_height)
            resize_width = int(round(original_width * scale_factor))
            resize_height = int(round(original_height * scale_factor))

            ref_image_transform = transforms.Compose([
                transforms.Resize(
                    (resize_height, resize_width),
                    interpolation=transforms.InterpolationMode.LANCZOS),
                transforms.CenterCrop((target_height, target_width)),
                transforms.ToTensor(),
                transforms.Normalize([0.5], [0.5])
            ])
            ref_images_pixel_values = ref_image_transform(image)
            ref_images_pixel_values = (ref_images_pixel_values.unsqueeze(
                0).unsqueeze(2).to(device))

            # Encode with VAE
            self.vae = self.vae.to(device)
            with torch.autocast(device_type="cuda",
                                dtype=vae_dtype,
                                enabled=vae_autocast_enabled):
                cond_latents = self.vae.encode(
                    ref_images_pixel_values).mode()
                cond_latents.mul_(self.vae.config.scaling_factor)

            # cond_latents shape: [1, 32, 1, H//compression, W//compression]
            # Expand to full temporal dimension: [1, 32, T, H, W]
            # First frame contains the encoded image, rest are zeros
            expanded_latent = cond_latents.repeat(1, 1, latent_temporal, 1,
                                                  1)
            expanded_latent[:, :,
                            1:, :, :] = 0.0  # Zero out all frames except first

            # Create mask: [1, 1, T, H, W]
            # First frame mask = 1 (conditional), rest = 0
            mask = torch.zeros(1,
                               1,
                               latent_temporal,
                               latent_height,
                               latent_width,
                               device=device,
                               dtype=expanded_latent.dtype)
            mask[:, :, 0, :, :] = 1.0  # First frame is conditional

            # Concatenate latent and mask: [1, 33, T, H, W]
            batch.image_latent = torch.cat([expanded_latent, mask], dim=1)

            if fastvideo_args.vae_cpu_offload:
                self.vae.to('cpu')
        else:
            # No VAE available, create zero latents with full temporal dimension
            batch.image_latent = torch.zeros(1,
                                             latent_channels + 1,
                                             latent_temporal,
                                             latent_height,
                                             latent_width,
                                             device=device)

    # Initialize video latent placeholder
    raw_latent_shape[1] = 1
    batch.video_latent = torch.zeros(tuple(raw_latent_shape), device=device)

    return batch
fastvideo.pipelines.stages.image_encoding.HYWorldImageEncodingStage.verify_input
verify_input(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult

Verify image encoding stage inputs.

Source code in fastvideo/pipelines/stages/image_encoding.py
def verify_input(self, batch: ForwardBatch,
                 fastvideo_args: FastVideoArgs) -> VerificationResult:
    """Verify image encoding stage inputs."""
    return VerificationResult()

fastvideo.pipelines.stages.image_encoding.Hy15ImageEncodingStage

Hy15ImageEncodingStage(image_encoder, image_processor)

Bases: ImageEncodingStage

Stage for encoding image prompts into embeddings for HunyuanVideo1.5 models.

Source code in fastvideo/pipelines/stages/image_encoding.py
def __init__(self, image_encoder, image_processor) -> None:
    """
    Initialize the prompt encoding stage.

    Args:
        enable_logging: Whether to enable logging for this stage.
        is_secondary: Whether this is a secondary image encoder.
    """
    super().__init__()
    self.image_processor = image_processor
    self.image_encoder = image_encoder

Functions

fastvideo.pipelines.stages.image_encoding.Hy15ImageEncodingStage.forward
forward(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> ForwardBatch

Encode the prompt into image encoder hidden states.

Source code in fastvideo/pipelines/stages/image_encoding.py
def forward(self, batch: ForwardBatch,
            fastvideo_args: FastVideoArgs) -> ForwardBatch:
    """
    Encode the prompt into image encoder hidden states.
    """
    if batch.pil_image is None:
        batch.image_embeds = [
            torch.zeros(1, 729, 1152, device=get_local_torch_device())
        ]

    raw_latent_shape = list(batch.raw_latent_shape)
    raw_latent_shape[1] = 1
    batch.video_latent = torch.zeros(tuple(raw_latent_shape),
                                     device=get_local_torch_device())
    return batch
fastvideo.pipelines.stages.image_encoding.Hy15ImageEncodingStage.verify_input
verify_input(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult

Verify image encoding stage inputs.

Source code in fastvideo/pipelines/stages/image_encoding.py
def verify_input(self, batch: ForwardBatch,
                 fastvideo_args: FastVideoArgs) -> VerificationResult:
    """Verify image encoding stage inputs."""
    return VerificationResult()

fastvideo.pipelines.stages.image_encoding.ImageEncodingStage

ImageEncodingStage(image_encoder, image_processor)

Bases: PipelineStage

Stage for encoding image prompts into embeddings for diffusion models.

This stage handles the encoding of image prompts into the embedding space expected by the diffusion model.

Initialize the prompt encoding stage.

Parameters:

Name Type Description Default
enable_logging

Whether to enable logging for this stage.

required
is_secondary

Whether this is a secondary image encoder.

required
Source code in fastvideo/pipelines/stages/image_encoding.py
def __init__(self, image_encoder, image_processor) -> None:
    """
    Initialize the prompt encoding stage.

    Args:
        enable_logging: Whether to enable logging for this stage.
        is_secondary: Whether this is a secondary image encoder.
    """
    super().__init__()
    self.image_processor = image_processor
    self.image_encoder = image_encoder

Functions

fastvideo.pipelines.stages.image_encoding.ImageEncodingStage.forward
forward(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> ForwardBatch

Encode the prompt into image encoder hidden states.

Parameters:

Name Type Description Default
batch ForwardBatch

The current batch information.

required
fastvideo_args FastVideoArgs

The inference arguments.

required

Returns:

Type Description
ForwardBatch

The batch with encoded prompt embeddings.

Source code in fastvideo/pipelines/stages/image_encoding.py
@torch.no_grad()
def forward(
    self,
    batch: ForwardBatch,
    fastvideo_args: FastVideoArgs,
) -> ForwardBatch:
    """
    Encode the prompt into image encoder hidden states.

    Args:
        batch: The current batch information.
        fastvideo_args: The inference arguments.

    Returns:
        The batch with encoded prompt embeddings.
    """
    self.image_encoder = self.image_encoder.to(get_local_torch_device())

    image = batch.pil_image

    image_inputs = self.image_processor(
        images=image, return_tensors="pt").to(get_local_torch_device())
    with set_forward_context(current_timestep=0, attn_metadata=None):
        outputs = self.image_encoder(**image_inputs)
        image_embeds = outputs.last_hidden_state

    batch.image_embeds.append(image_embeds)

    if fastvideo_args.image_encoder_cpu_offload:
        self.image_encoder.to('cpu')

    return batch
fastvideo.pipelines.stages.image_encoding.ImageEncodingStage.verify_input
verify_input(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult

Verify image encoding stage inputs.

Source code in fastvideo/pipelines/stages/image_encoding.py
def verify_input(self, batch: ForwardBatch,
                 fastvideo_args: FastVideoArgs) -> VerificationResult:
    """Verify image encoding stage inputs."""
    result = VerificationResult()
    result.add_check("pil_image", batch.pil_image, V.not_none)
    result.add_check("image_embeds", batch.image_embeds, V.is_list)
    return result
fastvideo.pipelines.stages.image_encoding.ImageEncodingStage.verify_output
verify_output(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult

Verify image encoding stage outputs.

Source code in fastvideo/pipelines/stages/image_encoding.py
def verify_output(self, batch: ForwardBatch,
                  fastvideo_args: FastVideoArgs) -> VerificationResult:
    """Verify image encoding stage outputs."""
    result = VerificationResult()
    result.add_check("image_embeds", batch.image_embeds,
                     V.list_of_tensors_dims(3))
    return result

fastvideo.pipelines.stages.image_encoding.ImageVAEEncodingStage

ImageVAEEncodingStage(vae: ParallelTiledVAE)

Bases: PipelineStage

Stage for encoding image pixel representations into latent space.

This stage handles the encoding of image pixel representations into the final input format (e.g., latents) for image-to-video generation.

Source code in fastvideo/pipelines/stages/image_encoding.py
def __init__(self, vae: ParallelTiledVAE) -> None:
    self.vae: ParallelTiledVAE = vae

Functions

fastvideo.pipelines.stages.image_encoding.ImageVAEEncodingStage.forward
forward(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> ForwardBatch

Encode pixel representations into latent space.

Parameters:

Name Type Description Default
batch ForwardBatch

The current batch information.

required
fastvideo_args FastVideoArgs

The inference arguments.

required

Returns:

Type Description
ForwardBatch

The batch with encoded outputs.

Source code in fastvideo/pipelines/stages/image_encoding.py
def forward(
    self,
    batch: ForwardBatch,
    fastvideo_args: FastVideoArgs,
) -> ForwardBatch:
    """
    Encode pixel representations into latent space.

    Args:
        batch: The current batch information.
        fastvideo_args: The inference arguments.

    Returns:
        The batch with encoded outputs.
    """
    assert batch.pil_image is not None
    if fastvideo_args.mode == ExecutionMode.INFERENCE:
        assert batch.pil_image is not None and isinstance(
            batch.pil_image, PIL.Image.Image)
        assert batch.height is not None and isinstance(batch.height, int)
        assert batch.width is not None and isinstance(batch.width, int)
        assert batch.num_frames is not None and isinstance(
            batch.num_frames, int)
        height = batch.height
        width = batch.width
        num_frames = batch.num_frames
    elif fastvideo_args.mode == ExecutionMode.PREPROCESS:
        assert batch.pil_image is not None and isinstance(
            batch.pil_image, torch.Tensor)
        assert batch.height is not None and isinstance(batch.height, list)
        assert batch.width is not None and isinstance(batch.width, list)
        assert batch.num_frames is not None and isinstance(
            batch.num_frames, list)
        num_frames = batch.num_frames[0]
        height = batch.height[0]
        width = batch.width[0]

    self.vae = self.vae.to(get_local_torch_device())

    # Process single image for I2V
    latent_height = height // self.vae.spatial_compression_ratio
    latent_width = width // self.vae.spatial_compression_ratio
    image = batch.pil_image
    image = self.preprocess(
        image,
        vae_scale_factor=self.vae.spatial_compression_ratio,
        height=height,
        width=width).to(get_local_torch_device(), dtype=torch.float32)

    # (B, C, H, W) -> (B, C, 1, H, W)
    image = image.unsqueeze(2)

    video_condition = torch.cat([
        image,
        image.new_zeros(image.shape[0], image.shape[1], num_frames - 1,
                        image.shape[3], image.shape[4])
    ],
                                dim=2)
    video_condition = video_condition.to(device=get_local_torch_device(),
                                         dtype=torch.float32)

    # Setup VAE precision
    vae_dtype = PRECISION_TO_TYPE[
        fastvideo_args.pipeline_config.vae_precision]
    vae_autocast_enabled = (
        vae_dtype != torch.float32) and not fastvideo_args.disable_autocast

    # Encode Image
    with torch.autocast(device_type="cuda",
                        dtype=vae_dtype,
                        enabled=vae_autocast_enabled):
        if fastvideo_args.pipeline_config.vae_tiling:
            self.vae.enable_tiling()
        # if fastvideo_args.vae_sp:
        #     self.vae.enable_parallel()
        if not vae_autocast_enabled:
            video_condition = video_condition.to(vae_dtype)
        encoder_output = self.vae.encode(video_condition)

    if fastvideo_args.mode == ExecutionMode.PREPROCESS:
        latent_condition = encoder_output.mean
    else:
        generator = batch.generator
        if generator is None:
            raise ValueError("Generator must be provided")
        latent_condition = self.retrieve_latents(encoder_output, generator)

    # Apply shifting if needed
    if (hasattr(self.vae, "shift_factor")
            and self.vae.shift_factor is not None):
        if isinstance(self.vae.shift_factor, torch.Tensor):
            latent_condition -= self.vae.shift_factor.to(
                latent_condition.device, latent_condition.dtype)
        else:
            latent_condition -= self.vae.shift_factor

    if isinstance(self.vae.scaling_factor, torch.Tensor):
        latent_condition = latent_condition * self.vae.scaling_factor.to(
            latent_condition.device, latent_condition.dtype)
    else:
        latent_condition = latent_condition * self.vae.scaling_factor

    if fastvideo_args.mode == ExecutionMode.PREPROCESS:
        batch.image_latent = latent_condition
    else:
        mask_lat_size = torch.ones(1, 1, num_frames, latent_height,
                                   latent_width)
        mask_lat_size[:, :, list(range(1, num_frames))] = 0
        first_frame_mask = mask_lat_size[:, :, 0:1]
        first_frame_mask = torch.repeat_interleave(
            first_frame_mask,
            dim=2,
            repeats=self.vae.temporal_compression_ratio)
        mask_lat_size = torch.concat(
            [first_frame_mask, mask_lat_size[:, :, 1:, :]], dim=2)
        mask_lat_size = mask_lat_size.view(
            1, -1, self.vae.temporal_compression_ratio, latent_height,
            latent_width)
        mask_lat_size = mask_lat_size.transpose(1, 2)
        mask_lat_size = mask_lat_size.to(latent_condition.device)

        batch.image_latent = torch.concat([mask_lat_size, latent_condition],
                                          dim=1)

    # Offload models if needed
    if hasattr(self, 'maybe_free_model_hooks'):
        self.maybe_free_model_hooks()

    self.vae.to("cpu")

    return batch
fastvideo.pipelines.stages.image_encoding.ImageVAEEncodingStage.verify_input
verify_input(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult

Verify encoding stage inputs.

Source code in fastvideo/pipelines/stages/image_encoding.py
def verify_input(self, batch: ForwardBatch,
                 fastvideo_args: FastVideoArgs) -> VerificationResult:
    """Verify encoding stage inputs."""
    result = VerificationResult()
    result.add_check("generator", batch.generator,
                     V.generator_or_list_generators)
    if fastvideo_args.mode == ExecutionMode.PREPROCESS:
        result.add_check("height", batch.height, V.list_not_empty)
        result.add_check("width", batch.width, V.list_not_empty)
        result.add_check("num_frames", batch.num_frames, V.list_not_empty)
    else:
        result.add_check("height", batch.height, V.positive_int)
        result.add_check("width", batch.width, V.positive_int)
        result.add_check("num_frames", batch.num_frames, V.positive_int)
    return result
fastvideo.pipelines.stages.image_encoding.ImageVAEEncodingStage.verify_output
verify_output(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult

Verify encoding stage outputs.

Source code in fastvideo/pipelines/stages/image_encoding.py
def verify_output(self, batch: ForwardBatch,
                  fastvideo_args: FastVideoArgs) -> VerificationResult:
    """Verify encoding stage outputs."""
    result = VerificationResult()
    result.add_check("image_latent", batch.image_latent,
                     [V.is_tensor, V.with_dims(5)])
    return result

fastvideo.pipelines.stages.image_encoding.RefImageEncodingStage

RefImageEncodingStage(image_encoder, image_processor)

Bases: ImageEncodingStage

Stage for encoding reference image prompts into embeddings for Wan2.1 Control models.

This stage extends ImageEncodingStage with specialized preprocessing for reference images.

Source code in fastvideo/pipelines/stages/image_encoding.py
def __init__(self, image_encoder, image_processor) -> None:
    """
    Initialize the prompt encoding stage.

    Args:
        enable_logging: Whether to enable logging for this stage.
        is_secondary: Whether this is a secondary image encoder.
    """
    super().__init__()
    self.image_processor = image_processor
    self.image_encoder = image_encoder

Functions

fastvideo.pipelines.stages.image_encoding.RefImageEncodingStage.forward
forward(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> ForwardBatch

Encode the prompt into image encoder hidden states.

Parameters:

Name Type Description Default
batch ForwardBatch

The current batch information.

required
fastvideo_args FastVideoArgs

The inference arguments.

required

Returns:

Type Description
ForwardBatch

The batch with encoded prompt embeddings.

Source code in fastvideo/pipelines/stages/image_encoding.py
@torch.no_grad()
def forward(
    self,
    batch: ForwardBatch,
    fastvideo_args: FastVideoArgs,
) -> ForwardBatch:
    """
    Encode the prompt into image encoder hidden states.

    Args:
        batch: The current batch information.
        fastvideo_args: The inference arguments.

    Returns:
        The batch with encoded prompt embeddings.
    """
    self.image_encoder = self.image_encoder.to(get_local_torch_device())

    image = batch.pil_image
    if image is None:
        image = create_default_image()
    # Preprocess reference image for CLIP encoder
    image_tensor = preprocess_reference_image_for_clip(
        image, get_local_torch_device())

    image_inputs = self.image_processor(images=image_tensor,
                                        return_tensors="pt").to(
                                            get_local_torch_device())
    with set_forward_context(current_timestep=0, attn_metadata=None):
        outputs = self.image_encoder(**image_inputs)
        image_embeds = outputs.last_hidden_state
    batch.image_embeds.append(image_embeds)

    if batch.pil_image is None:
        batch.image_embeds = [
            torch.zeros_like(x) for x in batch.image_embeds
        ]

    return batch

fastvideo.pipelines.stages.image_encoding.VideoVAEEncodingStage

VideoVAEEncodingStage(vae: ParallelTiledVAE)

Bases: ImageVAEEncodingStage

Stage for encoding video pixel representations into latent space.

This stage handles the encoding of video pixel representations for video-to-video generation and control. Inherits from ImageVAEEncodingStage to reuse common functionality.

Source code in fastvideo/pipelines/stages/image_encoding.py
def __init__(self, vae: ParallelTiledVAE) -> None:
    self.vae: ParallelTiledVAE = vae

Functions

fastvideo.pipelines.stages.image_encoding.VideoVAEEncodingStage.forward
forward(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> ForwardBatch

Encode video pixel representations into latent space.

Parameters:

Name Type Description Default
batch ForwardBatch

The current batch information.

required
fastvideo_args FastVideoArgs

The inference arguments.

required

Returns:

Type Description
ForwardBatch

The batch with encoded outputs.

Source code in fastvideo/pipelines/stages/image_encoding.py
def forward(
    self,
    batch: ForwardBatch,
    fastvideo_args: FastVideoArgs,
) -> ForwardBatch:
    """
    Encode video pixel representations into latent space.

    Args:
        batch: The current batch information.
        fastvideo_args: The inference arguments.

    Returns:
        The batch with encoded outputs.
    """
    assert batch.video_latent is not None, "Video latent input is required for VideoVAEEncodingStage"

    if fastvideo_args.mode == ExecutionMode.INFERENCE:
        assert batch.height is not None and isinstance(batch.height, int)
        assert batch.width is not None and isinstance(batch.width, int)
        assert batch.num_frames is not None and isinstance(
            batch.num_frames, int)
        height = batch.height
        width = batch.width
        num_frames = batch.num_frames
    elif fastvideo_args.mode == ExecutionMode.PREPROCESS:
        assert batch.height is not None and isinstance(batch.height, list)
        assert batch.width is not None and isinstance(batch.width, list)
        assert batch.num_frames is not None and isinstance(
            batch.num_frames, list)
        num_frames = batch.num_frames[0]
        height = batch.height[0]
        width = batch.width[0]

    self.vae = self.vae.to(get_local_torch_device())

    # Prepare video tensor from control video
    video_condition = self._prepare_control_video_tensor(
        batch.video_latent, num_frames, height,
        width).to(get_local_torch_device(), dtype=torch.float32)

    # Setup VAE precision
    vae_dtype = PRECISION_TO_TYPE[
        fastvideo_args.pipeline_config.vae_precision]
    vae_autocast_enabled = (
        vae_dtype != torch.float32) and not fastvideo_args.disable_autocast

    # Encode control video
    with torch.autocast(device_type="cuda",
                        dtype=vae_dtype,
                        enabled=vae_autocast_enabled):
        if fastvideo_args.pipeline_config.vae_tiling:
            self.vae.enable_tiling()
        if not vae_autocast_enabled:
            video_condition = video_condition.to(vae_dtype)
        encoder_output = self.vae.encode(video_condition)

    generator = batch.generator
    if generator is None:
        raise ValueError("Generator must be provided")
    latent_condition = self.retrieve_latents(encoder_output, generator)

    if (hasattr(self.vae, "shift_factor")
            and self.vae.shift_factor is not None):
        if isinstance(self.vae.shift_factor, torch.Tensor):
            latent_condition -= self.vae.shift_factor.to(
                latent_condition.device, latent_condition.dtype)
        else:
            latent_condition -= self.vae.shift_factor

    if isinstance(self.vae.scaling_factor, torch.Tensor):
        latent_condition = latent_condition * self.vae.scaling_factor.to(
            latent_condition.device, latent_condition.dtype)
    else:
        latent_condition = latent_condition * self.vae.scaling_factor

    batch.video_latent = latent_condition

    # Offload models if needed
    if hasattr(self, 'maybe_free_model_hooks'):
        self.maybe_free_model_hooks()

    self.vae.to("cpu")

    return batch
fastvideo.pipelines.stages.image_encoding.VideoVAEEncodingStage.verify_input
verify_input(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult

Verify video encoding stage inputs.

Source code in fastvideo/pipelines/stages/image_encoding.py
def verify_input(self, batch: ForwardBatch,
                 fastvideo_args: FastVideoArgs) -> VerificationResult:
    """Verify video encoding stage inputs."""
    result = VerificationResult()
    result.add_check("video_latent", batch.video_latent, V.not_none)
    result.add_check("generator", batch.generator,
                     V.generator_or_list_generators)
    if fastvideo_args.mode == ExecutionMode.PREPROCESS:
        result.add_check("height", batch.height, V.list_not_empty)
        result.add_check("width", batch.width, V.list_not_empty)
        result.add_check("num_frames", batch.num_frames, V.list_not_empty)
    else:
        result.add_check("height", batch.height, V.positive_int)
        result.add_check("width", batch.width, V.positive_int)
        result.add_check("num_frames", batch.num_frames, V.positive_int)
    return result
fastvideo.pipelines.stages.image_encoding.VideoVAEEncodingStage.verify_output
verify_output(batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult

Verify video encoding stage outputs.

Source code in fastvideo/pipelines/stages/image_encoding.py
def verify_output(self, batch: ForwardBatch,
                  fastvideo_args: FastVideoArgs) -> VerificationResult:
    """Verify video encoding stage outputs."""
    result = VerificationResult()
    result.add_check("video_latent", batch.video_latent,
                     [V.is_tensor, V.with_dims(5)])
    return result

Functions