Skip to content

dataset

Classes

fastvideo.dataset.TextDataset

TextDataset(data_merge_path: str, args, start_idx: int = 0, seed: int = 42)

Bases: IterableDataset, Stateful

Text-only dataset for processing prompts from a simple text file.

Assumes that data_merge_path is a text file with one prompt per line: A cat playing with a ball A dog running in the park A person cooking dinner ...

This dataset processes text data through text encoding stages only.

Source code in fastvideo/dataset/preprocessing_datasets.py
def __init__(self,
             data_merge_path: str,
             args,
             start_idx: int = 0,
             seed: int = 42):
    self.data_merge_path = data_merge_path
    self.start_idx = start_idx
    self.args = args
    self.seed = seed

    # Initialize tokenizer
    tokenizer_path = os.path.join(args.model_path, "tokenizer")
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_path,
                                              cache_dir=args.cache_dir)

    # Initialize text encoding stage
    self.text_encoding_stage = TextEncodingStage(
        tokenizer=tokenizer,
        text_max_length=args.text_max_length,
        cfg_rate=getattr(args, 'training_cfg_rate', 0.0),
        seed=self.seed)

    # Process text data
    self.processed_batches = self._process_text_data()

Functions

fastvideo.dataset.TextDataset.__iter__
__iter__()

Iterator for the dataset.

Source code in fastvideo/dataset/preprocessing_datasets.py
def __iter__(self):
    """Iterator for the dataset."""
    # Set up distributed sampling if needed
    if torch.distributed.is_available() and torch.distributed.is_initialized():
        rank = torch.distributed.get_rank()
        world_size = torch.distributed.get_world_size()
    else:
        rank = 0
        world_size = 1

    # Calculate chunk for this rank
    total_items = len(self.processed_batches)
    items_per_rank = math.ceil(total_items / world_size)
    start_idx = rank * items_per_rank + self.start_idx
    end_idx = min(start_idx + items_per_rank, total_items)

    # Yield items for this rank
    for idx in range(start_idx, end_idx):
        if idx < len(self.processed_batches):
            yield self._get_item(idx)
fastvideo.dataset.TextDataset.load_state_dict
load_state_dict(state_dict: dict[str, Any]) -> None

Load state dict from checkpoint.

Source code in fastvideo/dataset/preprocessing_datasets.py
def load_state_dict(self, state_dict: dict[str, Any]) -> None:
    """Load state dict from checkpoint."""
    self.processed_batches = state_dict["processed_batches"]
fastvideo.dataset.TextDataset.state_dict
state_dict() -> dict[str, Any]

Return state dict for checkpointing.

Source code in fastvideo/dataset/preprocessing_datasets.py
def state_dict(self) -> dict[str, Any]:
    """Return state dict for checkpointing."""
    return {"processed_batches": self.processed_batches}

fastvideo.dataset.ValidationDataset

ValidationDataset(filename: str)

Bases: IterableDataset

Source code in fastvideo/dataset/validation_dataset.py
def __init__(self, filename: str):
    super().__init__()

    self.filename = pathlib.Path(filename)
    # get directory of filename
    self.dir = os.path.abspath(self.filename.parent)

    if not self.filename.exists():
        raise FileNotFoundError(
            f"File {self.filename.as_posix()} does not exist")

    if self.filename.suffix == ".csv":
        data = datasets.load_dataset("csv",
                                     data_files=self.filename.as_posix(),
                                     split="train")
    elif self.filename.suffix == ".json":
        data = datasets.load_dataset("json",
                                     data_files=self.filename.as_posix(),
                                     split="train",
                                     field="data")
    elif self.filename.suffix == ".parquet":
        data = datasets.load_dataset("parquet",
                                     data_files=self.filename.as_posix(),
                                     split="train")
    elif self.filename.suffix == ".arrow":
        data = datasets.load_dataset("arrow",
                                     data_files=self.filename.as_posix(),
                                     split="train")
    else:
        _SUPPORTED_FILE_FORMATS = [".csv", ".json", ".parquet", ".arrow"]
        raise ValueError(
            f"Unsupported file format {self.filename.suffix} for validation dataset. Supported formats are: {_SUPPORTED_FILE_FORMATS}"
        )

    # Get distributed training info
    self.global_rank = get_world_rank()
    self.world_size = get_world_size()
    self.sp_world_size = get_sp_world_size()
    self.num_sp_groups = self.world_size // self.sp_world_size

    # Convert to list to get total samples
    self.all_samples = list(data)
    self.original_total_samples = len(self.all_samples)

    # Extend samples to be a multiple of DP degree (num_sp_groups)
    remainder = self.original_total_samples % self.num_sp_groups
    if remainder != 0:
        samples_to_add = self.num_sp_groups - remainder

        # Duplicate samples cyclically to reach the target
        additional_samples = []
        for i in range(samples_to_add):
            additional_samples.append(
                self.all_samples[i % self.original_total_samples])

        self.all_samples.extend(additional_samples)

    self.total_samples = len(self.all_samples)

    # Calculate which SP group this rank belongs to
    self.sp_group_id = self.global_rank // self.sp_world_size

    # Now all SP groups will have equal number of samples
    self.samples_per_sp_group = self.total_samples // self.num_sp_groups

    # Calculate start and end indices for this SP group
    self.start_idx = self.sp_group_id * self.samples_per_sp_group
    self.end_idx = self.start_idx + self.samples_per_sp_group

    # Get samples for this SP group
    self.sp_group_samples = self.all_samples[self.start_idx:self.end_idx]

    logger.info(
        "Rank %s (SP group %s): "
        "Original samples: %s, "
        "Extended samples: %s, "
        "SP group samples: %s, "
        "Range: [%s:%s]",
        self.global_rank,
        self.sp_group_id,
        self.original_total_samples,
        self.total_samples,
        len(self.sp_group_samples),
        self.start_idx,
        self.end_idx,
        local_main_process_only=False)

Functions

fastvideo.dataset.ValidationDataset.__len__
__len__()

Return the number of samples for this SP group.

Source code in fastvideo/dataset/validation_dataset.py
def __len__(self):
    """Return the number of samples for this SP group."""
    return len(self.sp_group_samples)

fastvideo.dataset.VideoCaptionMergedDataset

VideoCaptionMergedDataset(data_merge_path: str, args, transform, temporal_sample, transform_topcrop, start_idx: int = 0, seed: int = 42)

Bases: IterableDataset, Stateful

Merged dataset for video and caption data with stage-based processing. Assumes that data_merge_path is a txt file with the following format: ,

The folder should contain videos.

The json file should be a list of dictionaries with the following format:
[
{
    "path": "1gGQy4nxyUo-Scene-016.mp4",
    "resolution": {
    "width": 1920,
    "height": 1080
    },
    "size": 2439112,
    "fps": 25.0,
    "duration": 6.88,
    "num_frames": 172,
    "cap": [
    "A watermelon wearing a helmet is crushed by a hydraulic press, causing it to flatten and burst open."
    ]
},
...
]

This dataset processes video and image data through a series of stages: - Data validation - Resolution filtering
- Frame sampling - Transformation - Text encoding

Source code in fastvideo/dataset/preprocessing_datasets.py
def __init__(self,
             data_merge_path: str,
             args,
             transform,
             temporal_sample,
             transform_topcrop,
             start_idx: int = 0,
             seed: int = 42):
    self.data_merge_path = data_merge_path
    self.start_idx = start_idx
    self.args = args
    self.temporal_sample = temporal_sample
    self.seed = seed

    # Initialize tokenizer
    tokenizer_path = os.path.join(args.model_path, "tokenizer")
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_path,
                                              cache_dir=args.cache_dir)

    # Initialize processing stages
    self._init_stages(args, transform, transform_topcrop, tokenizer)

    # Process metadata
    self.processed_batches = self._process_metadata()

Functions

fastvideo.dataset.VideoCaptionMergedDataset.__iter__
__iter__()

Iterate through processed data items.

Source code in fastvideo/dataset/preprocessing_datasets.py
def __iter__(self):
    """Iterate through processed data items."""
    for idx in range(len(self.processed_batches)):
        yield self._get_item(idx)
fastvideo.dataset.VideoCaptionMergedDataset.load_state_dict
load_state_dict(state_dict: dict[str, Any]) -> None

Load state dict from checkpoint.

Source code in fastvideo/dataset/preprocessing_datasets.py
def load_state_dict(self, state_dict: dict[str, Any]) -> None:
    """Load state dict from checkpoint."""
    self.processed_batches = state_dict["processed_batches"]
fastvideo.dataset.VideoCaptionMergedDataset.state_dict
state_dict() -> dict[str, Any]

Return state dict for checkpointing.

Source code in fastvideo/dataset/preprocessing_datasets.py
def state_dict(self) -> dict[str, Any]:
    """Return state dict for checkpointing."""
    return {"processed_batches": self.processed_batches}

Modules

fastvideo.dataset.parquet_dataset_iterable_style

Classes

fastvideo.dataset.parquet_dataset_iterable_style.LatentsParquetIterStyleDataset
LatentsParquetIterStyleDataset(path: str, batch_size: int = 1024, cfg_rate: float = 0.1, num_workers: int = 1, drop_last: bool = True, text_padding_length: int = 512, seed: int = 42, read_batch_size: int = 32, parquet_schema: Schema = None)

Bases: IterableDataset

Efficient loader for video-text data from a directory of Parquet files.

Source code in fastvideo/dataset/parquet_dataset_iterable_style.py
def __init__(self,
             path: str,
             batch_size: int = 1024,
             cfg_rate: float = 0.1,
             num_workers: int = 1,
             drop_last: bool = True,
             text_padding_length: int = 512,
             seed: int = 42,
             read_batch_size: int = 32,
             parquet_schema: pa.Schema = None):
    super().__init__()
    self.path = str(path)
    self.batch_size = batch_size
    self.parquet_schema = parquet_schema
    self.cfg_rate = cfg_rate
    self.text_padding_length = text_padding_length
    self.seed = seed
    self.read_batch_size = read_batch_size
    # Get distributed training info
    self.global_rank = get_world_rank()
    self.world_size = get_world_size()
    self.sp_world_size = get_sp_world_size()
    self.num_sp_groups = self.world_size // self.sp_world_size
    num_workers = 1 if num_workers == 0 else num_workers
    # Get sharding info
    shard_parquet_files, shard_total_samples, shard_parquet_lengths = shard_parquet_files_across_sp_groups_and_workers(
        self.path, self.num_sp_groups, num_workers, seed)

    if drop_last:
        self.worker_num_samples = min(
            shard_total_samples) // batch_size * batch_size
        # Assign files to current rank's SP group
        ith_sp_group = self.global_rank // self.sp_world_size
        self.sp_group_parquet_files = shard_parquet_files[ith_sp_group::self
                                                          .num_sp_groups]
        self.sp_group_parquet_lengths = shard_parquet_lengths[
            ith_sp_group::self.num_sp_groups]
        self.sp_group_num_samples = shard_total_samples[ith_sp_group::self.
                                                        num_sp_groups]
        logger.info(
            "In total %d parquet files, %d samples, after sharding we retain %d samples due to drop_last",
            sum([len(shard) for shard in shard_parquet_files]),
            sum(shard_total_samples),
            self.worker_num_samples * self.num_sp_groups * num_workers)
    else:
        raise ValueError("drop_last must be True")
    logger.info("Each dataloader worker will load %d samples",
                self.worker_num_samples)

Functions

fastvideo.dataset.parquet_dataset_iterable_style.build_parquet_iterable_style_dataloader
build_parquet_iterable_style_dataloader(path: str, batch_size: int, num_data_workers: int, cfg_rate: float = 0.0, drop_last: bool = True, text_padding_length: int = 512, seed: int = 42, read_batch_size: int = 32) -> tuple[LatentsParquetIterStyleDataset, StatefulDataLoader]

Build a dataloader for the LatentsParquetIterStyleDataset.

Source code in fastvideo/dataset/parquet_dataset_iterable_style.py
def build_parquet_iterable_style_dataloader(
    path: str,
    batch_size: int,
    num_data_workers: int,
    cfg_rate: float = 0.0,
    drop_last: bool = True,
    text_padding_length: int = 512,
    seed: int = 42,
    read_batch_size: int = 32
) -> tuple[LatentsParquetIterStyleDataset, StatefulDataLoader]:
    """Build a dataloader for the LatentsParquetIterStyleDataset."""
    dataset = LatentsParquetIterStyleDataset(
        path=path,
        batch_size=batch_size,
        cfg_rate=cfg_rate,
        num_workers=num_data_workers,
        drop_last=drop_last,
        text_padding_length=text_padding_length,
        seed=seed,
        read_batch_size=read_batch_size)

    loader = StatefulDataLoader(
        dataset,
        batch_size=1,
        num_workers=num_data_workers,
        pin_memory=True,
    )
    return dataset, loader
fastvideo.dataset.parquet_dataset_iterable_style.shard_parquet_files_across_sp_groups_and_workers
shard_parquet_files_across_sp_groups_and_workers(path: str, num_sp_groups: int, num_workers: int, seed: int = 42) -> tuple[list[list[str]], list[int], list[dict[str, int]]]

Shard parquet files across SP groups and workers in a balanced way.

Parameters:

Name Type Description Default
path str

Directory containing parquet files

required
num_sp_groups int

Number of SP groups to shard across

required
num_workers int

Number of workers per SP group

required
seed int

Random seed for shuffling

42

Returns:

Type Description
list[list[str]]

Tuple containing:

list[int]
  • List of lists of parquet files for each shard
list[dict[str, int]]
  • List of total samples per shard
tuple[list[list[str]], list[int], list[dict[str, int]]]
  • List of dictionaries mapping file paths to their lengths
Source code in fastvideo/dataset/parquet_dataset_iterable_style.py
def shard_parquet_files_across_sp_groups_and_workers(
    path: str,
    num_sp_groups: int,
    num_workers: int,
    seed: int = 42,
) -> tuple[list[list[str]], list[int], list[dict[str, int]]]:
    """
    Shard parquet files across SP groups and workers in a balanced way.

    Args:
        path: Directory containing parquet files
        num_sp_groups: Number of SP groups to shard across
        num_workers: Number of workers per SP group
        seed: Random seed for shuffling

    Returns:
        Tuple containing:
        - List of lists of parquet files for each shard
        - List of total samples per shard
        - List of dictionaries mapping file paths to their lengths
    """
    # Check if sharding plan already exists
    sharding_info_dir = os.path.join(
        path, f"sharding_info_{num_sp_groups}_sp_groups_{num_workers}_workers")

    # Only rank 0 handles cache checking and file scanning
    if get_world_rank() == 0:
        cache_loaded = False
        shard_parquet_files = None
        shard_total_samples = None
        shard_parquet_lengths = None

        # First try to load existing sharding plan
        if os.path.exists(sharding_info_dir):
            logger.info("Loading sharding plan from %s", sharding_info_dir)
            try:
                with open(
                        os.path.join(sharding_info_dir,
                                     "shard_parquet_files.pkl"), "rb") as f:
                    shard_parquet_files = pickle.load(f)
                with open(
                        os.path.join(sharding_info_dir,
                                     "shard_total_samples.pkl"), "rb") as f:
                    shard_total_samples = pickle.load(f)
                with open(
                        os.path.join(sharding_info_dir,
                                     "shard_parquet_lengths.pkl"), "rb") as f:
                    shard_parquet_lengths = pickle.load(f)
                cache_loaded = True
                logger.info("Successfully loaded sharding plan")
            except Exception as e:
                logger.error("Error loading sharding plan: %s", str(e))
                logger.info("Falling back to creating new sharding plan")
                cache_loaded = False

        # If cache not loaded (either doesn't exist or failed to load), create sharding plan
        if not cache_loaded:
            logger.info("Creating new sharding plan")
            logger.info("Scanning for parquet files in %s", path)

            # Find all parquet files
            parquet_files = []

            for root, _, files in os.walk(path):
                for file in files:
                    if file.endswith('.parquet'):
                        parquet_files.append(os.path.join(root, file))

            if not parquet_files:
                raise ValueError("No parquet files found in %s", path)

            # Calculate file lengths efficiently using a single pass
            logger.info("Calculating file lengths...")
            lengths = []
            for file in tqdm.tqdm(parquet_files, desc="Reading parquet files"):
                lengths.append(pq.ParquetFile(file).metadata.num_rows)

            total_samples = sum(lengths)
            logger.info("Found %d files with %d total samples",
                        len(parquet_files), total_samples)

            # Sort files by length for better balancing
            sorted_indices = np.argsort(lengths)
            sorted_files = [parquet_files[i] for i in sorted_indices]
            sorted_lengths = [lengths[i] for i in sorted_indices]

            # Create shards
            num_shards = num_sp_groups * num_workers
            shard_parquet_files = [[] for _ in range(num_shards)]
            shard_total_samples = [0] * num_shards
            shard_parquet_lengths = [{} for _ in range(num_shards)]

            # Distribute files to shards using a greedy approach
            logger.info("Distributing files to shards...")
            for file, length in zip(reversed(sorted_files),
                                    reversed(sorted_lengths),
                                    strict=True):
                # Find shard with minimum current length
                target_shard = np.argmin(shard_total_samples)
                shard_parquet_files[target_shard].append(file)
                shard_total_samples[target_shard] += length
                shard_parquet_lengths[target_shard][file] = length
            #randomize each shard
            for shard in shard_parquet_files:
                rng = random.Random(seed)
                rng.shuffle(shard)

            # Save the sharding plan
            os.makedirs(sharding_info_dir, exist_ok=True)
            with open(
                    os.path.join(sharding_info_dir, "shard_parquet_files.pkl"),
                    "wb") as f:
                pickle.dump(shard_parquet_files, f)
            with open(
                    os.path.join(sharding_info_dir, "shard_total_samples.pkl"),
                    "wb") as f:
                pickle.dump(shard_total_samples, f)
            with open(
                    os.path.join(sharding_info_dir,
                                 "shard_parquet_lengths.pkl"), "wb") as f:
                pickle.dump(shard_parquet_lengths, f)
            logger.info("Saved sharding info to %s", sharding_info_dir)

    # Wait for rank 0 to finish creating/loading sharding plan
    world_group = get_world_group()
    world_group.barrier()

    # Now all ranks load the sharding plan (it should exist and be valid now)
    logger.info("Loading sharding plan from %s after barrier",
                sharding_info_dir)
    with open(os.path.join(sharding_info_dir, "shard_parquet_files.pkl"),
              "rb") as f:
        shard_parquet_files = pickle.load(f)
    with open(os.path.join(sharding_info_dir, "shard_total_samples.pkl"),
              "rb") as f:
        shard_total_samples = pickle.load(f)
    with open(os.path.join(sharding_info_dir, "shard_parquet_lengths.pkl"),
              "rb") as f:
        shard_parquet_lengths = pickle.load(f)

    return shard_parquet_files, shard_total_samples, shard_parquet_lengths

fastvideo.dataset.parquet_dataset_map_style

Classes

fastvideo.dataset.parquet_dataset_map_style.DP_SP_BatchSampler
DP_SP_BatchSampler(batch_size: int, dataset_size: int, num_sp_groups: int, sp_world_size: int, global_rank: int, drop_last: bool = True, drop_first_row: bool = False, seed: int = 0)

Bases: Sampler[list[int]]

A simple sequential batch sampler that yields batches of indices.

Source code in fastvideo/dataset/parquet_dataset_map_style.py
def __init__(
    self,
    batch_size: int,
    dataset_size: int,
    num_sp_groups: int,
    sp_world_size: int,
    global_rank: int,
    drop_last: bool = True,
    drop_first_row: bool = False,
    seed: int = 0,
):
    self.batch_size = batch_size
    self.dataset_size = dataset_size
    self.drop_last = drop_last
    self.seed = seed
    self.num_sp_groups = num_sp_groups
    self.global_rank = global_rank
    self.sp_world_size = sp_world_size

    # ── epoch-level RNG ────────────────────────────────────────────────
    rng = torch.Generator().manual_seed(self.seed)
    # Create a random permutation of all indices
    global_indices = torch.randperm(self.dataset_size, generator=rng)

    if drop_first_row:
        # drop 0 in global_indices
        global_indices = global_indices[global_indices != 0]
        self.dataset_size = self.dataset_size - 1

    if self.drop_last:
        # For drop_last=True, we:
        # 1. Ensure total samples is divisible by (batch_size * num_sp_groups)
        # 2. This guarantees each SP group gets same number of complete batches
        # 3. Prevents uneven batch sizes across SP groups at end of epoch
        num_batches = self.dataset_size // self.batch_size
        num_global_batches = num_batches // self.num_sp_groups
        global_indices = global_indices[:num_global_batches *
                                        self.num_sp_groups *
                                        self.batch_size]
    else:
        if self.dataset_size % (self.num_sp_groups * self.batch_size) != 0:
            # add more indices to make it divisible by (batch_size * num_sp_groups)
            padding_size = self.num_sp_groups * self.batch_size - (
                self.dataset_size % (self.num_sp_groups * self.batch_size))
            logger.info("Padding the dataset from %d to %d",
                        self.dataset_size, self.dataset_size + padding_size)
            global_indices = torch.cat(
                [global_indices, global_indices[:padding_size]])

    # shard the indices to each sp group
    ith_sp_group = self.global_rank // self.sp_world_size
    sp_group_local_indices = global_indices[ith_sp_group::self.
                                            num_sp_groups]
    self.sp_group_local_indices = sp_group_local_indices
    logger.info("Dataset size for each sp group: %d",
                len(sp_group_local_indices))
fastvideo.dataset.parquet_dataset_map_style.LatentsParquetMapStyleDataset
LatentsParquetMapStyleDataset(path: str, batch_size: int, parquet_schema: Schema, cfg_rate: float = 0.0, seed: int = 42, drop_last: bool = True, drop_first_row: bool = False, text_padding_length: int = 512)

Bases: Dataset

Return latents[B,C,T,H,W] and embeddings[B,L,D] in pinned CPU memory. Note: Using parquet for map style dataset is not efficient, we mainly keep it for backward compatibility and debugging.

Source code in fastvideo/dataset/parquet_dataset_map_style.py
def __init__(
    self,
    path: str,
    batch_size: int,
    parquet_schema: pa.Schema,
    cfg_rate: float = 0.0,
    seed: int = 42,
    drop_last: bool = True,
    drop_first_row: bool = False,
    text_padding_length: int = 512,
):
    super().__init__()
    self.path = path
    self.cfg_rate = cfg_rate
    self.parquet_schema = parquet_schema
    self.seed = seed
    # Create a seeded random generator for deterministic CFG
    self.rng = random.Random(seed)
    logger.info("Initializing LatentsParquetMapStyleDataset with path: %s",
                path)
    self.parquet_files, self.lengths = get_parquet_files_and_length(path)
    self.batch = batch_size
    self.text_padding_length = text_padding_length
    self.sampler = DP_SP_BatchSampler(
        batch_size=batch_size,
        dataset_size=sum(self.lengths),
        num_sp_groups=get_world_size() // get_sp_world_size(),
        sp_world_size=get_sp_world_size(),
        global_rank=get_world_rank(),
        drop_last=drop_last,
        drop_first_row=drop_first_row,
        seed=seed,
    )
    logger.info("Dataset initialized with %d parquet files and %d rows",
                len(self.parquet_files), sum(self.lengths))
Functions
fastvideo.dataset.parquet_dataset_map_style.LatentsParquetMapStyleDataset.__getitems__
__getitems__(indices: list[int]) -> dict[str, Any]

Batch fetch using read_row_from_parquet_file for each index.

Source code in fastvideo/dataset/parquet_dataset_map_style.py
def __getitems__(self, indices: list[int]) -> dict[str, Any]:
    """
    Batch fetch using read_row_from_parquet_file for each index.
    """
    rows = [
        read_row_from_parquet_file(self.parquet_files, idx, self.lengths)
        for idx in indices
    ]

    batch = collate_rows_from_parquet_schema(rows,
                                             self.parquet_schema,
                                             self.text_padding_length,
                                             cfg_rate=self.cfg_rate,
                                             rng=self.rng)
    return batch
fastvideo.dataset.parquet_dataset_map_style.LatentsParquetMapStyleDataset.get_validation_negative_prompt
get_validation_negative_prompt() -> tuple[Tensor, Tensor, str]

Get the negative prompt for validation. This method ensures the negative prompt is loaded and cached properly. Returns the processed negative prompt data (latents, embeddings, masks, info).

Source code in fastvideo/dataset/parquet_dataset_map_style.py
def get_validation_negative_prompt(
        self) -> tuple[torch.Tensor, torch.Tensor, str]:
    """
    Get the negative prompt for validation. 
    This method ensures the negative prompt is loaded and cached properly.
    Returns the processed negative prompt data (latents, embeddings, masks, info).
    """

    # Read first row from first parquet file
    file_path = self.parquet_files[0]
    row_idx = 0
    # Read the negative prompt data
    row_dict = read_row_from_parquet_file([file_path], row_idx,
                                          [self.lengths[0]])

    batch = collate_rows_from_parquet_schema([row_dict],
                                             self.parquet_schema,
                                             self.text_padding_length,
                                             cfg_rate=0.0,
                                             rng=self.rng)
    negative_prompt = batch['info_list'][0]['prompt']
    negative_prompt_embedding = batch['text_embedding']
    negative_prompt_attention_mask = batch['text_attention_mask']
    if len(negative_prompt_embedding.shape) == 2:
        negative_prompt_embedding = negative_prompt_embedding.unsqueeze(0)
    if len(negative_prompt_attention_mask.shape) == 1:
        negative_prompt_attention_mask = negative_prompt_attention_mask.unsqueeze(
            0).unsqueeze(0)

    return negative_prompt_embedding, negative_prompt_attention_mask, negative_prompt

Functions

fastvideo.dataset.parquet_dataset_map_style.read_row_from_parquet_file
read_row_from_parquet_file(parquet_files: list[str], global_row_idx: int, lengths: list[int]) -> dict[str, Any]

Read a row from a parquet file. Args: parquet_files: List[str] global_row_idx: int lengths: List[int] Returns:

Source code in fastvideo/dataset/parquet_dataset_map_style.py
def read_row_from_parquet_file(parquet_files: list[str], global_row_idx: int,
                               lengths: list[int]) -> dict[str, Any]:
    '''
    Read a row from a parquet file.
    Args:
        parquet_files: List[str]
        global_row_idx: int
        lengths: List[int]
    Returns:
    '''
    # find the parquet file and local row index
    cumulative = 0
    file_index = 0
    local_row_idx = 0

    for file_index in range(len(lengths)):
        if cumulative + lengths[file_index] > global_row_idx:
            local_row_idx = global_row_idx - cumulative
            break
        cumulative += lengths[file_index]
    else:
        # If we reach here, global_row_idx is out of bounds
        raise IndexError(
            f"global_row_idx {global_row_idx} is out of bounds for dataset")

    parquet_file = pq.ParquetFile(parquet_files[file_index])

    # Calculate the row group to read into memory and the local idx
    # This way we can avoid reading in the entire parquet file
    cumulative = 0
    row_group_index = 0
    local_index = 0

    for i in range(parquet_file.num_row_groups):
        num_rows = parquet_file.metadata.row_group(i).num_rows
        if cumulative + num_rows > local_row_idx:
            row_group_index = i
            local_index = local_row_idx - cumulative
            break
        cumulative += num_rows
    else:
        # If we reach here, local_row_idx is out of bounds for this parquet file
        raise IndexError(
            f"local_row_idx {local_row_idx} is out of bounds for parquet file {parquet_files[file_index]}"
        )

    row_group = parquet_file.read_row_group(row_group_index).to_pydict()
    row_dict = {k: v[local_index] for k, v in row_group.items()}
    del row_group

    return row_dict

fastvideo.dataset.preprocessing_datasets

Classes

fastvideo.dataset.preprocessing_datasets.DataValidationStage

Bases: DatasetFilterStage

Stage for validating data items.

Functions
fastvideo.dataset.preprocessing_datasets.DataValidationStage.process
process(batch: PreprocessBatch, **kwargs) -> PreprocessBatch

Process does nothing for validation - filtering is handled by should_keep.

Source code in fastvideo/dataset/preprocessing_datasets.py
def process(self, batch: PreprocessBatch, **kwargs) -> PreprocessBatch:
    """Process does nothing for validation - filtering is handled by should_keep."""
    return batch
fastvideo.dataset.preprocessing_datasets.DataValidationStage.should_keep
should_keep(batch: PreprocessBatch, **kwargs) -> bool

Validate data item.

Parameters:

Name Type Description Default
batch PreprocessBatch

Dataset batch to validate

required

Returns:

Type Description
bool

True if valid, False if invalid

Source code in fastvideo/dataset/preprocessing_datasets.py
def should_keep(self, batch: PreprocessBatch, **kwargs) -> bool:
    """
    Validate data item.

    Args:
        batch: Dataset batch to validate

    Returns:
        True if valid, False if invalid
    """
    # Check for caption
    if batch.cap is None:
        return False

    if batch.is_video:
        # Validate video-specific fields
        if batch.duration is None or batch.fps is None:
            return False
    elif not batch.is_image:
        return False

    return True
fastvideo.dataset.preprocessing_datasets.DatasetFilterStage

Bases: ABC

Abstract base class for dataset filtering stages.

These stages can filter out items during metadata processing.

Functions
fastvideo.dataset.preprocessing_datasets.DatasetFilterStage.process abstractmethod
process(batch: PreprocessBatch, **kwargs) -> PreprocessBatch

Process the dataset batch (for non-filtering operations).

Parameters:

Name Type Description Default
batch PreprocessBatch

Dataset batch to process

required
**kwargs

Additional processing parameters

{}

Returns:

Type Description
PreprocessBatch

Processed batch

Source code in fastvideo/dataset/preprocessing_datasets.py
@abstractmethod
def process(self, batch: PreprocessBatch, **kwargs) -> PreprocessBatch:
    """
    Process the dataset batch (for non-filtering operations).

    Args:
        batch: Dataset batch to process
        **kwargs: Additional processing parameters

    Returns:
        Processed batch
    """
    raise NotImplementedError
fastvideo.dataset.preprocessing_datasets.DatasetFilterStage.should_keep abstractmethod
should_keep(batch: PreprocessBatch, **kwargs) -> bool

Check if batch should be kept.

Parameters:

Name Type Description Default
batch PreprocessBatch

Dataset batch to check

required
**kwargs

Additional parameters

{}

Returns:

Type Description
bool

True if batch should be kept, False otherwise

Source code in fastvideo/dataset/preprocessing_datasets.py
@abstractmethod
def should_keep(self, batch: PreprocessBatch, **kwargs) -> bool:
    """
    Check if batch should be kept.

    Args:
        batch: Dataset batch to check
        **kwargs: Additional parameters

    Returns:
        True if batch should be kept, False otherwise
    """
    raise NotImplementedError
fastvideo.dataset.preprocessing_datasets.DatasetStage

Bases: ABC

Abstract base class for dataset processing stages.

Similar to PipelineStage but designed for dataset preprocessing operations.

Functions
fastvideo.dataset.preprocessing_datasets.DatasetStage.process abstractmethod
process(batch: PreprocessBatch, **kwargs) -> PreprocessBatch

Process the dataset batch.

Parameters:

Name Type Description Default
batch PreprocessBatch

Dataset batch to process

required
**kwargs

Additional processing parameters

{}

Returns:

Type Description
PreprocessBatch

Processed batch

Source code in fastvideo/dataset/preprocessing_datasets.py
@abstractmethod
def process(self, batch: PreprocessBatch, **kwargs) -> PreprocessBatch:
    """
    Process the dataset batch.

    Args:
        batch: Dataset batch to process
        **kwargs: Additional processing parameters

    Returns:
        Processed batch
    """
    raise NotImplementedError
fastvideo.dataset.preprocessing_datasets.FrameSamplingStage
FrameSamplingStage(num_frames: int, train_fps: int, speed_factor: int = 1, video_length_tolerance_range: float = 5.0, drop_short_ratio: float = 0.0, seed: int = 42)

Bases: DatasetFilterStage

Stage for temporal frame sampling and indexing.

Source code in fastvideo/dataset/preprocessing_datasets.py
def __init__(self,
             num_frames: int,
             train_fps: int,
             speed_factor: int = 1,
             video_length_tolerance_range: float = 5.0,
             drop_short_ratio: float = 0.0,
             seed: int = 42):
    self.num_frames = num_frames
    self.train_fps = train_fps
    self.speed_factor = speed_factor
    self.video_length_tolerance_range = video_length_tolerance_range
    self.drop_short_ratio = drop_short_ratio
    # Create a seeded random generator for deterministic sampling
    self.rng = random.Random(seed)
Functions
fastvideo.dataset.preprocessing_datasets.FrameSamplingStage.process
process(batch: PreprocessBatch, temporal_sample_fn=None, **kwargs) -> PreprocessBatch

Process frame sampling for video data items.

Parameters:

Name Type Description Default
batch PreprocessBatch

Dataset batch

required
temporal_sample_fn

Function for temporal sampling

None

Returns:

Type Description
PreprocessBatch

Updated batch with frame sampling info

Source code in fastvideo/dataset/preprocessing_datasets.py
def process(self,
            batch: PreprocessBatch,
            temporal_sample_fn=None,
            **kwargs) -> PreprocessBatch:
    """
    Process frame sampling for video data items.

    Args:
        batch: Dataset batch
        temporal_sample_fn: Function for temporal sampling

    Returns:
        Updated batch with frame sampling info
    """
    if batch.is_image:
        # For images, just add sample info
        batch.sample_frame_index = [0]
        batch.sample_num_frames = 1
        return batch

    assert batch.duration is not None and batch.fps is not None
    batch.num_frames = math.ceil(batch.fps * batch.duration)

    # Resample frame indices
    frame_interval = batch.fps / self.train_fps
    start_frame_idx = 0
    frame_indices = np.arange(start_frame_idx, batch.num_frames,
                              frame_interval).astype(int)

    # Temporal crop if too long
    if len(frame_indices) > self.num_frames:
        if temporal_sample_fn is not None:
            begin_index, end_index = temporal_sample_fn(len(frame_indices))
            frame_indices = frame_indices[begin_index:end_index]
        else:
            frame_indices = frame_indices[:self.num_frames]

    batch.sample_frame_index = frame_indices.tolist()
    batch.sample_num_frames = len(frame_indices)

    return batch
fastvideo.dataset.preprocessing_datasets.FrameSamplingStage.should_keep
should_keep(batch: PreprocessBatch, **kwargs) -> bool

Check if video should be kept based on length constraints.

Parameters:

Name Type Description Default
batch PreprocessBatch

Dataset batch

required

Returns:

Type Description
bool

True if should be kept, False otherwise

Source code in fastvideo/dataset/preprocessing_datasets.py
def should_keep(self, batch: PreprocessBatch, **kwargs) -> bool:
    """
    Check if video should be kept based on length constraints.

    Args:
        batch: Dataset batch

    Returns:
        True if should be kept, False otherwise
    """
    if batch.is_image:
        return True

    if batch.duration is None or batch.fps is None:
        return False

    num_frames = math.ceil(batch.fps * batch.duration)

    # Check if video is too long
    if (num_frames / batch.fps > self.video_length_tolerance_range *
        (self.num_frames / self.train_fps * self.speed_factor)):
        return False

    # Resample frame indices to check length
    frame_interval = batch.fps / self.train_fps
    start_frame_idx = 0
    frame_indices = np.arange(start_frame_idx, num_frames,
                              frame_interval).astype(int)

    # Filter short videos
    return not (len(frame_indices) < self.num_frames
                and self.rng.random() < self.drop_short_ratio)
fastvideo.dataset.preprocessing_datasets.ImageTransformStage
ImageTransformStage(transform, transform_topcrop)

Bases: DatasetStage

Stage for image data transformation.

Source code in fastvideo/dataset/preprocessing_datasets.py
def __init__(self, transform, transform_topcrop) -> None:
    self.transform = transform
    self.transform_topcrop = transform_topcrop
Functions
fastvideo.dataset.preprocessing_datasets.ImageTransformStage.process
process(batch: PreprocessBatch, **kwargs) -> PreprocessBatch

Transform image data.

Parameters:

Name Type Description Default
batch PreprocessBatch

Dataset batch with image information

required

Returns:

Type Description
PreprocessBatch

Batch with transformed image tensor

Source code in fastvideo/dataset/preprocessing_datasets.py
def process(self, batch: PreprocessBatch, **kwargs) -> PreprocessBatch:
    """
    Transform image data.

    Args:
        batch: Dataset batch with image information

    Returns:
        Batch with transformed image tensor
    """
    if not batch.is_image:
        return batch

    image = Image.open(batch.path).convert("RGB")
    image = torch.from_numpy(np.array(image))
    image = rearrange(image, "h w c -> c h w").unsqueeze(0)

    if self.transform_topcrop is not None:
        image = self.transform_topcrop(image)
    elif self.transform is not None:
        image = self.transform(image)

    image = image.transpose(0, 1)  # [1 C H W] -> [C 1 H W]
    image = image.float() / 127.5 - 1.0
    batch.pixel_values = image
    return batch
fastvideo.dataset.preprocessing_datasets.PreprocessBatch dataclass
PreprocessBatch(path: str, cap: str | list[str], resolution: dict | None = None, fps: float | None = None, duration: float | None = None, num_frames: int | None = None, sample_frame_index: list[int] | None = None, sample_num_frames: int | None = None, pixel_values: Tensor | None = None, text: str | None = None, input_ids: Tensor | None = None, cond_mask: Tensor | None = None)

Batch information for dataset processing stages.

This class holds all the information about a video-caption or image-caption pair as it moves through the processing pipeline. Fields are populated by different stages.

Attributes
fastvideo.dataset.preprocessing_datasets.PreprocessBatch.is_image property
is_image: bool

Check if this is an image item.

fastvideo.dataset.preprocessing_datasets.PreprocessBatch.is_video property
is_video: bool

Check if this is a video item.

fastvideo.dataset.preprocessing_datasets.ResolutionFilterStage
ResolutionFilterStage(max_h_div_w_ratio: float = 17 / 16, min_h_div_w_ratio: float = 8 / 16, max_height: int = 1024, max_width: int = 1024)

Bases: DatasetFilterStage

Stage for filtering data items based on resolution constraints.

Source code in fastvideo/dataset/preprocessing_datasets.py
def __init__(self,
             max_h_div_w_ratio: float = 17 / 16,
             min_h_div_w_ratio: float = 8 / 16,
             max_height: int = 1024,
             max_width: int = 1024):
    self.max_h_div_w_ratio = max_h_div_w_ratio
    self.min_h_div_w_ratio = min_h_div_w_ratio
    self.max_height = max_height
    self.max_width = max_width
Functions
fastvideo.dataset.preprocessing_datasets.ResolutionFilterStage.filter_resolution
filter_resolution(h: int, w: int, max_h_div_w_ratio: float, min_h_div_w_ratio: float) -> bool

Filter based on height/width ratio.

Source code in fastvideo/dataset/preprocessing_datasets.py
def filter_resolution(self, h: int, w: int, max_h_div_w_ratio: float,
                      min_h_div_w_ratio: float) -> bool:
    """Filter based on height/width ratio."""
    return h / w <= max_h_div_w_ratio and h / w >= min_h_div_w_ratio
fastvideo.dataset.preprocessing_datasets.ResolutionFilterStage.process
process(batch: PreprocessBatch, **kwargs) -> PreprocessBatch

Process does nothing for resolution filtering - filtering is handled by should_keep.

Source code in fastvideo/dataset/preprocessing_datasets.py
def process(self, batch: PreprocessBatch, **kwargs) -> PreprocessBatch:
    """Process does nothing for resolution filtering - filtering is handled by should_keep."""
    return batch
fastvideo.dataset.preprocessing_datasets.ResolutionFilterStage.should_keep
should_keep(batch: PreprocessBatch, **kwargs) -> bool

Check if data item passes resolution filtering.

Parameters:

Name Type Description Default
batch PreprocessBatch

Dataset batch with resolution information

required

Returns:

Type Description
bool

True if passes filter, False otherwise

Source code in fastvideo/dataset/preprocessing_datasets.py
def should_keep(self, batch: PreprocessBatch, **kwargs) -> bool:
    """
    Check if data item passes resolution filtering.

    Args:
        batch: Dataset batch with resolution information

    Returns:
        True if passes filter, False otherwise
    """
    # Only apply to videos
    if not batch.is_video:
        return True

    if batch.resolution is None:
        return False

    height = batch.resolution.get("height", None)
    width = batch.resolution.get("width", None)
    if height is None or width is None:
        return False

    # Check aspect ratio
    aspect = self.max_height / self.max_width
    hw_aspect_thr = 1.5

    return self.filter_resolution(
        height,
        width,
        max_h_div_w_ratio=hw_aspect_thr * aspect,
        min_h_div_w_ratio=1 / hw_aspect_thr * aspect,
    )
fastvideo.dataset.preprocessing_datasets.TextDataset
TextDataset(data_merge_path: str, args, start_idx: int = 0, seed: int = 42)

Bases: IterableDataset, Stateful

Text-only dataset for processing prompts from a simple text file.

Assumes that data_merge_path is a text file with one prompt per line: A cat playing with a ball A dog running in the park A person cooking dinner ...

This dataset processes text data through text encoding stages only.

Source code in fastvideo/dataset/preprocessing_datasets.py
def __init__(self,
             data_merge_path: str,
             args,
             start_idx: int = 0,
             seed: int = 42):
    self.data_merge_path = data_merge_path
    self.start_idx = start_idx
    self.args = args
    self.seed = seed

    # Initialize tokenizer
    tokenizer_path = os.path.join(args.model_path, "tokenizer")
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_path,
                                              cache_dir=args.cache_dir)

    # Initialize text encoding stage
    self.text_encoding_stage = TextEncodingStage(
        tokenizer=tokenizer,
        text_max_length=args.text_max_length,
        cfg_rate=getattr(args, 'training_cfg_rate', 0.0),
        seed=self.seed)

    # Process text data
    self.processed_batches = self._process_text_data()
Functions
fastvideo.dataset.preprocessing_datasets.TextDataset.__iter__
__iter__()

Iterator for the dataset.

Source code in fastvideo/dataset/preprocessing_datasets.py
def __iter__(self):
    """Iterator for the dataset."""
    # Set up distributed sampling if needed
    if torch.distributed.is_available() and torch.distributed.is_initialized():
        rank = torch.distributed.get_rank()
        world_size = torch.distributed.get_world_size()
    else:
        rank = 0
        world_size = 1

    # Calculate chunk for this rank
    total_items = len(self.processed_batches)
    items_per_rank = math.ceil(total_items / world_size)
    start_idx = rank * items_per_rank + self.start_idx
    end_idx = min(start_idx + items_per_rank, total_items)

    # Yield items for this rank
    for idx in range(start_idx, end_idx):
        if idx < len(self.processed_batches):
            yield self._get_item(idx)
fastvideo.dataset.preprocessing_datasets.TextDataset.load_state_dict
load_state_dict(state_dict: dict[str, Any]) -> None

Load state dict from checkpoint.

Source code in fastvideo/dataset/preprocessing_datasets.py
def load_state_dict(self, state_dict: dict[str, Any]) -> None:
    """Load state dict from checkpoint."""
    self.processed_batches = state_dict["processed_batches"]
fastvideo.dataset.preprocessing_datasets.TextDataset.state_dict
state_dict() -> dict[str, Any]

Return state dict for checkpointing.

Source code in fastvideo/dataset/preprocessing_datasets.py
def state_dict(self) -> dict[str, Any]:
    """Return state dict for checkpointing."""
    return {"processed_batches": self.processed_batches}
fastvideo.dataset.preprocessing_datasets.TextEncodingStage
TextEncodingStage(tokenizer, text_max_length: int, cfg_rate: float = 0.0, seed: int = 42)

Bases: DatasetStage

Stage for text tokenization and encoding.

Source code in fastvideo/dataset/preprocessing_datasets.py
def __init__(self,
             tokenizer,
             text_max_length: int,
             cfg_rate: float = 0.0,
             seed: int = 42):
    self.tokenizer = tokenizer
    self.text_max_length = text_max_length
    self.cfg_rate = cfg_rate
    # Create a seeded random generator for deterministic CFG
    self.rng = random.Random(seed)
Functions
fastvideo.dataset.preprocessing_datasets.TextEncodingStage.process
process(batch: PreprocessBatch, **kwargs) -> PreprocessBatch

Process text data.

Parameters:

Name Type Description Default
batch PreprocessBatch

Dataset batch with caption information

required

Returns:

Type Description
PreprocessBatch

Batch with encoded text information

Source code in fastvideo/dataset/preprocessing_datasets.py
def process(self, batch: PreprocessBatch, **kwargs) -> PreprocessBatch:
    """
    Process text data.

    Args:
        batch: Dataset batch with caption information

    Returns:
        Batch with encoded text information
    """
    text = batch.cap
    if not isinstance(text, list):
        text = [text]
    text = [self.rng.choice(text)]

    text = text[0] if self.rng.random() > self.cfg_rate else ""
    text_tokens_and_mask = self.tokenizer(
        text,
        max_length=self.text_max_length,
        padding="max_length",
        truncation=True,
        return_attention_mask=True,
        add_special_tokens=True,
        return_tensors="pt",
    )

    batch.text = text
    batch.input_ids = text_tokens_and_mask["input_ids"]
    batch.cond_mask = text_tokens_and_mask["attention_mask"]
    return batch
fastvideo.dataset.preprocessing_datasets.VideoCaptionMergedDataset
VideoCaptionMergedDataset(data_merge_path: str, args, transform, temporal_sample, transform_topcrop, start_idx: int = 0, seed: int = 42)

Bases: IterableDataset, Stateful

Merged dataset for video and caption data with stage-based processing. Assumes that data_merge_path is a txt file with the following format: ,

The folder should contain videos.

The json file should be a list of dictionaries with the following format:
[
{
    "path": "1gGQy4nxyUo-Scene-016.mp4",
    "resolution": {
    "width": 1920,
    "height": 1080
    },
    "size": 2439112,
    "fps": 25.0,
    "duration": 6.88,
    "num_frames": 172,
    "cap": [
    "A watermelon wearing a helmet is crushed by a hydraulic press, causing it to flatten and burst open."
    ]
},
...
]

This dataset processes video and image data through a series of stages: - Data validation - Resolution filtering
- Frame sampling - Transformation - Text encoding

Source code in fastvideo/dataset/preprocessing_datasets.py
def __init__(self,
             data_merge_path: str,
             args,
             transform,
             temporal_sample,
             transform_topcrop,
             start_idx: int = 0,
             seed: int = 42):
    self.data_merge_path = data_merge_path
    self.start_idx = start_idx
    self.args = args
    self.temporal_sample = temporal_sample
    self.seed = seed

    # Initialize tokenizer
    tokenizer_path = os.path.join(args.model_path, "tokenizer")
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_path,
                                              cache_dir=args.cache_dir)

    # Initialize processing stages
    self._init_stages(args, transform, transform_topcrop, tokenizer)

    # Process metadata
    self.processed_batches = self._process_metadata()
Functions
fastvideo.dataset.preprocessing_datasets.VideoCaptionMergedDataset.__iter__
__iter__()

Iterate through processed data items.

Source code in fastvideo/dataset/preprocessing_datasets.py
def __iter__(self):
    """Iterate through processed data items."""
    for idx in range(len(self.processed_batches)):
        yield self._get_item(idx)
fastvideo.dataset.preprocessing_datasets.VideoCaptionMergedDataset.load_state_dict
load_state_dict(state_dict: dict[str, Any]) -> None

Load state dict from checkpoint.

Source code in fastvideo/dataset/preprocessing_datasets.py
def load_state_dict(self, state_dict: dict[str, Any]) -> None:
    """Load state dict from checkpoint."""
    self.processed_batches = state_dict["processed_batches"]
fastvideo.dataset.preprocessing_datasets.VideoCaptionMergedDataset.state_dict
state_dict() -> dict[str, Any]

Return state dict for checkpointing.

Source code in fastvideo/dataset/preprocessing_datasets.py
def state_dict(self) -> dict[str, Any]:
    """Return state dict for checkpointing."""
    return {"processed_batches": self.processed_batches}
fastvideo.dataset.preprocessing_datasets.VideoTransformStage
VideoTransformStage(transform)

Bases: DatasetStage

Stage for video data transformation.

Source code in fastvideo/dataset/preprocessing_datasets.py
def __init__(self, transform) -> None:
    self.transform = transform
Functions
fastvideo.dataset.preprocessing_datasets.VideoTransformStage.process
process(batch: PreprocessBatch, **kwargs) -> PreprocessBatch

Transform video data.

Parameters:

Name Type Description Default
batch PreprocessBatch

Dataset batch with video information

required

Returns:

Type Description
PreprocessBatch

Batch with transformed video tensor

Source code in fastvideo/dataset/preprocessing_datasets.py
def process(self, batch: PreprocessBatch, **kwargs) -> PreprocessBatch:
    """
    Transform video data.

    Args:
        batch: Dataset batch with video information

    Returns:
        Batch with transformed video tensor
    """
    if not batch.is_video:
        return batch

    assert os.path.exists(batch.path), f"file {batch.path} do not exist!"
    assert batch.sample_frame_index is not None, "Frame indices must be set before transformation"

    torchvision_video, _, metadata = torchvision.io.read_video(
        batch.path, output_format="TCHW")
    video = torchvision_video[batch.sample_frame_index]
    if self.transform is not None:
        video = self.transform(video)
    video = rearrange(video, "t c h w -> c t h w")
    video = video.to(torch.uint8)

    h, w = video.shape[-2:]
    assert (
        h / w <= 17 / 16 and h / w >= 8 / 16
    ), f"Only videos with a ratio (h/w) less than 17/16 and more than 8/16 are supported. But video ({batch.path}) found ratio is {round(h / w, 2)} with the shape of {video.shape}"

    video = video.float() / 127.5 - 1.0
    batch.pixel_values = video
    return batch

Functions

fastvideo.dataset.transform

Classes

fastvideo.dataset.transform.CenterCropResizeVideo
CenterCropResizeVideo(size, top_crop=False, interpolation_mode='bilinear')

First use the short side for cropping length, center crop video, then resize to the specified size

Source code in fastvideo/dataset/transform.py
def __init__(
    self,
    size,
    top_crop=False,
    interpolation_mode="bilinear",
) -> None:
    if len(size) != 2:
        raise ValueError(
            f"size should be tuple (height, width), instead got {size}")
    self.size = size
    self.top_crop = top_crop
    self.interpolation_mode = interpolation_mode
Functions
fastvideo.dataset.transform.CenterCropResizeVideo.__call__
__call__(clip) -> Tensor

Parameters:

Name Type Description Default
clip tensor

Video clip to be cropped. Size is (T, C, H, W)

required

Returns: torch.tensor: scale resized / center cropped video clip. size is (T, C, crop_size, crop_size)

Source code in fastvideo/dataset/transform.py
def __call__(self, clip) -> torch.Tensor:
    """
    Args:
        clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
    Returns:
        torch.tensor: scale resized / center cropped video clip.
            size is (T, C, crop_size, crop_size)
    """
    clip_center_crop = center_crop_th_tw(clip,
                                         self.size[0],
                                         self.size[1],
                                         top_crop=self.top_crop)
    clip_center_crop_resize = resize(
        clip_center_crop,
        target_size=self.size,
        interpolation_mode=self.interpolation_mode,
    )
    return clip_center_crop_resize
fastvideo.dataset.transform.Normalize255
Normalize255()

Convert tensor data type from uint8 to float, divide value by 255.0 and

Source code in fastvideo/dataset/transform.py
def __init__(self) -> None:
    pass
Functions
fastvideo.dataset.transform.Normalize255.__call__
__call__(clip) -> Tensor

Parameters:

Name Type Description Default
clip torch.tensor, dtype=torch.uint8

Size is (T, C, H, W)

required

Return: clip (torch.tensor, dtype=torch.float): Size is (T, C, H, W)

Source code in fastvideo/dataset/transform.py
def __call__(self, clip) -> torch.Tensor:
    """
    Args:
        clip (torch.tensor, dtype=torch.uint8): Size is (T, C, H, W)
    Return:
        clip (torch.tensor, dtype=torch.float): Size is (T, C, H, W)
    """
    return normalize_video(clip)
fastvideo.dataset.transform.TemporalRandomCrop
TemporalRandomCrop(size)

Temporally crop the given frame indices at a random location.

Parameters:

Name Type Description Default
size int

Desired length of frames will be seen in the model.

required
Source code in fastvideo/dataset/transform.py
def __init__(self, size) -> None:
    self.size = size

Functions

fastvideo.dataset.transform.crop
crop(clip, i, j, h, w) -> Tensor

Parameters:

Name Type Description Default
clip tensor

Video clip to be cropped. Size is (T, C, H, W)

required
Source code in fastvideo/dataset/transform.py
def crop(clip, i, j, h, w) -> torch.Tensor:
    """
    Args:
        clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
    """
    if len(clip.size()) != 4:
        raise ValueError("clip should be a 4D tensor")
    return clip[..., i:i + h, j:j + w]
fastvideo.dataset.transform.normalize_video
normalize_video(clip) -> Tensor

Convert tensor data type from uint8 to float, divide value by 255.0 and permute the dimensions of clip tensor Args: clip (torch.tensor, dtype=torch.uint8): Size is (T, C, H, W) Return: clip (torch.tensor, dtype=torch.float): Size is (T, C, H, W)

Source code in fastvideo/dataset/transform.py
def normalize_video(clip) -> torch.Tensor:
    """
    Convert tensor data type from uint8 to float, divide value by 255.0 and
    permute the dimensions of clip tensor
    Args:
        clip (torch.tensor, dtype=torch.uint8): Size is (T, C, H, W)
    Return:
        clip (torch.tensor, dtype=torch.float): Size is (T, C, H, W)
    """
    _is_tensor_video_clip(clip)
    if not clip.dtype == torch.uint8:
        raise TypeError(
            f"clip tensor should have data type uint8. Got {clip.dtype}")
    # return clip.float().permute(3, 0, 1, 2) / 255.0
    return clip.float() / 255.0

fastvideo.dataset.utils

Functions

fastvideo.dataset.utils.collate_rows_from_parquet_schema
collate_rows_from_parquet_schema(rows, parquet_schema, text_padding_length, cfg_rate=0.0, rng=None) -> dict[str, Any]

Collate rows from parquet files based on the provided schema. Dynamically processes tensor fields based on schema and returns batched data.

Parameters:

Name Type Description Default
rows

List of row dictionaries from parquet files

required
parquet_schema

PyArrow schema defining the structure of the data

required

Returns:

Type Description
dict[str, Any]

Dict containing batched tensors and metadata

Source code in fastvideo/dataset/utils.py
def collate_rows_from_parquet_schema(rows,
                                     parquet_schema,
                                     text_padding_length,
                                     cfg_rate=0.0,
                                     rng=None) -> dict[str, Any]:
    """
    Collate rows from parquet files based on the provided schema.
    Dynamically processes tensor fields based on schema and returns batched data.

    Args:
        rows: List of row dictionaries from parquet files
        parquet_schema: PyArrow schema defining the structure of the data

    Returns:
        Dict containing batched tensors and metadata
    """
    if not rows:
        return cast(dict[str, Any], {})

    # Initialize containers for different data types
    batch_data: dict[str, Any] = {}

    # Get tensor and metadata field names from schema (fields ending with '_bytes')
    tensor_fields = []
    metadata_fields = []
    for field in parquet_schema.names:
        if field.endswith('_bytes'):
            shape_field = field.replace('_bytes', '_shape')
            dtype_field = field.replace('_bytes', '_dtype')
            tensor_name = field.replace('_bytes', '')
            tensor_fields.append(tensor_name)
            assert shape_field in parquet_schema.names, f"Shape field {shape_field} not found in schema for field {field}. Currently we only support *_bytes fields for tensors."
            assert dtype_field in parquet_schema.names, f"Dtype field {dtype_field} not found in schema for field {field}. Currently we only support *_bytes fields for tensors."
        elif not field.endswith('_shape') and not field.endswith('_dtype'):
            # Only add actual metadata fields, not the shape/dtype helper fields
            metadata_fields.append(field)

    # Process each tensor field
    for tensor_name in tensor_fields:
        tensor_list = []

        for row in rows:
            # Get tensor data from row using the existing helper function pattern
            shape_key = f"{tensor_name}_shape"
            bytes_key = f"{tensor_name}_bytes"

            if shape_key in row and bytes_key in row:
                shape = row[shape_key]
                bytes_data = row[bytes_key]

                if len(bytes_data) == 0:
                    tensor = torch.zeros(0, dtype=torch.bfloat16)
                else:
                    # Convert bytes to tensor using float32 as default
                    if tensor_name == 'text_embedding' and (rng.random(
                    ) if rng else random.random()) < cfg_rate:
                        data = np.zeros((512, 4096), dtype=np.float32)
                    else:
                        data = np.frombuffer(
                            bytes_data, dtype=np.float32).reshape(shape).copy()
                    tensor = torch.from_numpy(data)
                    # if len(data.shape) == 3:
                    #     B, L, D = tensor.shape
                    #     assert B == 1, "Batch size must be 1"
                    #     tensor = tensor.squeeze(0)

                tensor_list.append(tensor)
            else:
                # Handle missing tensor data
                tensor_list.append(torch.zeros(0, dtype=torch.bfloat16))

        # Stack tensors with special handling for text embeddings
        if tensor_name == 'text_embedding':
            # Handle text embeddings with padding
            padded_tensors = []
            attention_masks = []

            for tensor in tensor_list:
                if tensor.numel() > 0:
                    padded_tensor, mask = pad(tensor, text_padding_length)
                    padded_tensors.append(padded_tensor)
                    attention_masks.append(mask)
                else:
                    # Handle empty embeddings - assume default embedding dimension
                    padded_tensors.append(
                        torch.zeros(text_padding_length,
                                    768,
                                    dtype=torch.bfloat16))
                    attention_masks.append(torch.zeros(text_padding_length))

            batch_data[tensor_name] = torch.stack(padded_tensors)
            batch_data['text_attention_mask'] = torch.stack(attention_masks)
        else:
            # Stack all tensors to preserve batch consistency
            # Don't filter out None or empty tensors as this breaks batch sizing
            try:
                batch_data[tensor_name] = torch.stack(tensor_list)
            except ValueError as e:
                shapes = [
                    t.shape
                    if t is not None and hasattr(t, 'shape') else 'None/Invalid'
                    for t in tensor_list
                ]
                raise ValueError(
                    f"Failed to stack tensors for field '{tensor_name}'. "
                    f"Tensor shapes: {shapes}. "
                    f"All tensors in a batch must have compatible shapes. "
                    f"Original error: {e}") from e

    # Process metadata fields into info_list
    info_list = []
    for row in rows:
        info = {}
        for field in metadata_fields:
            info[field] = row.get(field, "")

        # Add prompt field for backward compatibility
        info["prompt"] = info.get("caption", "")
        info_list.append(info)

    batch_data['info_list'] = info_list

    # Add caption_text for backward compatibility
    if info_list and 'caption' in info_list[0]:
        batch_data['caption_text'] = [info['caption'] for info in info_list]

    return batch_data
fastvideo.dataset.utils.get_torch_tensors_from_row_dict
get_torch_tensors_from_row_dict(row_dict, keys, cfg_rate, rng=None) -> dict[str, Any]

Get the latents and prompts from a row dictionary.

Source code in fastvideo/dataset/utils.py
def get_torch_tensors_from_row_dict(row_dict,
                                    keys,
                                    cfg_rate,
                                    rng=None) -> dict[str, Any]:
    """
    Get the latents and prompts from a row dictionary.
    """
    return_dict = {}
    for key in keys:
        shape, bytes = None, None
        if isinstance(key, tuple):
            for k in key:
                try:
                    shape = row_dict[f"{k}_shape"]
                    bytes = row_dict[f"{k}_bytes"]
                except KeyError:
                    continue
            key = key[0]
            if shape is None or bytes is None:
                raise ValueError(f"Key {key} not found in row_dict")
        else:
            shape = row_dict[f"{key}_shape"]
            bytes = row_dict[f"{key}_bytes"]

        # TODO (peiyuan): read precision
        if key == 'text_embedding' and (rng.random()
                                        if rng else random.random()) < cfg_rate:
            data = np.zeros((512, 4096), dtype=np.float32)
        else:
            data = np.frombuffer(bytes, dtype=np.float32).reshape(shape).copy()
        data = torch.from_numpy(data)
        if len(data.shape) == 3:
            B, L, D = data.shape
            assert B == 1, "Batch size must be 1"
            data = data.squeeze(0)
        return_dict[key] = data
    return return_dict
fastvideo.dataset.utils.pad
pad(t: Tensor, padding_length: int) -> tuple[Tensor, Tensor]

Pad or crop an embedding [L, D] to exactly padding_length tokens. Return: - [L, D] tensor in pinned CPU memory - [L] attention mask in pinned CPU memory

Source code in fastvideo/dataset/utils.py
def pad(t: torch.Tensor, padding_length: int) -> tuple[torch.Tensor, torch.Tensor]:
    """
    Pad or crop an embedding [L, D] to exactly padding_length tokens.
    Return:
    - [L, D] tensor in pinned CPU memory
    - [L] attention mask in pinned CPU memory
    """
    L, D = t.shape
    if padding_length > L:  # pad
        pad = torch.zeros(padding_length - L, D, dtype=t.dtype, device=t.device)
        return torch.cat([t, pad], 0), torch.cat(
            [torch.ones(L), torch.zeros(padding_length - L)], 0)
    else:  # crop
        return t[:padding_length], torch.ones(padding_length)

fastvideo.dataset.validation_dataset

Classes

fastvideo.dataset.validation_dataset.ValidationDataset
ValidationDataset(filename: str)

Bases: IterableDataset

Source code in fastvideo/dataset/validation_dataset.py
def __init__(self, filename: str):
    super().__init__()

    self.filename = pathlib.Path(filename)
    # get directory of filename
    self.dir = os.path.abspath(self.filename.parent)

    if not self.filename.exists():
        raise FileNotFoundError(
            f"File {self.filename.as_posix()} does not exist")

    if self.filename.suffix == ".csv":
        data = datasets.load_dataset("csv",
                                     data_files=self.filename.as_posix(),
                                     split="train")
    elif self.filename.suffix == ".json":
        data = datasets.load_dataset("json",
                                     data_files=self.filename.as_posix(),
                                     split="train",
                                     field="data")
    elif self.filename.suffix == ".parquet":
        data = datasets.load_dataset("parquet",
                                     data_files=self.filename.as_posix(),
                                     split="train")
    elif self.filename.suffix == ".arrow":
        data = datasets.load_dataset("arrow",
                                     data_files=self.filename.as_posix(),
                                     split="train")
    else:
        _SUPPORTED_FILE_FORMATS = [".csv", ".json", ".parquet", ".arrow"]
        raise ValueError(
            f"Unsupported file format {self.filename.suffix} for validation dataset. Supported formats are: {_SUPPORTED_FILE_FORMATS}"
        )

    # Get distributed training info
    self.global_rank = get_world_rank()
    self.world_size = get_world_size()
    self.sp_world_size = get_sp_world_size()
    self.num_sp_groups = self.world_size // self.sp_world_size

    # Convert to list to get total samples
    self.all_samples = list(data)
    self.original_total_samples = len(self.all_samples)

    # Extend samples to be a multiple of DP degree (num_sp_groups)
    remainder = self.original_total_samples % self.num_sp_groups
    if remainder != 0:
        samples_to_add = self.num_sp_groups - remainder

        # Duplicate samples cyclically to reach the target
        additional_samples = []
        for i in range(samples_to_add):
            additional_samples.append(
                self.all_samples[i % self.original_total_samples])

        self.all_samples.extend(additional_samples)

    self.total_samples = len(self.all_samples)

    # Calculate which SP group this rank belongs to
    self.sp_group_id = self.global_rank // self.sp_world_size

    # Now all SP groups will have equal number of samples
    self.samples_per_sp_group = self.total_samples // self.num_sp_groups

    # Calculate start and end indices for this SP group
    self.start_idx = self.sp_group_id * self.samples_per_sp_group
    self.end_idx = self.start_idx + self.samples_per_sp_group

    # Get samples for this SP group
    self.sp_group_samples = self.all_samples[self.start_idx:self.end_idx]

    logger.info(
        "Rank %s (SP group %s): "
        "Original samples: %s, "
        "Extended samples: %s, "
        "SP group samples: %s, "
        "Range: [%s:%s]",
        self.global_rank,
        self.sp_group_id,
        self.original_total_samples,
        self.total_samples,
        len(self.sp_group_samples),
        self.start_idx,
        self.end_idx,
        local_main_process_only=False)
Functions
fastvideo.dataset.validation_dataset.ValidationDataset.__len__
__len__()

Return the number of samples for this SP group.

Source code in fastvideo/dataset/validation_dataset.py
def __len__(self):
    """Return the number of samples for this SP group."""
    return len(self.sp_group_samples)

Functions