Skip to content

longcat_pipeline

LongCat video diffusion pipeline implementation (Phase 1: Wrapper).

This module contains a wrapper implementation of the LongCat video diffusion pipeline using FastVideo's modular pipeline architecture with the original LongCat modules.

Classes

fastvideo.pipelines.basic.longcat.longcat_pipeline.LongCatPipeline

LongCatPipeline(*args, **kwargs)

Bases: LoRAPipeline, ComposedPipelineBase

LongCat video diffusion pipeline with LoRA support.

Phase 1 implementation using wrapper modules from third_party/longcat_video. This validates the pipeline infrastructure before full FastVideo integration.

Source code in fastvideo/pipelines/lora_pipeline.py
def __init__(self, *args, **kwargs) -> None:
    super().__init__(*args, **kwargs)
    self.device = get_local_torch_device()
    # build list of trainable transformers
    for transformer_name in self.trainable_transformer_names:
        if transformer_name in self.modules and self.modules[
                transformer_name] is not None:
            self.trainable_transformer_modules[
                transformer_name] = self.modules[transformer_name]
        # check for transformer_2 in case of Wan2.2 MoE or fake_score_transformer_2
        if transformer_name.endswith("_2"):
            raise ValueError(
                f"trainable_transformer_name override in pipelines should not include _2 suffix: {transformer_name}"
            )

        secondary_transformer_name = transformer_name + "_2"
        if secondary_transformer_name in self.modules and self.modules[
                secondary_transformer_name] is not None:
            self.trainable_transformer_modules[
                secondary_transformer_name] = self.modules[
                    secondary_transformer_name]

    logger.info("trainable_transformer_modules: %s",
                self.trainable_transformer_modules.keys())

    for transformer_name, transformer_module in self.trainable_transformer_modules.items(
    ):
        self.exclude_lora_layers[
            transformer_name] = transformer_module.config.arch_config.exclude_lora_layers
    self.lora_target_modules = self.fastvideo_args.lora_target_modules
    self.lora_path = self.fastvideo_args.lora_path
    self.lora_nickname = self.fastvideo_args.lora_nickname
    self.training_mode = self.fastvideo_args.training_mode
    if self.training_mode and getattr(self.fastvideo_args, "lora_training",
                                      False):
        assert isinstance(self.fastvideo_args, TrainingArgs)
        if self.fastvideo_args.lora_alpha is None:
            self.fastvideo_args.lora_alpha = self.fastvideo_args.lora_rank
        self.lora_rank = self.fastvideo_args.lora_rank  # type: ignore
        self.lora_alpha = self.fastvideo_args.lora_alpha  # type: ignore
        logger.info("Using LoRA training with rank %d and alpha %d",
                    self.lora_rank, self.lora_alpha)
        if self.lora_target_modules is None:
            self.lora_target_modules = [
                "q_proj", "k_proj", "v_proj", "o_proj", "to_q", "to_k",
                "to_v", "to_out", "to_qkv", "to_gate_compress"
            ]
            logger.info(
                "Using default lora_target_modules for all transformers: %s",
                self.lora_target_modules)
        else:
            logger.warning(
                "Using custom lora_target_modules for all transformers, which may not be intended: %s",
                self.lora_target_modules)

        self.convert_to_lora_layers()
    # Inference
    elif not self.training_mode and self.lora_path is not None:
        self.convert_to_lora_layers()
        self.set_lora_adapter(
            self.lora_nickname,  # type: ignore
            self.lora_path)  # type: ignore

Functions

fastvideo.pipelines.basic.longcat.longcat_pipeline.LongCatPipeline.create_pipeline_stages
create_pipeline_stages(fastvideo_args: FastVideoArgs) -> None

Set up pipeline stages with proper dependency injection.

Source code in fastvideo/pipelines/basic/longcat/longcat_pipeline.py
def create_pipeline_stages(self, fastvideo_args: FastVideoArgs) -> None:
    """Set up pipeline stages with proper dependency injection."""

    self.add_stage(stage_name="input_validation_stage",
                   stage=InputValidationStage())

    self.add_stage(stage_name="prompt_encoding_stage",
                   stage=TextEncodingStage(
                       text_encoders=[self.get_module("text_encoder")],
                       tokenizers=[self.get_module("tokenizer")],
                   ))

    # Add refine initialization stage (will be skipped if not refining)
    self.add_stage(stage_name="longcat_refine_init_stage",
                   stage=LongCatRefineInitStage(vae=self.get_module("vae")))

    # First prepare generic timesteps (for non-refine paths)
    self.add_stage(stage_name="timestep_preparation_stage",
                   stage=TimestepPreparationStage(
                       scheduler=self.get_module("scheduler")))

    # Then override timesteps for refinement (will be a no-op if not refining),
    # matching LongCat's generate_refine schedule.
    self.add_stage(stage_name="longcat_refine_timestep_stage",
                   stage=LongCatRefineTimestepStage(
                       scheduler=self.get_module("scheduler")))

    self.add_stage(stage_name="latent_preparation_stage",
                   stage=LatentPreparationStage(
                       scheduler=self.get_module("scheduler"),
                       transformer=self.get_module("transformer", None)))

    self.add_stage(stage_name="denoising_stage",
                   stage=LongCatDenoisingStage(
                       transformer=self.get_module("transformer"),
                       transformer_2=self.get_module("transformer_2", None),
                       scheduler=self.get_module("scheduler"),
                       vae=self.get_module("vae"),
                       pipeline=self))

    self.add_stage(stage_name="decoding_stage",
                   stage=DecodingStage(vae=self.get_module("vae"),
                                       pipeline=self))
fastvideo.pipelines.basic.longcat.longcat_pipeline.LongCatPipeline.initialize_pipeline
initialize_pipeline(fastvideo_args: FastVideoArgs)

Initialize LongCat-specific components.

Source code in fastvideo/pipelines/basic/longcat/longcat_pipeline.py
def initialize_pipeline(self, fastvideo_args: FastVideoArgs):
    """Initialize LongCat-specific components."""

    # Enable BSA (Block Sparse Attention) if configured
    pipeline_config = fastvideo_args.pipeline_config
    transformer = self.get_module("transformer", None)
    if transformer is None:
        raise RuntimeError(
            "Transformer module not found during initializing LongCat pipeline."
        )
    # If user toggles BSA via CLI/config
    if pipeline_config.enable_bsa:
        # Build effective BSA params:
        # 1) from explicit CLI overrides if provided
        # 2) else from pipeline_config.bsa_params
        # 3) else fall back to reasonable defaults
        bsa_params_cfg = pipeline_config.bsa_params
        sparsity = pipeline_config.bsa_sparsity
        cdf_threshold = pipeline_config.bsa_cdf_threshold
        chunk_q = pipeline_config.bsa_chunk_q
        chunk_k = pipeline_config.bsa_chunk_k

        effective_bsa_params = dict(bsa_params_cfg) if isinstance(
            bsa_params_cfg, dict) else {}
        if sparsity is not None:
            effective_bsa_params['sparsity'] = sparsity
        if cdf_threshold is not None:
            effective_bsa_params['cdf_threshold'] = cdf_threshold
        if chunk_q is not None:
            effective_bsa_params['chunk_3d_shape_q'] = chunk_q
        if chunk_k is not None:
            effective_bsa_params['chunk_3d_shape_k'] = chunk_k
        # Provide defaults if still missing
        effective_bsa_params.setdefault('sparsity', 0.9375)
        effective_bsa_params.setdefault('chunk_3d_shape_q', [4, 4, 4])
        effective_bsa_params.setdefault('chunk_3d_shape_k', [4, 4, 4])

        if hasattr(transformer, 'enable_bsa'):
            logger.info(
                "Enabling Block Sparse Attention (BSA) for LongCat transformer"
            )
            transformer.enable_bsa()
            # Propagate params to all attention modules
            if hasattr(transformer, 'blocks'):
                try:
                    for blk in transformer.blocks:
                        if hasattr(blk, 'self_attn'):
                            blk.self_attn.bsa_params = effective_bsa_params
                except Exception as e:
                    logger.warning(
                        "Failed to set BSA params on all blocks: %s", e)
            logger.info("BSA parameters in effect: %s",
                        effective_bsa_params)
        else:
            logger.warning(
                "BSA is enabled in config but transformer does not support it"
            )
    else:
        # Explicitly disable if present
        if hasattr(transformer, 'disable_bsa'):
            transformer.disable_bsa()

Functions