Skip to content

preprocessing_datasets

Classes

fastvideo.dataset.preprocessing_datasets.DataValidationStage

Bases: DatasetFilterStage

Stage for validating data items.

Functions

fastvideo.dataset.preprocessing_datasets.DataValidationStage.process
process(batch: PreprocessBatch, **kwargs) -> PreprocessBatch

Process does nothing for validation - filtering is handled by should_keep.

Source code in fastvideo/dataset/preprocessing_datasets.py
def process(self, batch: PreprocessBatch, **kwargs) -> PreprocessBatch:
    """Process does nothing for validation - filtering is handled by should_keep."""
    return batch
fastvideo.dataset.preprocessing_datasets.DataValidationStage.should_keep
should_keep(batch: PreprocessBatch, **kwargs) -> bool

Validate data item.

Parameters:

Name Type Description Default
batch PreprocessBatch

Dataset batch to validate

required

Returns:

Type Description
bool

True if valid, False if invalid

Source code in fastvideo/dataset/preprocessing_datasets.py
def should_keep(self, batch: PreprocessBatch, **kwargs) -> bool:
    """
    Validate data item.

    Args:
        batch: Dataset batch to validate

    Returns:
        True if valid, False if invalid
    """
    # Check for caption
    if batch.cap is None:
        return False

    if batch.is_video:
        # Validate video-specific fields
        if batch.duration is None or batch.fps is None:
            return False
    elif not batch.is_image:
        return False

    return True

fastvideo.dataset.preprocessing_datasets.DatasetFilterStage

Bases: ABC

Abstract base class for dataset filtering stages.

These stages can filter out items during metadata processing.

Functions

fastvideo.dataset.preprocessing_datasets.DatasetFilterStage.process abstractmethod
process(batch: PreprocessBatch, **kwargs) -> PreprocessBatch

Process the dataset batch (for non-filtering operations).

Parameters:

Name Type Description Default
batch PreprocessBatch

Dataset batch to process

required
**kwargs

Additional processing parameters

{}

Returns:

Type Description
PreprocessBatch

Processed batch

Source code in fastvideo/dataset/preprocessing_datasets.py
@abstractmethod
def process(self, batch: PreprocessBatch, **kwargs) -> PreprocessBatch:
    """
    Process the dataset batch (for non-filtering operations).

    Args:
        batch: Dataset batch to process
        **kwargs: Additional processing parameters

    Returns:
        Processed batch
    """
    raise NotImplementedError
fastvideo.dataset.preprocessing_datasets.DatasetFilterStage.should_keep abstractmethod
should_keep(batch: PreprocessBatch, **kwargs) -> bool

Check if batch should be kept.

Parameters:

Name Type Description Default
batch PreprocessBatch

Dataset batch to check

required
**kwargs

Additional parameters

{}

Returns:

Type Description
bool

True if batch should be kept, False otherwise

Source code in fastvideo/dataset/preprocessing_datasets.py
@abstractmethod
def should_keep(self, batch: PreprocessBatch, **kwargs) -> bool:
    """
    Check if batch should be kept.

    Args:
        batch: Dataset batch to check
        **kwargs: Additional parameters

    Returns:
        True if batch should be kept, False otherwise
    """
    raise NotImplementedError

fastvideo.dataset.preprocessing_datasets.DatasetStage

Bases: ABC

Abstract base class for dataset processing stages.

Similar to PipelineStage but designed for dataset preprocessing operations.

Functions

fastvideo.dataset.preprocessing_datasets.DatasetStage.process abstractmethod
process(batch: PreprocessBatch, **kwargs) -> PreprocessBatch

Process the dataset batch.

Parameters:

Name Type Description Default
batch PreprocessBatch

Dataset batch to process

required
**kwargs

Additional processing parameters

{}

Returns:

Type Description
PreprocessBatch

Processed batch

Source code in fastvideo/dataset/preprocessing_datasets.py
@abstractmethod
def process(self, batch: PreprocessBatch, **kwargs) -> PreprocessBatch:
    """
    Process the dataset batch.

    Args:
        batch: Dataset batch to process
        **kwargs: Additional processing parameters

    Returns:
        Processed batch
    """
    raise NotImplementedError

fastvideo.dataset.preprocessing_datasets.FrameSamplingStage

FrameSamplingStage(num_frames: int, train_fps: int, speed_factor: int = 1, video_length_tolerance_range: float = 5.0, drop_short_ratio: float = 0.0, seed: int = 42)

Bases: DatasetFilterStage

Stage for temporal frame sampling and indexing.

Source code in fastvideo/dataset/preprocessing_datasets.py
def __init__(self,
             num_frames: int,
             train_fps: int,
             speed_factor: int = 1,
             video_length_tolerance_range: float = 5.0,
             drop_short_ratio: float = 0.0,
             seed: int = 42):
    self.num_frames = num_frames
    self.train_fps = train_fps
    self.speed_factor = speed_factor
    self.video_length_tolerance_range = video_length_tolerance_range
    self.drop_short_ratio = drop_short_ratio
    # Create a seeded random generator for deterministic sampling
    self.rng = random.Random(seed)

Functions

fastvideo.dataset.preprocessing_datasets.FrameSamplingStage.process
process(batch: PreprocessBatch, temporal_sample_fn=None, **kwargs) -> PreprocessBatch

Process frame sampling for video data items.

Parameters:

Name Type Description Default
batch PreprocessBatch

Dataset batch

required
temporal_sample_fn

Function for temporal sampling

None

Returns:

Type Description
PreprocessBatch

Updated batch with frame sampling info

Source code in fastvideo/dataset/preprocessing_datasets.py
def process(self,
            batch: PreprocessBatch,
            temporal_sample_fn=None,
            **kwargs) -> PreprocessBatch:
    """
    Process frame sampling for video data items.

    Args:
        batch: Dataset batch
        temporal_sample_fn: Function for temporal sampling

    Returns:
        Updated batch with frame sampling info
    """
    if batch.is_image:
        # For images, just add sample info
        batch.sample_frame_index = [0]
        batch.sample_num_frames = 1
        return batch

    assert batch.duration is not None and batch.fps is not None
    batch.num_frames = math.ceil(batch.fps * batch.duration)

    # Resample frame indices
    frame_interval = batch.fps / self.train_fps
    start_frame_idx = 0
    frame_indices = np.arange(start_frame_idx, batch.num_frames,
                              frame_interval).astype(int)

    # Temporal crop if too long
    if len(frame_indices) > self.num_frames:
        if temporal_sample_fn is not None:
            begin_index, end_index = temporal_sample_fn(len(frame_indices))
            frame_indices = frame_indices[begin_index:end_index]
        else:
            frame_indices = frame_indices[:self.num_frames]

    batch.sample_frame_index = frame_indices.tolist()
    batch.sample_num_frames = len(frame_indices)

    return batch
fastvideo.dataset.preprocessing_datasets.FrameSamplingStage.should_keep
should_keep(batch: PreprocessBatch, **kwargs) -> bool

Check if video should be kept based on length constraints.

Parameters:

Name Type Description Default
batch PreprocessBatch

Dataset batch

required

Returns:

Type Description
bool

True if should be kept, False otherwise

Source code in fastvideo/dataset/preprocessing_datasets.py
def should_keep(self, batch: PreprocessBatch, **kwargs) -> bool:
    """
    Check if video should be kept based on length constraints.

    Args:
        batch: Dataset batch

    Returns:
        True if should be kept, False otherwise
    """
    if batch.is_image:
        return True

    if batch.duration is None or batch.fps is None:
        return False

    num_frames = math.ceil(batch.fps * batch.duration)

    # Check if video is too long
    if (num_frames / batch.fps > self.video_length_tolerance_range *
        (self.num_frames / self.train_fps * self.speed_factor)):
        return False

    # Resample frame indices to check length
    frame_interval = batch.fps / self.train_fps
    start_frame_idx = 0
    frame_indices = np.arange(start_frame_idx, num_frames,
                              frame_interval).astype(int)

    # Filter short videos
    return not (len(frame_indices) < self.num_frames
                and self.rng.random() < self.drop_short_ratio)

fastvideo.dataset.preprocessing_datasets.ImageTransformStage

ImageTransformStage(transform, transform_topcrop)

Bases: DatasetStage

Stage for image data transformation.

Source code in fastvideo/dataset/preprocessing_datasets.py
def __init__(self, transform, transform_topcrop) -> None:
    self.transform = transform
    self.transform_topcrop = transform_topcrop

Functions

fastvideo.dataset.preprocessing_datasets.ImageTransformStage.process
process(batch: PreprocessBatch, **kwargs) -> PreprocessBatch

Transform image data.

Parameters:

Name Type Description Default
batch PreprocessBatch

Dataset batch with image information

required

Returns:

Type Description
PreprocessBatch

Batch with transformed image tensor

Source code in fastvideo/dataset/preprocessing_datasets.py
def process(self, batch: PreprocessBatch, **kwargs) -> PreprocessBatch:
    """
    Transform image data.

    Args:
        batch: Dataset batch with image information

    Returns:
        Batch with transformed image tensor
    """
    if not batch.is_image:
        return batch

    image = Image.open(batch.path).convert("RGB")
    image = torch.from_numpy(np.array(image))
    image = rearrange(image, "h w c -> c h w").unsqueeze(0)

    if self.transform_topcrop is not None:
        image = self.transform_topcrop(image)
    elif self.transform is not None:
        image = self.transform(image)

    image = image.transpose(0, 1)  # [1 C H W] -> [C 1 H W]
    image = image.float() / 127.5 - 1.0
    batch.pixel_values = image
    return batch

fastvideo.dataset.preprocessing_datasets.PreprocessBatch dataclass

PreprocessBatch(path: str, cap: str | list[str], resolution: dict | None = None, fps: float | None = None, duration: float | None = None, num_frames: int | None = None, sample_frame_index: list[int] | None = None, sample_num_frames: int | None = None, pixel_values: Tensor | None = None, text: str | None = None, input_ids: Tensor | None = None, cond_mask: Tensor | None = None)

Batch information for dataset processing stages.

This class holds all the information about a video-caption or image-caption pair as it moves through the processing pipeline. Fields are populated by different stages.

Attributes

fastvideo.dataset.preprocessing_datasets.PreprocessBatch.is_image property
is_image: bool

Check if this is an image item.

fastvideo.dataset.preprocessing_datasets.PreprocessBatch.is_video property
is_video: bool

Check if this is a video item.

fastvideo.dataset.preprocessing_datasets.ResolutionFilterStage

ResolutionFilterStage(max_h_div_w_ratio: float = 17 / 16, min_h_div_w_ratio: float = 8 / 16, max_height: int = 1024, max_width: int = 1024)

Bases: DatasetFilterStage

Stage for filtering data items based on resolution constraints.

Source code in fastvideo/dataset/preprocessing_datasets.py
def __init__(self,
             max_h_div_w_ratio: float = 17 / 16,
             min_h_div_w_ratio: float = 8 / 16,
             max_height: int = 1024,
             max_width: int = 1024):
    self.max_h_div_w_ratio = max_h_div_w_ratio
    self.min_h_div_w_ratio = min_h_div_w_ratio
    self.max_height = max_height
    self.max_width = max_width

Functions

fastvideo.dataset.preprocessing_datasets.ResolutionFilterStage.filter_resolution
filter_resolution(h: int, w: int, max_h_div_w_ratio: float, min_h_div_w_ratio: float) -> bool

Filter based on height/width ratio.

Source code in fastvideo/dataset/preprocessing_datasets.py
def filter_resolution(self, h: int, w: int, max_h_div_w_ratio: float,
                      min_h_div_w_ratio: float) -> bool:
    """Filter based on height/width ratio."""
    return h / w <= max_h_div_w_ratio and h / w >= min_h_div_w_ratio
fastvideo.dataset.preprocessing_datasets.ResolutionFilterStage.process
process(batch: PreprocessBatch, **kwargs) -> PreprocessBatch

Process does nothing for resolution filtering - filtering is handled by should_keep.

Source code in fastvideo/dataset/preprocessing_datasets.py
def process(self, batch: PreprocessBatch, **kwargs) -> PreprocessBatch:
    """Process does nothing for resolution filtering - filtering is handled by should_keep."""
    return batch
fastvideo.dataset.preprocessing_datasets.ResolutionFilterStage.should_keep
should_keep(batch: PreprocessBatch, **kwargs) -> bool

Check if data item passes resolution filtering.

Parameters:

Name Type Description Default
batch PreprocessBatch

Dataset batch with resolution information

required

Returns:

Type Description
bool

True if passes filter, False otherwise

Source code in fastvideo/dataset/preprocessing_datasets.py
def should_keep(self, batch: PreprocessBatch, **kwargs) -> bool:
    """
    Check if data item passes resolution filtering.

    Args:
        batch: Dataset batch with resolution information

    Returns:
        True if passes filter, False otherwise
    """
    # Only apply to videos
    if not batch.is_video:
        return True

    if batch.resolution is None:
        return False

    height = batch.resolution.get("height", None)
    width = batch.resolution.get("width", None)
    if height is None or width is None:
        return False

    # Check aspect ratio
    aspect = self.max_height / self.max_width
    hw_aspect_thr = 1.5

    return self.filter_resolution(
        height,
        width,
        max_h_div_w_ratio=hw_aspect_thr * aspect,
        min_h_div_w_ratio=1 / hw_aspect_thr * aspect,
    )

fastvideo.dataset.preprocessing_datasets.TextDataset

TextDataset(data_merge_path: str, args, start_idx: int = 0, seed: int = 42)

Bases: IterableDataset, Stateful

Text-only dataset for processing prompts from a simple text file.

Assumes that data_merge_path is a text file with one prompt per line: A cat playing with a ball A dog running in the park A person cooking dinner ...

This dataset processes text data through text encoding stages only.

Source code in fastvideo/dataset/preprocessing_datasets.py
def __init__(self,
             data_merge_path: str,
             args,
             start_idx: int = 0,
             seed: int = 42):
    self.data_merge_path = data_merge_path
    self.start_idx = start_idx
    self.args = args
    self.seed = seed

    # Initialize tokenizer
    tokenizer_path = os.path.join(args.model_path, "tokenizer")
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_path,
                                              cache_dir=args.cache_dir)

    # Initialize text encoding stage
    self.text_encoding_stage = TextEncodingStage(
        tokenizer=tokenizer,
        text_max_length=args.text_max_length,
        cfg_rate=getattr(args, 'training_cfg_rate', 0.0),
        seed=self.seed)

    # Process text data
    self.processed_batches = self._process_text_data()

Functions

fastvideo.dataset.preprocessing_datasets.TextDataset.__iter__
__iter__()

Iterator for the dataset.

Source code in fastvideo/dataset/preprocessing_datasets.py
def __iter__(self):
    """Iterator for the dataset."""
    # Set up distributed sampling if needed
    if torch.distributed.is_available() and torch.distributed.is_initialized():
        rank = torch.distributed.get_rank()
        world_size = torch.distributed.get_world_size()
    else:
        rank = 0
        world_size = 1

    # Calculate chunk for this rank
    total_items = len(self.processed_batches)
    items_per_rank = math.ceil(total_items / world_size)
    start_idx = rank * items_per_rank + self.start_idx
    end_idx = min(start_idx + items_per_rank, total_items)

    # Yield items for this rank
    for idx in range(start_idx, end_idx):
        if idx < len(self.processed_batches):
            yield self._get_item(idx)
fastvideo.dataset.preprocessing_datasets.TextDataset.load_state_dict
load_state_dict(state_dict: dict[str, Any]) -> None

Load state dict from checkpoint.

Source code in fastvideo/dataset/preprocessing_datasets.py
def load_state_dict(self, state_dict: dict[str, Any]) -> None:
    """Load state dict from checkpoint."""
    self.processed_batches = state_dict["processed_batches"]
fastvideo.dataset.preprocessing_datasets.TextDataset.state_dict
state_dict() -> dict[str, Any]

Return state dict for checkpointing.

Source code in fastvideo/dataset/preprocessing_datasets.py
def state_dict(self) -> dict[str, Any]:
    """Return state dict for checkpointing."""
    return {"processed_batches": self.processed_batches}

fastvideo.dataset.preprocessing_datasets.TextEncodingStage

TextEncodingStage(tokenizer, text_max_length: int, cfg_rate: float = 0.0, seed: int = 42)

Bases: DatasetStage

Stage for text tokenization and encoding.

Source code in fastvideo/dataset/preprocessing_datasets.py
def __init__(self,
             tokenizer,
             text_max_length: int,
             cfg_rate: float = 0.0,
             seed: int = 42):
    self.tokenizer = tokenizer
    self.text_max_length = text_max_length
    self.cfg_rate = cfg_rate
    # Create a seeded random generator for deterministic CFG
    self.rng = random.Random(seed)

Functions

fastvideo.dataset.preprocessing_datasets.TextEncodingStage.process
process(batch: PreprocessBatch, **kwargs) -> PreprocessBatch

Process text data.

Parameters:

Name Type Description Default
batch PreprocessBatch

Dataset batch with caption information

required

Returns:

Type Description
PreprocessBatch

Batch with encoded text information

Source code in fastvideo/dataset/preprocessing_datasets.py
def process(self, batch: PreprocessBatch, **kwargs) -> PreprocessBatch:
    """
    Process text data.

    Args:
        batch: Dataset batch with caption information

    Returns:
        Batch with encoded text information
    """
    text = batch.cap
    if not isinstance(text, list):
        text = [text]
    text = [self.rng.choice(text)]

    text = text[0] if self.rng.random() > self.cfg_rate else ""
    text_tokens_and_mask = self.tokenizer(
        text,
        max_length=self.text_max_length,
        padding="max_length",
        truncation=True,
        return_attention_mask=True,
        add_special_tokens=True,
        return_tensors="pt",
    )

    batch.text = text
    batch.input_ids = text_tokens_and_mask["input_ids"]
    batch.cond_mask = text_tokens_and_mask["attention_mask"]
    return batch

fastvideo.dataset.preprocessing_datasets.VideoCaptionMergedDataset

VideoCaptionMergedDataset(data_merge_path: str, args, transform, temporal_sample, transform_topcrop, start_idx: int = 0, seed: int = 42)

Bases: IterableDataset, Stateful

Merged dataset for video and caption data with stage-based processing. Assumes that data_merge_path is a txt file with the following format: ,

The folder should contain videos.

The json file should be a list of dictionaries with the following format:
[
{
    "path": "1gGQy4nxyUo-Scene-016.mp4",
    "resolution": {
    "width": 1920,
    "height": 1080
    },
    "size": 2439112,
    "fps": 25.0,
    "duration": 6.88,
    "num_frames": 172,
    "cap": [
    "A watermelon wearing a helmet is crushed by a hydraulic press, causing it to flatten and burst open."
    ]
},
...
]

This dataset processes video and image data through a series of stages: - Data validation - Resolution filtering
- Frame sampling - Transformation - Text encoding

Source code in fastvideo/dataset/preprocessing_datasets.py
def __init__(self,
             data_merge_path: str,
             args,
             transform,
             temporal_sample,
             transform_topcrop,
             start_idx: int = 0,
             seed: int = 42):
    self.data_merge_path = data_merge_path
    self.start_idx = start_idx
    self.args = args
    self.temporal_sample = temporal_sample
    self.seed = seed

    # Initialize tokenizer
    tokenizer_path = os.path.join(args.model_path, "tokenizer")
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_path,
                                              cache_dir=args.cache_dir)

    # Initialize processing stages
    self._init_stages(args, transform, transform_topcrop, tokenizer)

    # Process metadata
    self.processed_batches = self._process_metadata()

Functions

fastvideo.dataset.preprocessing_datasets.VideoCaptionMergedDataset.__iter__
__iter__()

Iterate through processed data items.

Source code in fastvideo/dataset/preprocessing_datasets.py
def __iter__(self):
    """Iterate through processed data items."""
    for idx in range(len(self.processed_batches)):
        yield self._get_item(idx)
fastvideo.dataset.preprocessing_datasets.VideoCaptionMergedDataset.load_state_dict
load_state_dict(state_dict: dict[str, Any]) -> None

Load state dict from checkpoint.

Source code in fastvideo/dataset/preprocessing_datasets.py
def load_state_dict(self, state_dict: dict[str, Any]) -> None:
    """Load state dict from checkpoint."""
    self.processed_batches = state_dict["processed_batches"]
fastvideo.dataset.preprocessing_datasets.VideoCaptionMergedDataset.state_dict
state_dict() -> dict[str, Any]

Return state dict for checkpointing.

Source code in fastvideo/dataset/preprocessing_datasets.py
def state_dict(self) -> dict[str, Any]:
    """Return state dict for checkpointing."""
    return {"processed_batches": self.processed_batches}

fastvideo.dataset.preprocessing_datasets.VideoTransformStage

VideoTransformStage(transform)

Bases: DatasetStage

Stage for video data transformation.

Source code in fastvideo/dataset/preprocessing_datasets.py
def __init__(self, transform) -> None:
    self.transform = transform

Functions

fastvideo.dataset.preprocessing_datasets.VideoTransformStage.process
process(batch: PreprocessBatch, **kwargs) -> PreprocessBatch

Transform video data.

Parameters:

Name Type Description Default
batch PreprocessBatch

Dataset batch with video information

required

Returns:

Type Description
PreprocessBatch

Batch with transformed video tensor

Source code in fastvideo/dataset/preprocessing_datasets.py
def process(self, batch: PreprocessBatch, **kwargs) -> PreprocessBatch:
    """
    Transform video data.

    Args:
        batch: Dataset batch with video information

    Returns:
        Batch with transformed video tensor
    """
    if not batch.is_video:
        return batch

    assert os.path.exists(batch.path), f"file {batch.path} do not exist!"
    assert batch.sample_frame_index is not None, "Frame indices must be set before transformation"

    torchvision_video, _, metadata = torchvision.io.read_video(
        batch.path, output_format="TCHW")
    video = torchvision_video[batch.sample_frame_index]
    if self.transform is not None:
        video = self.transform(video)
    video = rearrange(video, "t c h w -> c t h w")
    video = video.to(torch.uint8)

    h, w = video.shape[-2:]
    assert (
        h / w <= 17 / 16 and h / w >= 8 / 16
    ), f"Only videos with a ratio (h/w) less than 17/16 and more than 8/16 are supported. But video ({batch.path}) found ratio is {round(h / w, 2)} with the shape of {video.shape}"

    video = video.float() / 127.5 - 1.0
    batch.pixel_values = video
    return batch

Functions