Skip to content

preprocess_pipeline_ode_trajectory

ODE Trajectory Data Preprocessing pipeline implementation.

This module contains an implementation of the ODE Trajectory Data Preprocessing pipeline using the modular pipeline architecture.

Sec 4.3 of CausVid paper: https://arxiv.org/pdf/2412.07772

Classes

fastvideo.pipelines.preprocess.preprocess_pipeline_ode_trajectory.PreprocessPipeline_ODE_Trajectory

PreprocessPipeline_ODE_Trajectory(model_path: str, fastvideo_args: FastVideoArgs | TrainingArgs, required_config_modules: list[str] | None = None, loaded_modules: dict[str, Module] | None = None)

Bases: BasePreprocessPipeline

ODE Trajectory preprocessing pipeline implementation.

Source code in fastvideo/pipelines/composed_pipeline_base.py
def __init__(self,
             model_path: str,
             fastvideo_args: FastVideoArgs | TrainingArgs,
             required_config_modules: list[str] | None = None,
             loaded_modules: dict[str, torch.nn.Module] | None = None):
    """
    Initialize the pipeline. After __init__, the pipeline should be ready to
    use. The pipeline should be stateless and not hold any batch state.
    """
    self.fastvideo_args = fastvideo_args

    self.model_path: str = model_path
    self._stages: list[PipelineStage] = []
    self._stage_name_mapping: dict[str, PipelineStage] = {}

    if required_config_modules is not None:
        self._required_config_modules = required_config_modules

    if self._required_config_modules is None:
        raise NotImplementedError(
            "Subclass must set _required_config_modules")

    maybe_init_distributed_environment_and_model_parallel(
        fastvideo_args.tp_size, fastvideo_args.sp_size)

    # Torch profiler. Enabled and configured through env vars:
    # FASTVIDEO_TORCH_PROFILER_DIR=/path/to/save/trace
    trace_dir = envs.FASTVIDEO_TORCH_PROFILER_DIR
    self.profiler_controller = get_or_create_profiler(trace_dir)
    self.profiler = self.profiler_controller.profiler

    self.local_rank = get_world_group().local_rank

    # Load modules directly in initialization
    logger.info("Loading pipeline modules...")
    with self.profiler_controller.region("profiler_region_model_loading"):
        self.modules = self.load_modules(fastvideo_args, loaded_modules)

Functions

fastvideo.pipelines.preprocess.preprocess_pipeline_ode_trajectory.PreprocessPipeline_ODE_Trajectory.create_pipeline_stages
create_pipeline_stages(fastvideo_args: FastVideoArgs)

Set up pipeline stages with proper dependency injection.

Source code in fastvideo/pipelines/preprocess/preprocess_pipeline_ode_trajectory.py
def create_pipeline_stages(self, fastvideo_args: FastVideoArgs):
    """Set up pipeline stages with proper dependency injection."""
    assert fastvideo_args.pipeline_config.flow_shift == 5
    self.modules["scheduler"] = SelfForcingFlowMatchScheduler(
        shift=fastvideo_args.pipeline_config.flow_shift,
        sigma_min=0.0,
        extra_one_step=True)
    self.modules["scheduler"].set_timesteps(num_inference_steps=48,
                                            denoising_strength=1.0)

    self.add_stage(stage_name="input_validation_stage",
                   stage=InputValidationStage())
    self.add_stage(stage_name="prompt_encoding_stage",
                   stage=TextEncodingStage(
                       text_encoders=[self.get_module("text_encoder")],
                       tokenizers=[self.get_module("tokenizer")],
                   ))
    self.add_stage(stage_name="timestep_preparation_stage",
                   stage=TimestepPreparationStage(
                       scheduler=self.get_module("scheduler")))
    self.add_stage(stage_name="latent_preparation_stage",
                   stage=LatentPreparationStage(
                       scheduler=self.get_module("scheduler"),
                       transformer=self.get_module("transformer", None)))
    self.add_stage(stage_name="denoising_stage",
                   stage=DenoisingStage(
                       transformer=self.get_module("transformer"),
                       scheduler=self.get_module("scheduler"),
                       pipeline=self,
                   ))
    self.add_stage(stage_name="decoding_stage",
                   stage=DecodingStage(vae=self.get_module("vae")))
fastvideo.pipelines.preprocess.preprocess_pipeline_ode_trajectory.PreprocessPipeline_ODE_Trajectory.get_pyarrow_schema
get_pyarrow_schema() -> Schema

Return the PyArrow schema for ODE Trajectory pipeline.

Source code in fastvideo/pipelines/preprocess/preprocess_pipeline_ode_trajectory.py
def get_pyarrow_schema(self) -> pa.Schema:
    """Return the PyArrow schema for ODE Trajectory pipeline."""
    return pyarrow_schema_ode_trajectory_text_only
fastvideo.pipelines.preprocess.preprocess_pipeline_ode_trajectory.PreprocessPipeline_ODE_Trajectory.preprocess_text_and_trajectory
preprocess_text_and_trajectory(fastvideo_args: FastVideoArgs, args)

Preprocess text-only data and generate trajectory information.

Source code in fastvideo/pipelines/preprocess/preprocess_pipeline_ode_trajectory.py
def preprocess_text_and_trajectory(self, fastvideo_args: FastVideoArgs,
                                   args):
    """Preprocess text-only data and generate trajectory information."""

    for batch_idx, data in enumerate(self.pbar):
        if data is None:
            continue

        with torch.inference_mode():
            # For text-only processing, we only need text data
            # Filter out samples without text
            valid_indices = []
            for i, text in enumerate(data["text"]):
                if text and text.strip():  # Check if text is not empty
                    valid_indices.append(i)
            self.num_processed_samples += len(valid_indices)

            if not valid_indices:
                continue

            # Create new batch with only valid samples (text-only)
            valid_data = {
                "text": [data["text"][i] for i in valid_indices],
                "path": [data["path"][i] for i in valid_indices],
            }

            # Add fps and duration if available in data
            if "fps" in data:
                valid_data["fps"] = [data["fps"][i] for i in valid_indices]
            if "duration" in data:
                valid_data["duration"] = [
                    data["duration"][i] for i in valid_indices
                ]

            batch_captions = valid_data["text"]
            # Encode text using the standalone TextEncodingStage API
            prompt_embeds_list, prompt_masks_list = self.prompt_encoding_stage.encode_text(
                batch_captions,
                fastvideo_args,
                encoder_index=[0],
                return_attention_mask=True,
            )
            prompt_embeds = prompt_embeds_list[0]
            prompt_attention_masks = prompt_masks_list[0]
            assert prompt_embeds.shape[0] == prompt_attention_masks.shape[0]

            sampling_params = SamplingParam.from_pretrained(args.model_path)

            # encode negative prompt for trajectory collection
            if sampling_params.guidance_scale > 1 and sampling_params.negative_prompt is not None:
                negative_prompt_embeds_list, negative_prompt_masks_list = self.prompt_encoding_stage.encode_text(
                    sampling_params.negative_prompt,
                    fastvideo_args,
                    encoder_index=[0],
                    return_attention_mask=True,
                )
                negative_prompt_embed = negative_prompt_embeds_list[0][0]
                negative_prompt_attention_mask = negative_prompt_masks_list[
                    0][0]
            else:
                negative_prompt_embed = None
                negative_prompt_attention_mask = None

            trajectory_latents = []
            trajectory_timesteps = []
            trajectory_decoded = []

            for i, (prompt_embed, prompt_attention_mask) in enumerate(
                    zip(prompt_embeds, prompt_attention_masks,
                        strict=False)):
                prompt_embed = prompt_embed.unsqueeze(0)
                prompt_attention_mask = prompt_attention_mask.unsqueeze(0)

                # Collect the trajectory data (text-to-video generation)
                batch = ForwardBatch(**shallow_asdict(sampling_params), )
                batch.prompt_embeds = [prompt_embed]
                batch.prompt_attention_mask = [prompt_attention_mask]
                batch.negative_prompt_embeds = [negative_prompt_embed]
                batch.negative_attention_mask = [
                    negative_prompt_attention_mask
                ]
                batch.num_inference_steps = 48
                batch.return_trajectory_latents = True
                # Enabling this will save the decoded trajectory videos.
                # Used for debugging.
                batch.return_trajectory_decoded = False
                batch.height = args.max_height
                batch.width = args.max_width
                batch.fps = args.train_fps
                batch.guidance_scale = 6.0
                batch.do_classifier_free_guidance = True

                result_batch = self.input_validation_stage(
                    batch, fastvideo_args)
                result_batch = self.timestep_preparation_stage(
                    batch, fastvideo_args)
                result_batch = self.latent_preparation_stage(
                    result_batch, fastvideo_args)
                result_batch = self.denoising_stage(result_batch,
                                                    fastvideo_args)
                result_batch = self.decoding_stage(result_batch,
                                                   fastvideo_args)

                trajectory_latents.append(
                    result_batch.trajectory_latents.cpu())
                trajectory_timesteps.append(
                    result_batch.trajectory_timesteps.cpu())
                trajectory_decoded.append(result_batch.trajectory_decoded)

            # Prepare extra features for text-only processing
            extra_features = {
                "trajectory_latents": trajectory_latents,
                "trajectory_timesteps": trajectory_timesteps
            }

            if batch.return_trajectory_decoded:
                for i, decoded_frames in enumerate(trajectory_decoded):
                    for j, decoded_frame in enumerate(decoded_frames):
                        save_decoded_latents_as_video(
                            decoded_frame,
                            f"decoded_videos/trajectory_decoded_{i}_{j}.mp4",
                            args.train_fps)

            # Prepare batch data for Parquet dataset
            batch_data: list[dict[str, Any]] = []

            # Add progress bar for saving outputs
            save_pbar = tqdm(enumerate(valid_data["path"]),
                             desc="Saving outputs",
                             unit="item",
                             leave=False)

            for idx, video_path in save_pbar:
                video_name = os.path.basename(video_path).split(".")[0]

                # Convert tensors to numpy arrays
                text_embedding = prompt_embeds[idx].cpu().numpy()

                # Get extra features for this sample
                sample_extra_features = {}
                if extra_features:
                    for key, value in extra_features.items():
                        if isinstance(value, torch.Tensor):
                            sample_extra_features[key] = value[idx].cpu(
                            ).numpy()
                        else:
                            assert isinstance(value, list)
                            if isinstance(value[idx], torch.Tensor):
                                sample_extra_features[key] = value[idx].cpu(
                                ).float().numpy()
                            else:
                                sample_extra_features[key] = value[idx]

                # Create record for Parquet dataset (text-only ODE schema)
                record: dict[str, Any] = ode_text_only_record_creator(
                    video_name=video_name,
                    text_embedding=text_embedding,
                    caption=valid_data["text"][idx],
                    trajectory_latents=sample_extra_features[
                        "trajectory_latents"],
                    trajectory_timesteps=sample_extra_features[
                        "trajectory_timesteps"],
                )
                batch_data.append(record)

            if batch_data:
                write_pbar = tqdm(total=1,
                                  desc="Writing to Parquet dataset",
                                  unit="batch")
                table = records_to_table(batch_data,
                                         self.get_pyarrow_schema())
                write_pbar.update(1)
                write_pbar.close()

                if not hasattr(self, 'dataset_writer'):
                    self.dataset_writer = ParquetDatasetWriter(
                        out_dir=self.combined_parquet_dir,
                        samples_per_file=args.samples_per_file,
                    )
                self.dataset_writer.append_table(table)

                logger.info("Collected batch with %s samples", len(table))

            if self.num_processed_samples >= args.flush_frequency:
                written = self.dataset_writer.flush()
                logger.info("Flushed %s samples to parquet", written)
                self.num_processed_samples = 0

    # Final flush for any remaining samples
    if hasattr(self, 'dataset_writer'):
        written = self.dataset_writer.flush(write_remainder=True)
        if written:
            logger.info("Final flush wrote %s samples", written)

Functions