Skip to content

parquet_dataset_map_style

Classes

fastvideo.dataset.parquet_dataset_map_style.DP_SP_BatchSampler

DP_SP_BatchSampler(batch_size: int, dataset_size: int, num_sp_groups: int, sp_world_size: int, global_rank: int, drop_last: bool = True, drop_first_row: bool = False, seed: int = 0)

Bases: Sampler[list[int]]

A simple sequential batch sampler that yields batches of indices.

Source code in fastvideo/dataset/parquet_dataset_map_style.py
def __init__(
    self,
    batch_size: int,
    dataset_size: int,
    num_sp_groups: int,
    sp_world_size: int,
    global_rank: int,
    drop_last: bool = True,
    drop_first_row: bool = False,
    seed: int = 0,
):
    self.batch_size = batch_size
    self.dataset_size = dataset_size
    self.drop_last = drop_last
    self.seed = seed
    self.num_sp_groups = num_sp_groups
    self.global_rank = global_rank
    self.sp_world_size = sp_world_size

    # ── epoch-level RNG ────────────────────────────────────────────────
    rng = torch.Generator().manual_seed(self.seed)
    # Create a random permutation of all indices
    global_indices = torch.randperm(self.dataset_size, generator=rng)

    if drop_first_row:
        # drop 0 in global_indices
        global_indices = global_indices[global_indices != 0]
        self.dataset_size = self.dataset_size - 1

    if self.drop_last:
        # For drop_last=True, we:
        # 1. Ensure total samples is divisible by (batch_size * num_sp_groups)
        # 2. This guarantees each SP group gets same number of complete batches
        # 3. Prevents uneven batch sizes across SP groups at end of epoch
        num_batches = self.dataset_size // self.batch_size
        num_global_batches = num_batches // self.num_sp_groups
        global_indices = global_indices[:num_global_batches *
                                        self.num_sp_groups *
                                        self.batch_size]
    else:
        if self.dataset_size % (self.num_sp_groups * self.batch_size) != 0:
            # add more indices to make it divisible by (batch_size * num_sp_groups)
            padding_size = self.num_sp_groups * self.batch_size - (
                self.dataset_size % (self.num_sp_groups * self.batch_size))
            logger.info("Padding the dataset from %d to %d",
                        self.dataset_size, self.dataset_size + padding_size)
            global_indices = torch.cat(
                [global_indices, global_indices[:padding_size]])

    # shard the indices to each sp group
    ith_sp_group = self.global_rank // self.sp_world_size
    sp_group_local_indices = global_indices[ith_sp_group::self.
                                            num_sp_groups]
    self.sp_group_local_indices = sp_group_local_indices
    logger.info("Dataset size for each sp group: %d",
                len(sp_group_local_indices))

fastvideo.dataset.parquet_dataset_map_style.LatentsParquetMapStyleDataset

LatentsParquetMapStyleDataset(path: str, batch_size: int, parquet_schema: Schema, cfg_rate: float = 0.0, seed: int = 42, drop_last: bool = True, drop_first_row: bool = False, text_padding_length: int = 512)

Bases: Dataset

Return latents[B,C,T,H,W] and embeddings[B,L,D] in pinned CPU memory. Note: Using parquet for map style dataset is not efficient, we mainly keep it for backward compatibility and debugging.

Source code in fastvideo/dataset/parquet_dataset_map_style.py
def __init__(
    self,
    path: str,
    batch_size: int,
    parquet_schema: pa.Schema,
    cfg_rate: float = 0.0,
    seed: int = 42,
    drop_last: bool = True,
    drop_first_row: bool = False,
    text_padding_length: int = 512,
):
    super().__init__()
    self.path = path
    self.cfg_rate = cfg_rate
    self.parquet_schema = parquet_schema
    self.seed = seed
    # Create a seeded random generator for deterministic CFG
    self.rng = random.Random(seed)
    logger.info("Initializing LatentsParquetMapStyleDataset with path: %s",
                path)
    self.parquet_files, self.lengths = get_parquet_files_and_length(path)
    self.batch = batch_size
    self.text_padding_length = text_padding_length
    self.sampler = DP_SP_BatchSampler(
        batch_size=batch_size,
        dataset_size=sum(self.lengths),
        num_sp_groups=get_world_size() // get_sp_world_size(),
        sp_world_size=get_sp_world_size(),
        global_rank=get_world_rank(),
        drop_last=drop_last,
        drop_first_row=drop_first_row,
        seed=seed,
    )
    logger.info("Dataset initialized with %d parquet files and %d rows",
                len(self.parquet_files), sum(self.lengths))

Functions

fastvideo.dataset.parquet_dataset_map_style.LatentsParquetMapStyleDataset.__getitems__
__getitems__(indices: list[int]) -> dict[str, Any]

Batch fetch using read_row_from_parquet_file for each index.

Source code in fastvideo/dataset/parquet_dataset_map_style.py
def __getitems__(self, indices: list[int]) -> dict[str, Any]:
    """
    Batch fetch using read_row_from_parquet_file for each index.
    """
    rows = [
        read_row_from_parquet_file(self.parquet_files, idx, self.lengths)
        for idx in indices
    ]

    batch = collate_rows_from_parquet_schema(rows,
                                             self.parquet_schema,
                                             self.text_padding_length,
                                             cfg_rate=self.cfg_rate,
                                             rng=self.rng)
    return batch
fastvideo.dataset.parquet_dataset_map_style.LatentsParquetMapStyleDataset.get_validation_negative_prompt
get_validation_negative_prompt() -> tuple[Tensor, Tensor, str]

Get the negative prompt for validation. This method ensures the negative prompt is loaded and cached properly. Returns the processed negative prompt data (latents, embeddings, masks, info).

Source code in fastvideo/dataset/parquet_dataset_map_style.py
def get_validation_negative_prompt(
        self) -> tuple[torch.Tensor, torch.Tensor, str]:
    """
    Get the negative prompt for validation. 
    This method ensures the negative prompt is loaded and cached properly.
    Returns the processed negative prompt data (latents, embeddings, masks, info).
    """

    # Read first row from first parquet file
    file_path = self.parquet_files[0]
    row_idx = 0
    # Read the negative prompt data
    row_dict = read_row_from_parquet_file([file_path], row_idx,
                                          [self.lengths[0]])

    batch = collate_rows_from_parquet_schema([row_dict],
                                             self.parquet_schema,
                                             self.text_padding_length,
                                             cfg_rate=0.0,
                                             rng=self.rng)
    negative_prompt = batch['info_list'][0]['prompt']
    negative_prompt_embedding = batch['text_embedding']
    negative_prompt_attention_mask = batch['text_attention_mask']
    if len(negative_prompt_embedding.shape) == 2:
        negative_prompt_embedding = negative_prompt_embedding.unsqueeze(0)
    if len(negative_prompt_attention_mask.shape) == 1:
        negative_prompt_attention_mask = negative_prompt_attention_mask.unsqueeze(
            0).unsqueeze(0)

    return negative_prompt_embedding, negative_prompt_attention_mask, negative_prompt

Functions

fastvideo.dataset.parquet_dataset_map_style.read_row_from_parquet_file

read_row_from_parquet_file(parquet_files: list[str], global_row_idx: int, lengths: list[int]) -> dict[str, Any]

Read a row from a parquet file. Args: parquet_files: List[str] global_row_idx: int lengths: List[int] Returns:

Source code in fastvideo/dataset/parquet_dataset_map_style.py
def read_row_from_parquet_file(parquet_files: list[str], global_row_idx: int,
                               lengths: list[int]) -> dict[str, Any]:
    '''
    Read a row from a parquet file.
    Args:
        parquet_files: List[str]
        global_row_idx: int
        lengths: List[int]
    Returns:
    '''
    # find the parquet file and local row index
    cumulative = 0
    file_index = 0
    local_row_idx = 0

    for file_index in range(len(lengths)):
        if cumulative + lengths[file_index] > global_row_idx:
            local_row_idx = global_row_idx - cumulative
            break
        cumulative += lengths[file_index]
    else:
        # If we reach here, global_row_idx is out of bounds
        raise IndexError(
            f"global_row_idx {global_row_idx} is out of bounds for dataset")

    parquet_file = pq.ParquetFile(parquet_files[file_index])

    # Calculate the row group to read into memory and the local idx
    # This way we can avoid reading in the entire parquet file
    cumulative = 0
    row_group_index = 0
    local_index = 0

    for i in range(parquet_file.num_row_groups):
        num_rows = parquet_file.metadata.row_group(i).num_rows
        if cumulative + num_rows > local_row_idx:
            row_group_index = i
            local_index = local_row_idx - cumulative
            break
        cumulative += num_rows
    else:
        # If we reach here, local_row_idx is out of bounds for this parquet file
        raise IndexError(
            f"local_row_idx {local_row_idx} is out of bounds for parquet file {parquet_files[file_index]}"
        )

    row_group = parquet_file.read_row_group(row_group_index).to_pydict()
    row_dict = {k: v[local_index] for k, v in row_group.items()}
    del row_group

    return row_dict