Skip to content

validation_dataset

Classes

fastvideo.dataset.validation_dataset.ValidationDataset

ValidationDataset(filename: str)

Bases: IterableDataset

Source code in fastvideo/dataset/validation_dataset.py
def __init__(self, filename: str):
    super().__init__()

    self.filename = pathlib.Path(filename)
    # get directory of filename
    self.dir = os.path.abspath(self.filename.parent)

    if not self.filename.exists():
        raise FileNotFoundError(
            f"File {self.filename.as_posix()} does not exist")

    if self.filename.suffix == ".csv":
        data = datasets.load_dataset("csv",
                                     data_files=self.filename.as_posix(),
                                     split="train")
    elif self.filename.suffix == ".json":
        data = datasets.load_dataset("json",
                                     data_files=self.filename.as_posix(),
                                     split="train",
                                     field="data")
    elif self.filename.suffix == ".parquet":
        data = datasets.load_dataset("parquet",
                                     data_files=self.filename.as_posix(),
                                     split="train")
    elif self.filename.suffix == ".arrow":
        data = datasets.load_dataset("arrow",
                                     data_files=self.filename.as_posix(),
                                     split="train")
    else:
        _SUPPORTED_FILE_FORMATS = [".csv", ".json", ".parquet", ".arrow"]
        raise ValueError(
            f"Unsupported file format {self.filename.suffix} for validation dataset. Supported formats are: {_SUPPORTED_FILE_FORMATS}"
        )

    # Get distributed training info
    self.global_rank = get_world_rank()
    self.world_size = get_world_size()
    self.sp_world_size = get_sp_world_size()
    self.num_sp_groups = self.world_size // self.sp_world_size

    # Convert to list to get total samples
    self.all_samples = list(data)
    self.original_total_samples = len(self.all_samples)

    # Extend samples to be a multiple of DP degree (num_sp_groups)
    remainder = self.original_total_samples % self.num_sp_groups
    if remainder != 0:
        samples_to_add = self.num_sp_groups - remainder

        # Duplicate samples cyclically to reach the target
        additional_samples = []
        for i in range(samples_to_add):
            additional_samples.append(
                self.all_samples[i % self.original_total_samples])

        self.all_samples.extend(additional_samples)

    self.total_samples = len(self.all_samples)

    # Calculate which SP group this rank belongs to
    self.sp_group_id = self.global_rank // self.sp_world_size

    # Now all SP groups will have equal number of samples
    self.samples_per_sp_group = self.total_samples // self.num_sp_groups

    # Calculate start and end indices for this SP group
    self.start_idx = self.sp_group_id * self.samples_per_sp_group
    self.end_idx = self.start_idx + self.samples_per_sp_group

    # Get samples for this SP group
    self.sp_group_samples = self.all_samples[self.start_idx:self.end_idx]

    logger.info(
        "Rank %s (SP group %s): "
        "Original samples: %s, "
        "Extended samples: %s, "
        "SP group samples: %s, "
        "Range: [%s:%s]",
        self.global_rank,
        self.sp_group_id,
        self.original_total_samples,
        self.total_samples,
        len(self.sp_group_samples),
        self.start_idx,
        self.end_idx,
        local_main_process_only=False)

Functions

fastvideo.dataset.validation_dataset.ValidationDataset.__len__
__len__()

Return the number of samples for this SP group.

Source code in fastvideo/dataset/validation_dataset.py
def __len__(self):
    """Return the number of samples for this SP group."""
    return len(self.sp_group_samples)

Functions