Skip to content

preprocess_pipeline_text

Text-only Data Preprocessing pipeline implementation.

This module contains an implementation of the Text-only Data Preprocessing pipeline using the modular pipeline architecture, based on the ODE Trajectory preprocessing.

Classes

fastvideo.pipelines.preprocess.preprocess_pipeline_text.PreprocessPipeline_Text

PreprocessPipeline_Text(model_path: str, fastvideo_args: FastVideoArgs | TrainingArgs, required_config_modules: list[str] | None = None, loaded_modules: dict[str, Module] | None = None)

Bases: BasePreprocessPipeline

Text-only preprocessing pipeline implementation.

Source code in fastvideo/pipelines/composed_pipeline_base.py
def __init__(self,
             model_path: str,
             fastvideo_args: FastVideoArgs | TrainingArgs,
             required_config_modules: list[str] | None = None,
             loaded_modules: dict[str, torch.nn.Module] | None = None):
    """
    Initialize the pipeline. After __init__, the pipeline should be ready to
    use. The pipeline should be stateless and not hold any batch state.
    """
    self.fastvideo_args = fastvideo_args

    self.model_path: str = model_path
    self._stages: list[PipelineStage] = []
    self._stage_name_mapping: dict[str, PipelineStage] = {}

    if required_config_modules is not None:
        self._required_config_modules = required_config_modules

    if self._required_config_modules is None:
        raise NotImplementedError(
            "Subclass must set _required_config_modules")

    maybe_init_distributed_environment_and_model_parallel(
        fastvideo_args.tp_size, fastvideo_args.sp_size)

    # Torch profiler. Enabled and configured through env vars:
    # FASTVIDEO_TORCH_PROFILER_DIR=/path/to/save/trace
    trace_dir = envs.FASTVIDEO_TORCH_PROFILER_DIR
    self.profiler_controller = get_or_create_profiler(trace_dir)
    self.profiler = self.profiler_controller.profiler

    self.local_rank = get_world_group().local_rank

    # Load modules directly in initialization
    logger.info("Loading pipeline modules...")
    with self.profiler_controller.region("profiler_region_model_loading"):
        self.modules = self.load_modules(fastvideo_args, loaded_modules)

Functions

fastvideo.pipelines.preprocess.preprocess_pipeline_text.PreprocessPipeline_Text.create_pipeline_stages
create_pipeline_stages(fastvideo_args: FastVideoArgs)

Set up pipeline stages with proper dependency injection.

Source code in fastvideo/pipelines/preprocess/preprocess_pipeline_text.py
def create_pipeline_stages(self, fastvideo_args: FastVideoArgs):
    """Set up pipeline stages with proper dependency injection."""
    self.add_stage(stage_name="prompt_encoding_stage",
                   stage=TextEncodingStage(
                       text_encoders=[self.get_module("text_encoder")],
                       tokenizers=[self.get_module("tokenizer")],
                   ))
fastvideo.pipelines.preprocess.preprocess_pipeline_text.PreprocessPipeline_Text.get_pyarrow_schema
get_pyarrow_schema()

Return the PyArrow schema for text-only pipeline.

Source code in fastvideo/pipelines/preprocess/preprocess_pipeline_text.py
def get_pyarrow_schema(self):
    """Return the PyArrow schema for text-only pipeline."""
    return pyarrow_schema_text_only
fastvideo.pipelines.preprocess.preprocess_pipeline_text.PreprocessPipeline_Text.preprocess_text_only
preprocess_text_only(fastvideo_args: FastVideoArgs, args)

Preprocess text-only data.

Source code in fastvideo/pipelines/preprocess/preprocess_pipeline_text.py
def preprocess_text_only(self, fastvideo_args: FastVideoArgs, args):
    """Preprocess text-only data."""

    for batch_idx, data in enumerate(self.pbar):
        if data is None:
            continue

        with torch.inference_mode():
            # For text-only processing, we only need text data
            # Filter out samples without text
            valid_indices = []
            for i, text in enumerate(data["text"]):
                if text and text.strip():  # Check if text is not empty
                    valid_indices.append(i)
            self.num_processed_samples += len(valid_indices)

            if not valid_indices:
                continue

            # Create new batch with only valid samples (text-only)
            valid_data = {
                "text": [data["text"][i] for i in valid_indices],
                "path": [data["path"][i] for i in valid_indices],
            }

            batch_captions = valid_data["text"]
            # Encode text using the standalone TextEncodingStage API
            prompt_embeds_list, prompt_masks_list = self.prompt_encoding_stage.encode_text(
                batch_captions,
                fastvideo_args,
                encoder_index=[0],
                return_attention_mask=True,
            )
            prompt_embeds = prompt_embeds_list[0]
            prompt_attention_masks = prompt_masks_list[0]
            assert prompt_embeds.shape[0] == prompt_attention_masks.shape[0]

            logger.info("===== prompt_embeds: %s", prompt_embeds.shape)
            logger.info("===== prompt_attention_masks: %s",
                        prompt_attention_masks.shape)

            # Prepare batch data for Parquet dataset
            batch_data = []

            # Add progress bar for saving outputs
            save_pbar = tqdm(enumerate(valid_data["path"]),
                             desc="Saving outputs",
                             unit="item",
                             leave=False)

            for idx, text_path in save_pbar:
                text_name = os.path.basename(text_path).split(".")[0]

                # Convert tensors to numpy arrays
                text_embedding = prompt_embeds[idx].cpu().numpy()

                # Create record for Parquet dataset (text-only schema)
                record = text_only_record_creator(
                    text_name=text_name,
                    text_embedding=text_embedding,
                    caption=valid_data["text"][idx],
                )
                batch_data.append(record)

            if batch_data:
                write_pbar = tqdm(total=1,
                                  desc="Writing to Parquet dataset",
                                  unit="batch")
                table = records_to_table(batch_data,
                                         pyarrow_schema_text_only)
                write_pbar.update(1)
                write_pbar.close()

                if not hasattr(self, 'dataset_writer'):
                    self.dataset_writer = ParquetDatasetWriter(
                        out_dir=self.combined_parquet_dir,
                        samples_per_file=args.samples_per_file,
                    )
                self.dataset_writer.append_table(table)

                logger.info("Collected batch with %s samples", len(table))

            if self.num_processed_samples >= args.flush_frequency:
                written = self.dataset_writer.flush()
                logger.info("Flushed %s samples to parquet", written)
                self.num_processed_samples = 0

    # Final flush for any remaining samples
    if hasattr(self, 'dataset_writer'):
        written = self.dataset_writer.flush(write_remainder=True)
        if written:
            logger.info("Final flush wrote %s samples", written)

Functions