Skip to content

test_inference_similarity

Classes

Functions

fastvideo.tests.ssim.test_inference_similarity.test_i2v_inference_similarity

test_i2v_inference_similarity(prompt, ATTENTION_BACKEND, model_id)

Test that runs inference with different parameters and compares the output to reference videos using SSIM.

Source code in fastvideo/tests/ssim/test_inference_similarity.py
@pytest.mark.parametrize("prompt", I2V_TEST_PROMPTS)
@pytest.mark.parametrize("ATTENTION_BACKEND", ["FLASH_ATTN"])
@pytest.mark.parametrize("model_id", list(I2V_MODEL_TO_PARAMS.keys()))
def test_i2v_inference_similarity(prompt, ATTENTION_BACKEND, model_id):
    """
    Test that runs inference with different parameters and compares the output
    to reference videos using SSIM.
    """
    assert len(I2V_TEST_PROMPTS) == len(I2V_IMAGE_PATHS), "Expect number of prompts equal to number of images"
    with _attention_backend(ATTENTION_BACKEND):
        script_dir = os.path.dirname(os.path.abspath(__file__))

        base_output_dir = os.path.join(script_dir, 'generated_videos', model_id)
        output_dir = os.path.join(base_output_dir, ATTENTION_BACKEND)
        output_video_name = f"{prompt[:100].strip()}.mp4"

        os.makedirs(output_dir, exist_ok=True)

        BASE_PARAMS = I2V_MODEL_TO_PARAMS[model_id]
        num_inference_steps = BASE_PARAMS["num_inference_steps"]
        image_path = I2V_IMAGE_PATHS[I2V_TEST_PROMPTS.index(prompt)]

        init_kwargs = {
            "num_gpus": BASE_PARAMS["num_gpus"],
            "flow_shift": BASE_PARAMS["flow_shift"],
            "sp_size": BASE_PARAMS["sp_size"],
            "tp_size": BASE_PARAMS["tp_size"],
        }
        if BASE_PARAMS.get("vae_sp"):
            init_kwargs["vae_sp"] = True
            init_kwargs["vae_tiling"] = True
        if "text-encoder-precision" in BASE_PARAMS:
            init_kwargs["text_encoder_precisions"] = BASE_PARAMS[
                "text-encoder-precision"]

        generation_kwargs = {
            "num_inference_steps": num_inference_steps,
            "output_path": output_dir,
            "image_path": image_path,
            "height": BASE_PARAMS["height"],
            "width": BASE_PARAMS["width"],
            "num_frames": BASE_PARAMS["num_frames"],
            "guidance_scale": BASE_PARAMS["guidance_scale"],
            "embedded_cfg_scale": BASE_PARAMS["embedded_cfg_scale"],
            "seed": BASE_PARAMS["seed"],
            "fps": BASE_PARAMS["fps"],
        }
        if "neg_prompt" in BASE_PARAMS:
            generation_kwargs["neg_prompt"] = BASE_PARAMS["neg_prompt"]

        generator: VideoGenerator | None = None
        try:
            generator = VideoGenerator.from_pretrained(
                model_path=BASE_PARAMS["model_path"], **init_kwargs)
            generator.generate_video(prompt, **generation_kwargs)
        finally:
            _shutdown_executor(generator)

    assert os.path.exists(
        output_dir), f"Output video was not generated at {output_dir}"

    reference_folder = os.path.join(script_dir, device_reference_folder, model_id, ATTENTION_BACKEND)

    if not os.path.exists(reference_folder):
        logger.error("Reference folder missing")
        raise FileNotFoundError(
            f"Reference video folder does not exist: {reference_folder}")

    # Find the matching reference video based on the prompt
    reference_video_name = None

    for filename in os.listdir(reference_folder):
        if filename.endswith('.mp4') and prompt[:100].strip() in filename:
            reference_video_name = filename
            break

    if not reference_video_name:
        logger.error(f"Reference video not found for prompt: {prompt} with backend: {ATTENTION_BACKEND}")
        raise FileNotFoundError(f"Reference video missing")

    reference_video_path = os.path.join(reference_folder, reference_video_name)
    generated_video_path = os.path.join(output_dir, output_video_name)

    logger.info(
        f"Computing SSIM between {reference_video_path} and {generated_video_path}"
    )
    ssim_values = compute_video_ssim_torchvision(reference_video_path,
                                                 generated_video_path,
                                                 use_ms_ssim=True)

    mean_ssim = ssim_values[0]
    logger.info(f"SSIM mean value: {mean_ssim}")
    logger.info(f"Writing SSIM results to directory: {output_dir}")

    success = write_ssim_results(output_dir, ssim_values, reference_video_path,
                                 generated_video_path, num_inference_steps,
                                 prompt)

    if not success:
        logger.error("Failed to write SSIM results to file")

    min_acceptable_ssim = 0.97
    assert mean_ssim >= min_acceptable_ssim, f"SSIM value {mean_ssim} is below threshold {min_acceptable_ssim} for {model_id} with backend {ATTENTION_BACKEND}"

fastvideo.tests.ssim.test_inference_similarity.test_inference_similarity

test_inference_similarity(prompt, ATTENTION_BACKEND, model_id)

Test that runs inference with different parameters and compares the output to reference videos using SSIM.

Source code in fastvideo/tests/ssim/test_inference_similarity.py
@pytest.mark.parametrize("prompt", TEST_PROMPTS)
@pytest.mark.parametrize("ATTENTION_BACKEND", ["FLASH_ATTN", "TORCH_SDPA"])
@pytest.mark.parametrize("model_id", list(MODEL_TO_PARAMS.keys()))
def test_inference_similarity(prompt, ATTENTION_BACKEND, model_id):
    """
    Test that runs inference with different parameters and compares the output
    to reference videos using SSIM.
    """
    with _attention_backend(ATTENTION_BACKEND):
        script_dir = os.path.dirname(os.path.abspath(__file__))

        base_output_dir = os.path.join(script_dir, 'generated_videos', model_id)
        output_dir = os.path.join(base_output_dir, ATTENTION_BACKEND)
        output_video_name = f"{prompt[:100].strip()}.mp4"

        os.makedirs(output_dir, exist_ok=True)

        BASE_PARAMS = MODEL_TO_PARAMS[model_id]
        num_inference_steps = BASE_PARAMS["num_inference_steps"]

        init_kwargs = {
            "num_gpus": BASE_PARAMS["num_gpus"],
            "sp_size": BASE_PARAMS["sp_size"],
            "tp_size": BASE_PARAMS["tp_size"],
            "use_fsdp_inference": True,
            "dit_cpu_offload": False,
            "dit_layerwise_offload": False,
        }
        if "flow_shift" in BASE_PARAMS:
            init_kwargs["flow_shift"] = BASE_PARAMS["flow_shift"]
        if BASE_PARAMS.get("vae_sp"):
            init_kwargs["vae_sp"] = True
            init_kwargs["vae_tiling"] = True
        if "text-encoder-precision" in BASE_PARAMS:
            init_kwargs["text_encoder_precisions"] = BASE_PARAMS[
                "text-encoder-precision"]
        # LTX2-specific VAE tiling parameters
        if BASE_PARAMS.get("ltx2_vae_tiling"):
            init_kwargs["ltx2_vae_tiling"] = True
            init_kwargs["ltx2_vae_spatial_tile_size_in_pixels"] = BASE_PARAMS.get(
                "ltx2_vae_spatial_tile_size_in_pixels", 512)
            init_kwargs["ltx2_vae_spatial_tile_overlap_in_pixels"] = BASE_PARAMS.get(
                "ltx2_vae_spatial_tile_overlap_in_pixels", 64)
            init_kwargs["ltx2_vae_temporal_tile_size_in_frames"] = BASE_PARAMS.get(
                "ltx2_vae_temporal_tile_size_in_frames", 64)
            init_kwargs[
                "ltx2_vae_temporal_tile_overlap_in_frames"] = BASE_PARAMS.get(
                    "ltx2_vae_temporal_tile_overlap_in_frames", 24)

        generation_kwargs = {
            "num_inference_steps": num_inference_steps,
            "output_path": output_dir,
            "height": BASE_PARAMS["height"],
            "width": BASE_PARAMS["width"],
            "num_frames": BASE_PARAMS["num_frames"],
            "guidance_scale": BASE_PARAMS["guidance_scale"],
            "embedded_cfg_scale": BASE_PARAMS["embedded_cfg_scale"],
            "seed": BASE_PARAMS["seed"],
            "fps": BASE_PARAMS["fps"],
        }
        if "neg_prompt" in BASE_PARAMS:
            generation_kwargs["neg_prompt"] = BASE_PARAMS["neg_prompt"]

        generator: VideoGenerator | None = None
        try:
            generator = VideoGenerator.from_pretrained(
                model_path=BASE_PARAMS["model_path"], **init_kwargs)
            generator.generate_video(prompt, **generation_kwargs)
        finally:
            _shutdown_executor(generator)

    assert os.path.exists(
        output_dir), f"Output video was not generated at {output_dir}"

    reference_folder = os.path.join(script_dir, device_reference_folder, model_id, ATTENTION_BACKEND)

    if not os.path.exists(reference_folder):
        logger.error("Reference folder missing")
        raise FileNotFoundError(
            f"Reference video folder does not exist: {reference_folder}")

    # Find the matching reference video based on the prompt
    reference_video_name = None

    for filename in os.listdir(reference_folder):
        if filename.endswith('.mp4') and prompt[:100].strip() in filename:
            reference_video_name = filename
            break

    if not reference_video_name:
        logger.error(f"Reference video not found for prompt: {prompt} with backend: {ATTENTION_BACKEND}")
        raise FileNotFoundError(f"Reference video missing")

    reference_video_path = os.path.join(reference_folder, reference_video_name)
    generated_video_path = os.path.join(output_dir, output_video_name)

    logger.info(
        f"Computing SSIM between {reference_video_path} and {generated_video_path}"
    )
    ssim_values = compute_video_ssim_torchvision(reference_video_path,
                                                 generated_video_path,
                                                 use_ms_ssim=True)

    mean_ssim = ssim_values[0]
    logger.info(f"SSIM mean value: {mean_ssim}")
    logger.info(f"Writing SSIM results to directory: {output_dir}")

    success = write_ssim_results(output_dir, ssim_values, reference_video_path,
                                 generated_video_path, num_inference_steps,
                                 prompt)

    if not success:
        logger.error("Failed to write SSIM results to file")

    min_acceptable_ssim = 0.93
    assert mean_ssim >= min_acceptable_ssim, f"SSIM value {mean_ssim} is below threshold {min_acceptable_ssim} for {model_id} with backend {ATTENTION_BACKEND}"