Skip to content

preprocess

Modules

fastvideo.pipelines.preprocess.preprocess_pipeline_base

Classes

fastvideo.pipelines.preprocess.preprocess_pipeline_base.BasePreprocessPipeline
BasePreprocessPipeline(model_path: str, fastvideo_args: FastVideoArgs | TrainingArgs, required_config_modules: list[str] | None = None, loaded_modules: dict[str, Module] | None = None)

Bases: ComposedPipelineBase

Base class for preprocessing pipelines that handles common functionality.

Source code in fastvideo/pipelines/composed_pipeline_base.py
def __init__(self,
             model_path: str,
             fastvideo_args: FastVideoArgs | TrainingArgs,
             required_config_modules: list[str] | None = None,
             loaded_modules: dict[str, torch.nn.Module] | None = None):
    """
    Initialize the pipeline. After __init__, the pipeline should be ready to
    use. The pipeline should be stateless and not hold any batch state.
    """
    self.fastvideo_args = fastvideo_args

    self.model_path: str = model_path
    self._stages: list[PipelineStage] = []
    self._stage_name_mapping: dict[str, PipelineStage] = {}

    if required_config_modules is not None:
        self._required_config_modules = required_config_modules

    if self._required_config_modules is None:
        raise NotImplementedError(
            "Subclass must set _required_config_modules")

    maybe_init_distributed_environment_and_model_parallel(
        fastvideo_args.tp_size, fastvideo_args.sp_size)

    # Torch profiler. Enabled and configured through env vars:
    # FASTVIDEO_TORCH_PROFILER_DIR=/path/to/save/trace
    trace_dir = envs.FASTVIDEO_TORCH_PROFILER_DIR
    self.profiler_controller = get_or_create_profiler(trace_dir)
    self.profiler = self.profiler_controller.profiler

    self.local_rank = get_world_group().local_rank

    # Load modules directly in initialization
    logger.info("Loading pipeline modules...")
    with self.profiler_controller.region("profiler_region_model_loading"):
        self.modules = self.load_modules(fastvideo_args, loaded_modules)
Functions
fastvideo.pipelines.preprocess.preprocess_pipeline_base.BasePreprocessPipeline.create_pipeline_stages
create_pipeline_stages(fastvideo_args: FastVideoArgs)

Set up pipeline stages with proper dependency injection.

Source code in fastvideo/pipelines/preprocess/preprocess_pipeline_base.py
def create_pipeline_stages(self, fastvideo_args: FastVideoArgs):
    """Set up pipeline stages with proper dependency injection."""
    self.add_stage(stage_name="prompt_encoding_stage",
                   stage=TextEncodingStage(
                       text_encoders=[self.get_module("text_encoder")],
                       tokenizers=[self.get_module("tokenizer")],
                   ))
fastvideo.pipelines.preprocess.preprocess_pipeline_base.BasePreprocessPipeline.create_record
create_record(video_name: str, vae_latent: ndarray, text_embedding: ndarray, valid_data: dict[str, Any], idx: int, extra_features: dict[str, Any] | None = None) -> dict[str, Any]

Create a record for the Parquet dataset.

Source code in fastvideo/pipelines/preprocess/preprocess_pipeline_base.py
def create_record(
        self,
        video_name: str,
        vae_latent: np.ndarray,
        text_embedding: np.ndarray,
        valid_data: dict[str, Any],
        idx: int,
        extra_features: dict[str, Any] | None = None) -> dict[str, Any]:
    """Create a record for the Parquet dataset."""
    record = {
        "id":
        video_name,
        "vae_latent_bytes":
        vae_latent.tobytes(),
        "vae_latent_shape":
        list(vae_latent.shape),
        "vae_latent_dtype":
        str(vae_latent.dtype),
        "text_embedding_bytes":
        text_embedding.tobytes(),
        "text_embedding_shape":
        list(text_embedding.shape),
        "text_embedding_dtype":
        str(text_embedding.dtype),
        "file_name":
        video_name,
        "caption":
        valid_data["text"][idx] if len(valid_data["text"]) > 0 else "",
        "media_type":
        "video",
        "width":
        valid_data["pixel_values"][idx].shape[-2]
        if len(valid_data["pixel_values"]) > 0 else 0,
        "height":
        valid_data["pixel_values"][idx].shape[-1]
        if len(valid_data["pixel_values"]) > 0 else 0,
        "num_frames":
        vae_latent.shape[1] if len(vae_latent.shape) > 1 else 0,
        "duration_sec":
        float(valid_data["duration"][idx])
        if len(valid_data["duration"]) > 0 else 0.0,
        "fps":
        float(valid_data["fps"][idx])
        if len(valid_data["fps"]) > 0 else 0.0,
    }
    if extra_features:
        record.update(extra_features)
    return record
fastvideo.pipelines.preprocess.preprocess_pipeline_base.BasePreprocessPipeline.create_record_for_schema
create_record_for_schema(preprocess_batch: PreprocessBatch, schema: Schema, strict: bool = False) -> dict[str, Any]

Create a record for the Parquet dataset using a generic schema-based approach.

Parameters:

Name Type Description Default
preprocess_batch PreprocessBatch

The batch containing the data to extract

required
schema Schema

PyArrow schema defining the expected fields

required
strict bool

If True, raises an exception when required fields are missing or unfilled

False

Returns:

Type Description
dict[str, Any]

Dictionary record matching the schema

Raises:

Type Description
ValueError

If strict=True and required fields are missing or unfilled

Source code in fastvideo/pipelines/preprocess/preprocess_pipeline_base.py
def create_record_for_schema(self,
                             preprocess_batch: PreprocessBatch,
                             schema: pa.Schema,
                             strict: bool = False) -> dict[str, Any]:
    """Create a record for the Parquet dataset using a generic schema-based approach.

    Args:
        preprocess_batch: The batch containing the data to extract
        schema: PyArrow schema defining the expected fields
        strict: If True, raises an exception when required fields are missing or unfilled

    Returns:
        Dictionary record matching the schema

    Raises:
        ValueError: If strict=True and required fields are missing or unfilled
    """
    record = {}
    unfilled_fields = []

    for field in schema.names:
        field_filled = False

        if field.endswith('_bytes'):
            # Handle binary tensor data - convert numpy array or tensor to bytes
            tensor_name = field.replace('_bytes', '')
            tensor_data = getattr(preprocess_batch, tensor_name, None)
            if tensor_data is not None:
                try:
                    if hasattr(tensor_data, 'numpy'):  # torch tensor
                        record[field] = tensor_data.cpu().numpy().tobytes()
                        field_filled = True
                    elif hasattr(tensor_data, 'tobytes'):  # numpy array
                        record[field] = tensor_data.tobytes()
                        field_filled = True
                    else:
                        raise ValueError(
                            f"Unsupported tensor type for field {field}: {type(tensor_data)}"
                        )
                except Exception as e:
                    if strict:
                        raise ValueError(
                            f"Failed to convert tensor {tensor_name} to bytes: {e}"
                        ) from e
                    record[field] = b''  # Empty bytes for missing data
            else:
                record[field] = b''  # Empty bytes for missing data

        elif field.endswith('_shape'):
            # Handle tensor shape info
            tensor_name = field.replace('_shape', '')
            tensor_data = getattr(preprocess_batch, tensor_name, None)
            if tensor_data is not None and hasattr(tensor_data, 'shape'):
                record[field] = list(tensor_data.shape)
                field_filled = True
            else:
                record[field] = []

        elif field.endswith('_dtype'):
            # Handle tensor dtype info
            tensor_name = field.replace('_dtype', '')
            tensor_data = getattr(preprocess_batch, tensor_name, None)
            if tensor_data is not None and hasattr(tensor_data, 'dtype'):
                record[field] = str(tensor_data.dtype)
                field_filled = True
            else:
                record[field] = 'unknown'

        elif field in ['width', 'height', 'num_frames']:
            # Handle integer metadata fields
            value = getattr(preprocess_batch, field, None)
            if value is not None:
                try:
                    record[field] = int(value)
                    field_filled = True
                except (ValueError, TypeError) as e:
                    if strict:
                        raise ValueError(
                            f"Failed to convert field {field} to int: {e}"
                        ) from e
                    record[field] = 0
            else:
                record[field] = 0

        elif field in ['duration_sec', 'fps']:
            # Handle float metadata fields
            # Map schema field names to batch attribute names
            attr_name = 'duration' if field == 'duration_sec' else field
            value = getattr(preprocess_batch, attr_name, None)
            if value is not None:
                try:
                    record[field] = float(value)
                    field_filled = True
                except (ValueError, TypeError) as e:
                    if strict:
                        raise ValueError(
                            f"Failed to convert field {field} to float: {e}"
                        ) from e
                    record[field] = 0.0
            else:
                record[field] = 0.0

        else:
            # Handle string fields (id, file_name, caption, media_type, etc.)
            # Map common schema field names to batch attribute names
            attr_name = field
            if field == 'caption':
                attr_name = 'text'
            elif field == 'file_name':
                attr_name = 'path'
            elif field == 'id':
                # Generate ID from path if available
                path_value = getattr(preprocess_batch, 'path', None)
                if path_value:
                    import os
                    record[field] = os.path.basename(path_value).split(
                        '.')[0]
                    field_filled = True
                else:
                    record[field] = ""
                continue
            elif field == 'media_type':
                # Determine media type from path
                path_value = getattr(preprocess_batch, 'path', None)
                if path_value:
                    record[field] = 'video' if path_value.endswith(
                        '.mp4') else 'image'
                    field_filled = True
                else:
                    record[field] = ""
                continue

            value = getattr(preprocess_batch, attr_name, None)
            if value is not None:
                record[field] = str(value)
                field_filled = True
            else:
                record[field] = ""

        # Track unfilled fields
        if not field_filled:
            unfilled_fields.append(field)

    # Handle strict mode
    if strict and unfilled_fields:
        raise ValueError(
            f"Required fields were not filled: {unfilled_fields}")

    # Log unfilled fields as warning if not in strict mode
    if unfilled_fields:
        logger.warning(
            "Some fields were not filled and got default values: %s",
            unfilled_fields)

    return record
fastvideo.pipelines.preprocess.preprocess_pipeline_base.BasePreprocessPipeline.get_extra_features
get_extra_features(valid_data: dict[str, Any], fastvideo_args: FastVideoArgs) -> dict[str, Any]

Get additional features specific to the pipeline type. Override in subclasses.

Source code in fastvideo/pipelines/preprocess/preprocess_pipeline_base.py
def get_extra_features(self, valid_data: dict[str, Any],
                       fastvideo_args: FastVideoArgs) -> dict[str, Any]:
    """Get additional features specific to the pipeline type. Override in subclasses."""
    return {}
fastvideo.pipelines.preprocess.preprocess_pipeline_base.BasePreprocessPipeline.get_pyarrow_schema
get_pyarrow_schema() -> Schema

Return the PyArrow schema for this pipeline. Must be overridden.

Source code in fastvideo/pipelines/preprocess/preprocess_pipeline_base.py
def get_pyarrow_schema(self) -> pa.Schema:
    """Return the PyArrow schema for this pipeline. Must be overridden."""
    raise NotImplementedError
fastvideo.pipelines.preprocess.preprocess_pipeline_base.BasePreprocessPipeline.get_schema_fields
get_schema_fields() -> list[str]

Get the schema fields for the pipeline type.

Source code in fastvideo/pipelines/preprocess/preprocess_pipeline_base.py
def get_schema_fields(self) -> list[str]:
    """Get the schema fields for the pipeline type."""
    return [f.name for f in self.get_pyarrow_schema()]

Functions

fastvideo.pipelines.preprocess.preprocess_pipeline_i2v

I2V Data Preprocessing pipeline implementation.

This module contains an implementation of the I2V Data Preprocessing pipeline using the modular pipeline architecture.

Classes

fastvideo.pipelines.preprocess.preprocess_pipeline_i2v.PreprocessPipeline_I2V
PreprocessPipeline_I2V(model_path: str, fastvideo_args: FastVideoArgs | TrainingArgs, required_config_modules: list[str] | None = None, loaded_modules: dict[str, Module] | None = None)

Bases: BasePreprocessPipeline

I2V preprocessing pipeline implementation.

Source code in fastvideo/pipelines/composed_pipeline_base.py
def __init__(self,
             model_path: str,
             fastvideo_args: FastVideoArgs | TrainingArgs,
             required_config_modules: list[str] | None = None,
             loaded_modules: dict[str, torch.nn.Module] | None = None):
    """
    Initialize the pipeline. After __init__, the pipeline should be ready to
    use. The pipeline should be stateless and not hold any batch state.
    """
    self.fastvideo_args = fastvideo_args

    self.model_path: str = model_path
    self._stages: list[PipelineStage] = []
    self._stage_name_mapping: dict[str, PipelineStage] = {}

    if required_config_modules is not None:
        self._required_config_modules = required_config_modules

    if self._required_config_modules is None:
        raise NotImplementedError(
            "Subclass must set _required_config_modules")

    maybe_init_distributed_environment_and_model_parallel(
        fastvideo_args.tp_size, fastvideo_args.sp_size)

    # Torch profiler. Enabled and configured through env vars:
    # FASTVIDEO_TORCH_PROFILER_DIR=/path/to/save/trace
    trace_dir = envs.FASTVIDEO_TORCH_PROFILER_DIR
    self.profiler_controller = get_or_create_profiler(trace_dir)
    self.profiler = self.profiler_controller.profiler

    self.local_rank = get_world_group().local_rank

    # Load modules directly in initialization
    logger.info("Loading pipeline modules...")
    with self.profiler_controller.region("profiler_region_model_loading"):
        self.modules = self.load_modules(fastvideo_args, loaded_modules)
Functions
fastvideo.pipelines.preprocess.preprocess_pipeline_i2v.PreprocessPipeline_I2V.create_record
create_record(video_name: str, vae_latent: ndarray, text_embedding: ndarray, valid_data: dict[str, Any], idx: int, extra_features: dict[str, Any] | None = None) -> dict[str, Any]

Create a record for the Parquet dataset with CLIP features.

Source code in fastvideo/pipelines/preprocess/preprocess_pipeline_i2v.py
def create_record(
        self,
        video_name: str,
        vae_latent: np.ndarray,
        text_embedding: np.ndarray,
        valid_data: dict[str, Any],
        idx: int,
        extra_features: dict[str, Any] | None = None) -> dict[str, Any]:
    """Create a record for the Parquet dataset with CLIP features."""
    record = super().create_record(video_name=video_name,
                                   vae_latent=vae_latent,
                                   text_embedding=text_embedding,
                                   valid_data=valid_data,
                                   idx=idx,
                                   extra_features=extra_features)

    if extra_features and "clip_feature" in extra_features:
        clip_feature = extra_features["clip_feature"]
        record.update({
            "clip_feature_bytes": clip_feature.tobytes(),
            "clip_feature_shape": list(clip_feature.shape),
            "clip_feature_dtype": str(clip_feature.dtype),
        })
    else:
        record.update({
            "clip_feature_bytes": b"",
            "clip_feature_shape": [],
            "clip_feature_dtype": "",
        })

    if extra_features and "first_frame_latent" in extra_features:
        first_frame_latent = extra_features["first_frame_latent"]
        record.update({
            "first_frame_latent_bytes":
            first_frame_latent.tobytes(),
            "first_frame_latent_shape":
            list(first_frame_latent.shape),
            "first_frame_latent_dtype":
            str(first_frame_latent.dtype),
        })
    else:
        record.update({
            "first_frame_latent_bytes": b"",
            "first_frame_latent_shape": [],
            "first_frame_latent_dtype": "",
        })

    if extra_features and "pil_image" in extra_features:
        pil_image = extra_features["pil_image"]
        record.update({
            "pil_image_bytes": pil_image.tobytes(),
            "pil_image_shape": list(pil_image.shape),
            "pil_image_dtype": str(pil_image.dtype),
        })
    else:
        record.update({
            "pil_image_bytes": b"",
            "pil_image_shape": [],
            "pil_image_dtype": "",
        })

    return record
fastvideo.pipelines.preprocess.preprocess_pipeline_i2v.PreprocessPipeline_I2V.get_pyarrow_schema
get_pyarrow_schema()

Return the PyArrow schema for I2V pipeline.

Source code in fastvideo/pipelines/preprocess/preprocess_pipeline_i2v.py
def get_pyarrow_schema(self):
    """Return the PyArrow schema for I2V pipeline."""
    return pyarrow_schema_i2v

Functions

fastvideo.pipelines.preprocess.preprocess_pipeline_ode_trajectory

ODE Trajectory Data Preprocessing pipeline implementation.

This module contains an implementation of the ODE Trajectory Data Preprocessing pipeline using the modular pipeline architecture.

Sec 4.3 of CausVid paper: https://arxiv.org/pdf/2412.07772

Classes

fastvideo.pipelines.preprocess.preprocess_pipeline_ode_trajectory.PreprocessPipeline_ODE_Trajectory
PreprocessPipeline_ODE_Trajectory(model_path: str, fastvideo_args: FastVideoArgs | TrainingArgs, required_config_modules: list[str] | None = None, loaded_modules: dict[str, Module] | None = None)

Bases: BasePreprocessPipeline

ODE Trajectory preprocessing pipeline implementation.

Source code in fastvideo/pipelines/composed_pipeline_base.py
def __init__(self,
             model_path: str,
             fastvideo_args: FastVideoArgs | TrainingArgs,
             required_config_modules: list[str] | None = None,
             loaded_modules: dict[str, torch.nn.Module] | None = None):
    """
    Initialize the pipeline. After __init__, the pipeline should be ready to
    use. The pipeline should be stateless and not hold any batch state.
    """
    self.fastvideo_args = fastvideo_args

    self.model_path: str = model_path
    self._stages: list[PipelineStage] = []
    self._stage_name_mapping: dict[str, PipelineStage] = {}

    if required_config_modules is not None:
        self._required_config_modules = required_config_modules

    if self._required_config_modules is None:
        raise NotImplementedError(
            "Subclass must set _required_config_modules")

    maybe_init_distributed_environment_and_model_parallel(
        fastvideo_args.tp_size, fastvideo_args.sp_size)

    # Torch profiler. Enabled and configured through env vars:
    # FASTVIDEO_TORCH_PROFILER_DIR=/path/to/save/trace
    trace_dir = envs.FASTVIDEO_TORCH_PROFILER_DIR
    self.profiler_controller = get_or_create_profiler(trace_dir)
    self.profiler = self.profiler_controller.profiler

    self.local_rank = get_world_group().local_rank

    # Load modules directly in initialization
    logger.info("Loading pipeline modules...")
    with self.profiler_controller.region("profiler_region_model_loading"):
        self.modules = self.load_modules(fastvideo_args, loaded_modules)
Functions
fastvideo.pipelines.preprocess.preprocess_pipeline_ode_trajectory.PreprocessPipeline_ODE_Trajectory.create_pipeline_stages
create_pipeline_stages(fastvideo_args: FastVideoArgs)

Set up pipeline stages with proper dependency injection.

Source code in fastvideo/pipelines/preprocess/preprocess_pipeline_ode_trajectory.py
def create_pipeline_stages(self, fastvideo_args: FastVideoArgs):
    """Set up pipeline stages with proper dependency injection."""
    assert fastvideo_args.pipeline_config.flow_shift == 5
    self.modules["scheduler"] = SelfForcingFlowMatchScheduler(
        shift=fastvideo_args.pipeline_config.flow_shift,
        sigma_min=0.0,
        extra_one_step=True)
    self.modules["scheduler"].set_timesteps(num_inference_steps=48,
                                            denoising_strength=1.0)

    self.add_stage(stage_name="input_validation_stage",
                   stage=InputValidationStage())
    self.add_stage(stage_name="prompt_encoding_stage",
                   stage=TextEncodingStage(
                       text_encoders=[self.get_module("text_encoder")],
                       tokenizers=[self.get_module("tokenizer")],
                   ))
    self.add_stage(stage_name="timestep_preparation_stage",
                   stage=TimestepPreparationStage(
                       scheduler=self.get_module("scheduler")))
    self.add_stage(stage_name="latent_preparation_stage",
                   stage=LatentPreparationStage(
                       scheduler=self.get_module("scheduler"),
                       transformer=self.get_module("transformer", None)))
    self.add_stage(stage_name="denoising_stage",
                   stage=DenoisingStage(
                       transformer=self.get_module("transformer"),
                       scheduler=self.get_module("scheduler"),
                       pipeline=self,
                   ))
    self.add_stage(stage_name="decoding_stage",
                   stage=DecodingStage(vae=self.get_module("vae")))
fastvideo.pipelines.preprocess.preprocess_pipeline_ode_trajectory.PreprocessPipeline_ODE_Trajectory.get_pyarrow_schema
get_pyarrow_schema() -> Schema

Return the PyArrow schema for ODE Trajectory pipeline.

Source code in fastvideo/pipelines/preprocess/preprocess_pipeline_ode_trajectory.py
def get_pyarrow_schema(self) -> pa.Schema:
    """Return the PyArrow schema for ODE Trajectory pipeline."""
    return pyarrow_schema_ode_trajectory_text_only
fastvideo.pipelines.preprocess.preprocess_pipeline_ode_trajectory.PreprocessPipeline_ODE_Trajectory.preprocess_text_and_trajectory
preprocess_text_and_trajectory(fastvideo_args: FastVideoArgs, args)

Preprocess text-only data and generate trajectory information.

Source code in fastvideo/pipelines/preprocess/preprocess_pipeline_ode_trajectory.py
def preprocess_text_and_trajectory(self, fastvideo_args: FastVideoArgs,
                                   args):
    """Preprocess text-only data and generate trajectory information."""

    for batch_idx, data in enumerate(self.pbar):
        if data is None:
            continue

        with torch.inference_mode():
            # For text-only processing, we only need text data
            # Filter out samples without text
            valid_indices = []
            for i, text in enumerate(data["text"]):
                if text and text.strip():  # Check if text is not empty
                    valid_indices.append(i)
            self.num_processed_samples += len(valid_indices)

            if not valid_indices:
                continue

            # Create new batch with only valid samples (text-only)
            valid_data = {
                "text": [data["text"][i] for i in valid_indices],
                "path": [data["path"][i] for i in valid_indices],
            }

            # Add fps and duration if available in data
            if "fps" in data:
                valid_data["fps"] = [data["fps"][i] for i in valid_indices]
            if "duration" in data:
                valid_data["duration"] = [
                    data["duration"][i] for i in valid_indices
                ]

            batch_captions = valid_data["text"]
            # Encode text using the standalone TextEncodingStage API
            prompt_embeds_list, prompt_masks_list = self.prompt_encoding_stage.encode_text(
                batch_captions,
                fastvideo_args,
                encoder_index=[0],
                return_attention_mask=True,
            )
            prompt_embeds = prompt_embeds_list[0]
            prompt_attention_masks = prompt_masks_list[0]
            assert prompt_embeds.shape[0] == prompt_attention_masks.shape[0]

            sampling_params = SamplingParam.from_pretrained(args.model_path)

            # encode negative prompt for trajectory collection
            if sampling_params.guidance_scale > 1 and sampling_params.negative_prompt is not None:
                negative_prompt_embeds_list, negative_prompt_masks_list = self.prompt_encoding_stage.encode_text(
                    sampling_params.negative_prompt,
                    fastvideo_args,
                    encoder_index=[0],
                    return_attention_mask=True,
                )
                negative_prompt_embed = negative_prompt_embeds_list[0][0]
                negative_prompt_attention_mask = negative_prompt_masks_list[
                    0][0]
            else:
                negative_prompt_embed = None
                negative_prompt_attention_mask = None

            trajectory_latents = []
            trajectory_timesteps = []
            trajectory_decoded = []

            for i, (prompt_embed, prompt_attention_mask) in enumerate(
                    zip(prompt_embeds, prompt_attention_masks,
                        strict=False)):
                prompt_embed = prompt_embed.unsqueeze(0)
                prompt_attention_mask = prompt_attention_mask.unsqueeze(0)

                # Collect the trajectory data (text-to-video generation)
                batch = ForwardBatch(**shallow_asdict(sampling_params), )
                batch.prompt_embeds = [prompt_embed]
                batch.prompt_attention_mask = [prompt_attention_mask]
                batch.negative_prompt_embeds = [negative_prompt_embed]
                batch.negative_attention_mask = [
                    negative_prompt_attention_mask
                ]
                batch.num_inference_steps = 48
                batch.return_trajectory_latents = True
                # Enabling this will save the decoded trajectory videos.
                # Used for debugging.
                batch.return_trajectory_decoded = False
                batch.height = args.max_height
                batch.width = args.max_width
                batch.fps = args.train_fps
                batch.guidance_scale = 6.0
                batch.do_classifier_free_guidance = True

                result_batch = self.input_validation_stage(
                    batch, fastvideo_args)
                result_batch = self.timestep_preparation_stage(
                    batch, fastvideo_args)
                result_batch = self.latent_preparation_stage(
                    result_batch, fastvideo_args)
                result_batch = self.denoising_stage(result_batch,
                                                    fastvideo_args)
                result_batch = self.decoding_stage(result_batch,
                                                   fastvideo_args)

                trajectory_latents.append(
                    result_batch.trajectory_latents.cpu())
                trajectory_timesteps.append(
                    result_batch.trajectory_timesteps.cpu())
                trajectory_decoded.append(result_batch.trajectory_decoded)

            # Prepare extra features for text-only processing
            extra_features = {
                "trajectory_latents": trajectory_latents,
                "trajectory_timesteps": trajectory_timesteps
            }

            if batch.return_trajectory_decoded:
                for i, decoded_frames in enumerate(trajectory_decoded):
                    for j, decoded_frame in enumerate(decoded_frames):
                        save_decoded_latents_as_video(
                            decoded_frame,
                            f"decoded_videos/trajectory_decoded_{i}_{j}.mp4",
                            args.train_fps)

            # Prepare batch data for Parquet dataset
            batch_data: list[dict[str, Any]] = []

            # Add progress bar for saving outputs
            save_pbar = tqdm(enumerate(valid_data["path"]),
                             desc="Saving outputs",
                             unit="item",
                             leave=False)

            for idx, video_path in save_pbar:
                video_name = os.path.basename(video_path).split(".")[0]

                # Convert tensors to numpy arrays
                text_embedding = prompt_embeds[idx].cpu().numpy()

                # Get extra features for this sample
                sample_extra_features = {}
                if extra_features:
                    for key, value in extra_features.items():
                        if isinstance(value, torch.Tensor):
                            sample_extra_features[key] = value[idx].cpu(
                            ).numpy()
                        else:
                            assert isinstance(value, list)
                            if isinstance(value[idx], torch.Tensor):
                                sample_extra_features[key] = value[idx].cpu(
                                ).float().numpy()
                            else:
                                sample_extra_features[key] = value[idx]

                # Create record for Parquet dataset (text-only ODE schema)
                record: dict[str, Any] = ode_text_only_record_creator(
                    video_name=video_name,
                    text_embedding=text_embedding,
                    caption=valid_data["text"][idx],
                    trajectory_latents=sample_extra_features[
                        "trajectory_latents"],
                    trajectory_timesteps=sample_extra_features[
                        "trajectory_timesteps"],
                )
                batch_data.append(record)

            if batch_data:
                write_pbar = tqdm(total=1,
                                  desc="Writing to Parquet dataset",
                                  unit="batch")
                table = records_to_table(batch_data,
                                         self.get_pyarrow_schema())
                write_pbar.update(1)
                write_pbar.close()

                if not hasattr(self, 'dataset_writer'):
                    self.dataset_writer = ParquetDatasetWriter(
                        out_dir=self.combined_parquet_dir,
                        samples_per_file=args.samples_per_file,
                    )
                self.dataset_writer.append_table(table)

                logger.info("Collected batch with %s samples", len(table))

            if self.num_processed_samples >= args.flush_frequency:
                written = self.dataset_writer.flush()
                logger.info("Flushed %s samples to parquet", written)
                self.num_processed_samples = 0

    # Final flush for any remaining samples
    if hasattr(self, 'dataset_writer'):
        written = self.dataset_writer.flush(write_remainder=True)
        if written:
            logger.info("Final flush wrote %s samples", written)

Functions

fastvideo.pipelines.preprocess.preprocess_pipeline_t2v

T2V Data Preprocessing pipeline implementation.

This module contains an implementation of the T2V Data Preprocessing pipeline using the modular pipeline architecture.

Classes

fastvideo.pipelines.preprocess.preprocess_pipeline_t2v.PreprocessPipeline_T2V
PreprocessPipeline_T2V(model_path: str, fastvideo_args: FastVideoArgs | TrainingArgs, required_config_modules: list[str] | None = None, loaded_modules: dict[str, Module] | None = None)

Bases: BasePreprocessPipeline

T2V preprocessing pipeline implementation.

Source code in fastvideo/pipelines/composed_pipeline_base.py
def __init__(self,
             model_path: str,
             fastvideo_args: FastVideoArgs | TrainingArgs,
             required_config_modules: list[str] | None = None,
             loaded_modules: dict[str, torch.nn.Module] | None = None):
    """
    Initialize the pipeline. After __init__, the pipeline should be ready to
    use. The pipeline should be stateless and not hold any batch state.
    """
    self.fastvideo_args = fastvideo_args

    self.model_path: str = model_path
    self._stages: list[PipelineStage] = []
    self._stage_name_mapping: dict[str, PipelineStage] = {}

    if required_config_modules is not None:
        self._required_config_modules = required_config_modules

    if self._required_config_modules is None:
        raise NotImplementedError(
            "Subclass must set _required_config_modules")

    maybe_init_distributed_environment_and_model_parallel(
        fastvideo_args.tp_size, fastvideo_args.sp_size)

    # Torch profiler. Enabled and configured through env vars:
    # FASTVIDEO_TORCH_PROFILER_DIR=/path/to/save/trace
    trace_dir = envs.FASTVIDEO_TORCH_PROFILER_DIR
    self.profiler_controller = get_or_create_profiler(trace_dir)
    self.profiler = self.profiler_controller.profiler

    self.local_rank = get_world_group().local_rank

    # Load modules directly in initialization
    logger.info("Loading pipeline modules...")
    with self.profiler_controller.region("profiler_region_model_loading"):
        self.modules = self.load_modules(fastvideo_args, loaded_modules)
Functions
fastvideo.pipelines.preprocess.preprocess_pipeline_t2v.PreprocessPipeline_T2V.get_pyarrow_schema
get_pyarrow_schema()

Return the PyArrow schema for T2V pipeline.

Source code in fastvideo/pipelines/preprocess/preprocess_pipeline_t2v.py
def get_pyarrow_schema(self):
    """Return the PyArrow schema for T2V pipeline."""
    return pyarrow_schema_t2v

fastvideo.pipelines.preprocess.preprocess_pipeline_text

Text-only Data Preprocessing pipeline implementation.

This module contains an implementation of the Text-only Data Preprocessing pipeline using the modular pipeline architecture, based on the ODE Trajectory preprocessing.

Classes

fastvideo.pipelines.preprocess.preprocess_pipeline_text.PreprocessPipeline_Text
PreprocessPipeline_Text(model_path: str, fastvideo_args: FastVideoArgs | TrainingArgs, required_config_modules: list[str] | None = None, loaded_modules: dict[str, Module] | None = None)

Bases: BasePreprocessPipeline

Text-only preprocessing pipeline implementation.

Source code in fastvideo/pipelines/composed_pipeline_base.py
def __init__(self,
             model_path: str,
             fastvideo_args: FastVideoArgs | TrainingArgs,
             required_config_modules: list[str] | None = None,
             loaded_modules: dict[str, torch.nn.Module] | None = None):
    """
    Initialize the pipeline. After __init__, the pipeline should be ready to
    use. The pipeline should be stateless and not hold any batch state.
    """
    self.fastvideo_args = fastvideo_args

    self.model_path: str = model_path
    self._stages: list[PipelineStage] = []
    self._stage_name_mapping: dict[str, PipelineStage] = {}

    if required_config_modules is not None:
        self._required_config_modules = required_config_modules

    if self._required_config_modules is None:
        raise NotImplementedError(
            "Subclass must set _required_config_modules")

    maybe_init_distributed_environment_and_model_parallel(
        fastvideo_args.tp_size, fastvideo_args.sp_size)

    # Torch profiler. Enabled and configured through env vars:
    # FASTVIDEO_TORCH_PROFILER_DIR=/path/to/save/trace
    trace_dir = envs.FASTVIDEO_TORCH_PROFILER_DIR
    self.profiler_controller = get_or_create_profiler(trace_dir)
    self.profiler = self.profiler_controller.profiler

    self.local_rank = get_world_group().local_rank

    # Load modules directly in initialization
    logger.info("Loading pipeline modules...")
    with self.profiler_controller.region("profiler_region_model_loading"):
        self.modules = self.load_modules(fastvideo_args, loaded_modules)
Functions
fastvideo.pipelines.preprocess.preprocess_pipeline_text.PreprocessPipeline_Text.create_pipeline_stages
create_pipeline_stages(fastvideo_args: FastVideoArgs)

Set up pipeline stages with proper dependency injection.

Source code in fastvideo/pipelines/preprocess/preprocess_pipeline_text.py
def create_pipeline_stages(self, fastvideo_args: FastVideoArgs):
    """Set up pipeline stages with proper dependency injection."""
    self.add_stage(stage_name="prompt_encoding_stage",
                   stage=TextEncodingStage(
                       text_encoders=[self.get_module("text_encoder")],
                       tokenizers=[self.get_module("tokenizer")],
                   ))
fastvideo.pipelines.preprocess.preprocess_pipeline_text.PreprocessPipeline_Text.get_pyarrow_schema
get_pyarrow_schema()

Return the PyArrow schema for text-only pipeline.

Source code in fastvideo/pipelines/preprocess/preprocess_pipeline_text.py
def get_pyarrow_schema(self):
    """Return the PyArrow schema for text-only pipeline."""
    return pyarrow_schema_text_only
fastvideo.pipelines.preprocess.preprocess_pipeline_text.PreprocessPipeline_Text.preprocess_text_only
preprocess_text_only(fastvideo_args: FastVideoArgs, args)

Preprocess text-only data.

Source code in fastvideo/pipelines/preprocess/preprocess_pipeline_text.py
def preprocess_text_only(self, fastvideo_args: FastVideoArgs, args):
    """Preprocess text-only data."""

    for batch_idx, data in enumerate(self.pbar):
        if data is None:
            continue

        with torch.inference_mode():
            # For text-only processing, we only need text data
            # Filter out samples without text
            valid_indices = []
            for i, text in enumerate(data["text"]):
                if text and text.strip():  # Check if text is not empty
                    valid_indices.append(i)
            self.num_processed_samples += len(valid_indices)

            if not valid_indices:
                continue

            # Create new batch with only valid samples (text-only)
            valid_data = {
                "text": [data["text"][i] for i in valid_indices],
                "path": [data["path"][i] for i in valid_indices],
            }

            batch_captions = valid_data["text"]
            # Encode text using the standalone TextEncodingStage API
            prompt_embeds_list, prompt_masks_list = self.prompt_encoding_stage.encode_text(
                batch_captions,
                fastvideo_args,
                encoder_index=[0],
                return_attention_mask=True,
            )
            prompt_embeds = prompt_embeds_list[0]
            prompt_attention_masks = prompt_masks_list[0]
            assert prompt_embeds.shape[0] == prompt_attention_masks.shape[0]

            logger.info("===== prompt_embeds: %s", prompt_embeds.shape)
            logger.info("===== prompt_attention_masks: %s",
                        prompt_attention_masks.shape)

            # Prepare batch data for Parquet dataset
            batch_data = []

            # Add progress bar for saving outputs
            save_pbar = tqdm(enumerate(valid_data["path"]),
                             desc="Saving outputs",
                             unit="item",
                             leave=False)

            for idx, text_path in save_pbar:
                text_name = os.path.basename(text_path).split(".")[0]

                # Convert tensors to numpy arrays
                text_embedding = prompt_embeds[idx].cpu().numpy()

                # Create record for Parquet dataset (text-only schema)
                record = text_only_record_creator(
                    text_name=text_name,
                    text_embedding=text_embedding,
                    caption=valid_data["text"][idx],
                )
                batch_data.append(record)

            if batch_data:
                write_pbar = tqdm(total=1,
                                  desc="Writing to Parquet dataset",
                                  unit="batch")
                table = records_to_table(batch_data,
                                         pyarrow_schema_text_only)
                write_pbar.update(1)
                write_pbar.close()

                if not hasattr(self, 'dataset_writer'):
                    self.dataset_writer = ParquetDatasetWriter(
                        out_dir=self.combined_parquet_dir,
                        samples_per_file=args.samples_per_file,
                    )
                self.dataset_writer.append_table(table)

                logger.info("Collected batch with %s samples", len(table))

            if self.num_processed_samples >= args.flush_frequency:
                written = self.dataset_writer.flush()
                logger.info("Flushed %s samples to parquet", written)
                self.num_processed_samples = 0

    # Final flush for any remaining samples
    if hasattr(self, 'dataset_writer'):
        written = self.dataset_writer.flush(write_remainder=True)
        if written:
            logger.info("Final flush wrote %s samples", written)

Functions

fastvideo.pipelines.preprocess.preprocess_stages

Classes

fastvideo.pipelines.preprocess.preprocess_stages.TextTransformStage
TextTransformStage(cfg_uncondition_drop_rate: float, seed: int)

Bases: PipelineStage

Process text data according to the cfg rate.

Source code in fastvideo/pipelines/preprocess/preprocess_stages.py
def __init__(self, cfg_uncondition_drop_rate: float, seed: int) -> None:
    self.cfg_rate = cfg_uncondition_drop_rate
    self.rng = random.Random(seed)
fastvideo.pipelines.preprocess.preprocess_stages.VideoTransformStage
VideoTransformStage(train_fps: int, num_frames: int, max_height: int, max_width: int, do_temporal_sample: bool)

Bases: PipelineStage

Crop a video in temporal dimension.

Source code in fastvideo/pipelines/preprocess/preprocess_stages.py
def __init__(self, train_fps: int, num_frames: int, max_height: int,
             max_width: int, do_temporal_sample: bool) -> None:
    self.train_fps = train_fps
    self.num_frames = num_frames
    if do_temporal_sample:
        self.temporal_sample_fn: Callable | None = TemporalRandomCrop(
            num_frames)
    else:
        self.temporal_sample_fn = None

    self.video_transform = transforms.Compose([
        CenterCropResizeVideo((max_height, max_width)),
    ])