Skip to content

preprocess_pipeline_base

Classes

fastvideo.pipelines.preprocess.preprocess_pipeline_base.BasePreprocessPipeline

BasePreprocessPipeline(model_path: str, fastvideo_args: FastVideoArgs | TrainingArgs, required_config_modules: list[str] | None = None, loaded_modules: dict[str, Module] | None = None)

Bases: ComposedPipelineBase

Base class for preprocessing pipelines that handles common functionality.

Source code in fastvideo/pipelines/composed_pipeline_base.py
def __init__(self,
             model_path: str,
             fastvideo_args: FastVideoArgs | TrainingArgs,
             required_config_modules: list[str] | None = None,
             loaded_modules: dict[str, torch.nn.Module] | None = None):
    """
    Initialize the pipeline. After __init__, the pipeline should be ready to
    use. The pipeline should be stateless and not hold any batch state.
    """
    self.fastvideo_args = fastvideo_args

    self.model_path: str = model_path
    self._stages: list[PipelineStage] = []
    self._stage_name_mapping: dict[str, PipelineStage] = {}

    if required_config_modules is not None:
        self._required_config_modules = required_config_modules

    if self._required_config_modules is None:
        raise NotImplementedError(
            "Subclass must set _required_config_modules")

    maybe_init_distributed_environment_and_model_parallel(
        fastvideo_args.tp_size, fastvideo_args.sp_size)

    # Torch profiler. Enabled and configured through env vars:
    # FASTVIDEO_TORCH_PROFILER_DIR=/path/to/save/trace
    trace_dir = envs.FASTVIDEO_TORCH_PROFILER_DIR
    self.profiler_controller = get_or_create_profiler(trace_dir)
    self.profiler = self.profiler_controller.profiler

    self.local_rank = get_world_group().local_rank

    # Load modules directly in initialization
    logger.info("Loading pipeline modules...")
    with self.profiler_controller.region("profiler_region_model_loading"):
        self.modules = self.load_modules(fastvideo_args, loaded_modules)

Functions

fastvideo.pipelines.preprocess.preprocess_pipeline_base.BasePreprocessPipeline.create_pipeline_stages
create_pipeline_stages(fastvideo_args: FastVideoArgs)

Set up pipeline stages with proper dependency injection.

Source code in fastvideo/pipelines/preprocess/preprocess_pipeline_base.py
def create_pipeline_stages(self, fastvideo_args: FastVideoArgs):
    """Set up pipeline stages with proper dependency injection."""
    self.add_stage(stage_name="prompt_encoding_stage",
                   stage=TextEncodingStage(
                       text_encoders=[self.get_module("text_encoder")],
                       tokenizers=[self.get_module("tokenizer")],
                   ))
fastvideo.pipelines.preprocess.preprocess_pipeline_base.BasePreprocessPipeline.create_record
create_record(video_name: str, vae_latent: ndarray, text_embedding: ndarray, valid_data: dict[str, Any], idx: int, extra_features: dict[str, Any] | None = None) -> dict[str, Any]

Create a record for the Parquet dataset.

Source code in fastvideo/pipelines/preprocess/preprocess_pipeline_base.py
def create_record(
        self,
        video_name: str,
        vae_latent: np.ndarray,
        text_embedding: np.ndarray,
        valid_data: dict[str, Any],
        idx: int,
        extra_features: dict[str, Any] | None = None) -> dict[str, Any]:
    """Create a record for the Parquet dataset."""
    record = {
        "id":
        video_name,
        "vae_latent_bytes":
        vae_latent.tobytes(),
        "vae_latent_shape":
        list(vae_latent.shape),
        "vae_latent_dtype":
        str(vae_latent.dtype),
        "text_embedding_bytes":
        text_embedding.tobytes(),
        "text_embedding_shape":
        list(text_embedding.shape),
        "text_embedding_dtype":
        str(text_embedding.dtype),
        "file_name":
        video_name,
        "caption":
        valid_data["text"][idx] if len(valid_data["text"]) > 0 else "",
        "media_type":
        "video",
        "width":
        valid_data["pixel_values"][idx].shape[-2]
        if len(valid_data["pixel_values"]) > 0 else 0,
        "height":
        valid_data["pixel_values"][idx].shape[-1]
        if len(valid_data["pixel_values"]) > 0 else 0,
        "num_frames":
        vae_latent.shape[1] if len(vae_latent.shape) > 1 else 0,
        "duration_sec":
        float(valid_data["duration"][idx])
        if len(valid_data["duration"]) > 0 else 0.0,
        "fps":
        float(valid_data["fps"][idx])
        if len(valid_data["fps"]) > 0 else 0.0,
    }
    if extra_features:
        record.update(extra_features)
    return record
fastvideo.pipelines.preprocess.preprocess_pipeline_base.BasePreprocessPipeline.create_record_for_schema
create_record_for_schema(preprocess_batch: PreprocessBatch, schema: Schema, strict: bool = False) -> dict[str, Any]

Create a record for the Parquet dataset using a generic schema-based approach.

Parameters:

Name Type Description Default
preprocess_batch PreprocessBatch

The batch containing the data to extract

required
schema Schema

PyArrow schema defining the expected fields

required
strict bool

If True, raises an exception when required fields are missing or unfilled

False

Returns:

Type Description
dict[str, Any]

Dictionary record matching the schema

Raises:

Type Description
ValueError

If strict=True and required fields are missing or unfilled

Source code in fastvideo/pipelines/preprocess/preprocess_pipeline_base.py
def create_record_for_schema(self,
                             preprocess_batch: PreprocessBatch,
                             schema: pa.Schema,
                             strict: bool = False) -> dict[str, Any]:
    """Create a record for the Parquet dataset using a generic schema-based approach.

    Args:
        preprocess_batch: The batch containing the data to extract
        schema: PyArrow schema defining the expected fields
        strict: If True, raises an exception when required fields are missing or unfilled

    Returns:
        Dictionary record matching the schema

    Raises:
        ValueError: If strict=True and required fields are missing or unfilled
    """
    record = {}
    unfilled_fields = []

    for field in schema.names:
        field_filled = False

        if field.endswith('_bytes'):
            # Handle binary tensor data - convert numpy array or tensor to bytes
            tensor_name = field.replace('_bytes', '')
            tensor_data = getattr(preprocess_batch, tensor_name, None)
            if tensor_data is not None:
                try:
                    if hasattr(tensor_data, 'numpy'):  # torch tensor
                        record[field] = tensor_data.cpu().numpy().tobytes()
                        field_filled = True
                    elif hasattr(tensor_data, 'tobytes'):  # numpy array
                        record[field] = tensor_data.tobytes()
                        field_filled = True
                    else:
                        raise ValueError(
                            f"Unsupported tensor type for field {field}: {type(tensor_data)}"
                        )
                except Exception as e:
                    if strict:
                        raise ValueError(
                            f"Failed to convert tensor {tensor_name} to bytes: {e}"
                        ) from e
                    record[field] = b''  # Empty bytes for missing data
            else:
                record[field] = b''  # Empty bytes for missing data

        elif field.endswith('_shape'):
            # Handle tensor shape info
            tensor_name = field.replace('_shape', '')
            tensor_data = getattr(preprocess_batch, tensor_name, None)
            if tensor_data is not None and hasattr(tensor_data, 'shape'):
                record[field] = list(tensor_data.shape)
                field_filled = True
            else:
                record[field] = []

        elif field.endswith('_dtype'):
            # Handle tensor dtype info
            tensor_name = field.replace('_dtype', '')
            tensor_data = getattr(preprocess_batch, tensor_name, None)
            if tensor_data is not None and hasattr(tensor_data, 'dtype'):
                record[field] = str(tensor_data.dtype)
                field_filled = True
            else:
                record[field] = 'unknown'

        elif field in ['width', 'height', 'num_frames']:
            # Handle integer metadata fields
            value = getattr(preprocess_batch, field, None)
            if value is not None:
                try:
                    record[field] = int(value)
                    field_filled = True
                except (ValueError, TypeError) as e:
                    if strict:
                        raise ValueError(
                            f"Failed to convert field {field} to int: {e}"
                        ) from e
                    record[field] = 0
            else:
                record[field] = 0

        elif field in ['duration_sec', 'fps']:
            # Handle float metadata fields
            # Map schema field names to batch attribute names
            attr_name = 'duration' if field == 'duration_sec' else field
            value = getattr(preprocess_batch, attr_name, None)
            if value is not None:
                try:
                    record[field] = float(value)
                    field_filled = True
                except (ValueError, TypeError) as e:
                    if strict:
                        raise ValueError(
                            f"Failed to convert field {field} to float: {e}"
                        ) from e
                    record[field] = 0.0
            else:
                record[field] = 0.0

        else:
            # Handle string fields (id, file_name, caption, media_type, etc.)
            # Map common schema field names to batch attribute names
            attr_name = field
            if field == 'caption':
                attr_name = 'text'
            elif field == 'file_name':
                attr_name = 'path'
            elif field == 'id':
                # Generate ID from path if available
                path_value = getattr(preprocess_batch, 'path', None)
                if path_value:
                    import os
                    record[field] = os.path.basename(path_value).split(
                        '.')[0]
                    field_filled = True
                else:
                    record[field] = ""
                continue
            elif field == 'media_type':
                # Determine media type from path
                path_value = getattr(preprocess_batch, 'path', None)
                if path_value:
                    record[field] = 'video' if path_value.endswith(
                        '.mp4') else 'image'
                    field_filled = True
                else:
                    record[field] = ""
                continue

            value = getattr(preprocess_batch, attr_name, None)
            if value is not None:
                record[field] = str(value)
                field_filled = True
            else:
                record[field] = ""

        # Track unfilled fields
        if not field_filled:
            unfilled_fields.append(field)

    # Handle strict mode
    if strict and unfilled_fields:
        raise ValueError(
            f"Required fields were not filled: {unfilled_fields}")

    # Log unfilled fields as warning if not in strict mode
    if unfilled_fields:
        logger.warning(
            "Some fields were not filled and got default values: %s",
            unfilled_fields)

    return record
fastvideo.pipelines.preprocess.preprocess_pipeline_base.BasePreprocessPipeline.get_extra_features
get_extra_features(valid_data: dict[str, Any], fastvideo_args: FastVideoArgs) -> dict[str, Any]

Get additional features specific to the pipeline type. Override in subclasses.

Source code in fastvideo/pipelines/preprocess/preprocess_pipeline_base.py
def get_extra_features(self, valid_data: dict[str, Any],
                       fastvideo_args: FastVideoArgs) -> dict[str, Any]:
    """Get additional features specific to the pipeline type. Override in subclasses."""
    return {}
fastvideo.pipelines.preprocess.preprocess_pipeline_base.BasePreprocessPipeline.get_pyarrow_schema
get_pyarrow_schema() -> Schema

Return the PyArrow schema for this pipeline. Must be overridden.

Source code in fastvideo/pipelines/preprocess/preprocess_pipeline_base.py
def get_pyarrow_schema(self) -> pa.Schema:
    """Return the PyArrow schema for this pipeline. Must be overridden."""
    raise NotImplementedError
fastvideo.pipelines.preprocess.preprocess_pipeline_base.BasePreprocessPipeline.get_schema_fields
get_schema_fields() -> list[str]

Get the schema fields for the pipeline type.

Source code in fastvideo/pipelines/preprocess/preprocess_pipeline_base.py
def get_schema_fields(self) -> list[str]:
    """Get the schema fields for the pipeline type."""
    return [f.name for f in self.get_pyarrow_schema()]

Functions