Skip to content

cosmos

Classes

fastvideo.configs.pipelines.cosmos.CosmosConfig dataclass

CosmosConfig(model_path: str = '', pipeline_config_path: str | None = None, embedded_cfg_scale: int = 6, flow_shift: float = 1.0, disable_autocast: bool = False, dit_config: DiTConfig = CosmosVideoConfig(), dit_precision: str = 'bf16', vae_config: VAEConfig = CosmosVAEConfig(), vae_precision: str = 'fp16', vae_tiling: bool = True, vae_sp: bool = True, image_encoder_config: EncoderConfig = EncoderConfig(), image_encoder_precision: str = 'fp32', text_encoder_configs: tuple[EncoderConfig, ...] = (lambda: (T5LargeConfig(),))(), text_encoder_precisions: tuple[str, ...] = (lambda: ('bf16',))(), preprocess_text_funcs: tuple[Callable[[str], str], ...] = (lambda: (preprocess_text,))(), postprocess_text_funcs: tuple[Callable[[BaseEncoderOutput], Tensor], ...] = (lambda: (t5_large_postprocess_text,))(), pos_magic: str | None = None, neg_magic: str | None = None, timesteps_scale: bool | None = None, mask_strategy_file_path: str | None = None, STA_mode: STA_Mode = STA_INFERENCE, skip_time_steps: int = 15, dmd_denoising_steps: list[int] | None = None, ti2v_task: bool = False, boundary_ratio: float | None = None, conditioning_strategy: str = 'frame_replace', min_num_conditional_frames: int = 1, max_num_conditional_frames: int = 2, sigma_conditional: float = 0.0001, sigma_data: float = 1.0, state_ch: int = 16, state_t: int = 24, text_encoder_class: str = 'T5')

Bases: PipelineConfig

Configuration for Cosmos2 Video2World pipeline matching diffusers.

Functions

fastvideo.configs.pipelines.cosmos.t5_large_postprocess_text

t5_large_postprocess_text(outputs: BaseEncoderOutput) -> Tensor

Postprocess T5 Large text encoder outputs for Cosmos pipeline.

Return raw last_hidden_state without truncation/padding.

Source code in fastvideo/configs/pipelines/cosmos.py
def t5_large_postprocess_text(outputs: BaseEncoderOutput) -> torch.Tensor:
    """Postprocess T5 Large text encoder outputs for Cosmos pipeline.

    Return raw last_hidden_state without truncation/padding.
    """
    hidden_state = outputs.last_hidden_state

    if hidden_state is None:
        raise ValueError("T5 Large outputs missing last_hidden_state")

    nan_count = torch.isnan(hidden_state).sum()
    if nan_count > 0:
        hidden_state = hidden_state.masked_fill(torch.isnan(hidden_state), 0.0)

    # Zero out embeddings beyond actual sequence length
    if outputs.attention_mask is not None:
        attention_mask = outputs.attention_mask
        lengths = attention_mask.sum(dim=1).cpu()
        for i, length in enumerate(lengths):
            hidden_state[i, length:] = 0

    return hidden_state