Bases: LoRAPipeline, ComposedPipelineBase
LongCat Video Continuation pipeline.
Generates video continuation from multiple conditioning frames using
optional KV cache for 2-3x speedup.
Key features:
- Takes video input (13+ frames typically)
- Encodes conditioning frames via VAE
- Optionally pre-computes KV cache for conditioning
- Uses cached K/V during denoising for speedup
- Concatenates conditioning back after denoising
Source code in fastvideo/pipelines/lora_pipeline.py
| def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.device = get_local_torch_device()
# build list of trainable transformers
for transformer_name in self.trainable_transformer_names:
if transformer_name in self.modules and self.modules[
transformer_name] is not None:
self.trainable_transformer_modules[
transformer_name] = self.modules[transformer_name]
# check for transformer_2 in case of Wan2.2 MoE or fake_score_transformer_2
if transformer_name.endswith("_2"):
raise ValueError(
f"trainable_transformer_name override in pipelines should not include _2 suffix: {transformer_name}"
)
secondary_transformer_name = transformer_name + "_2"
if secondary_transformer_name in self.modules and self.modules[
secondary_transformer_name] is not None:
self.trainable_transformer_modules[
secondary_transformer_name] = self.modules[
secondary_transformer_name]
logger.info("trainable_transformer_modules: %s",
self.trainable_transformer_modules.keys())
for transformer_name, transformer_module in self.trainable_transformer_modules.items(
):
self.exclude_lora_layers[
transformer_name] = transformer_module.config.arch_config.exclude_lora_layers
self.lora_target_modules = self.fastvideo_args.lora_target_modules
self.lora_path = self.fastvideo_args.lora_path
self.lora_nickname = self.fastvideo_args.lora_nickname
self.training_mode = self.fastvideo_args.training_mode
if self.training_mode and getattr(self.fastvideo_args, "lora_training",
False):
assert isinstance(self.fastvideo_args, TrainingArgs)
if self.fastvideo_args.lora_alpha is None:
self.fastvideo_args.lora_alpha = self.fastvideo_args.lora_rank
self.lora_rank = self.fastvideo_args.lora_rank # type: ignore
self.lora_alpha = self.fastvideo_args.lora_alpha # type: ignore
logger.info("Using LoRA training with rank %d and alpha %d",
self.lora_rank, self.lora_alpha)
if self.lora_target_modules is None:
self.lora_target_modules = [
"q_proj", "k_proj", "v_proj", "o_proj", "to_q", "to_k",
"to_v", "to_out", "to_qkv", "to_gate_compress"
]
logger.info(
"Using default lora_target_modules for all transformers: %s",
self.lora_target_modules)
else:
logger.warning(
"Using custom lora_target_modules for all transformers, which may not be intended: %s",
self.lora_target_modules)
self.convert_to_lora_layers()
# Inference
elif not self.training_mode and self.lora_path is not None:
self.convert_to_lora_layers()
self.set_lora_adapter(
self.lora_nickname, # type: ignore
self.lora_path) # type: ignore
|
Functions
fastvideo.pipelines.basic.longcat.longcat_vc_pipeline.LongCatVideoContinuationPipeline.create_pipeline_stages
Set up VC-specific pipeline stages.
Source code in fastvideo/pipelines/basic/longcat/longcat_vc_pipeline.py
| def create_pipeline_stages(self, fastvideo_args: FastVideoArgs):
"""Set up VC-specific pipeline stages."""
# 1. Input validation
self.add_stage(stage_name="input_validation_stage",
stage=InputValidationStage())
# 2. Text encoding
self.add_stage(stage_name="prompt_encoding_stage",
stage=TextEncodingStage(
text_encoders=[self.get_module("text_encoder")],
tokenizers=[self.get_module("tokenizer")],
))
# 3. Video VAE encoding (encodes conditioning frames)
self.add_stage(
stage_name="video_vae_encoding_stage",
stage=LongCatVideoVAEEncodingStage(vae=self.get_module("vae")))
# 4. Timestep preparation
self.add_stage(stage_name="timestep_preparation_stage",
stage=TimestepPreparationStage(
scheduler=self.get_module("scheduler")))
# 5. Latent preparation (reuse I2V stage - it handles video_latent too)
self.add_stage(stage_name="latent_preparation_stage",
stage=LongCatVCLatentPreparationStage(
scheduler=self.get_module("scheduler"),
transformer=self.get_module("transformer")))
# 6. KV cache initialization (optional, based on config)
# This is always added but will skip if use_kv_cache=False
self.add_stage(stage_name="kv_cache_init_stage",
stage=LongCatKVCacheInitStage(
transformer=self.get_module("transformer")))
# 7. Denoising with VC and KV cache support
self.add_stage(stage_name="denoising_stage",
stage=LongCatVCDenoisingStage(
transformer=self.get_module("transformer"),
transformer_2=self.get_module("transformer_2", None),
scheduler=self.get_module("scheduler"),
vae=self.get_module("vae"),
pipeline=self))
# 8. Decoding
self.add_stage(stage_name="decoding_stage",
stage=DecodingStage(vae=self.get_module("vae"),
pipeline=self))
|
fastvideo.pipelines.basic.longcat.longcat_vc_pipeline.LongCatVideoContinuationPipeline.initialize_pipeline
Initialize LongCat-specific components.
Source code in fastvideo/pipelines/basic/longcat/longcat_vc_pipeline.py
| def initialize_pipeline(self, fastvideo_args: FastVideoArgs):
"""Initialize LongCat-specific components."""
pipeline_config = fastvideo_args.pipeline_config
transformer = self.get_module("transformer", None)
if transformer is None:
return
# Enable BSA if configured (for VC, BSA may not be needed)
if getattr(pipeline_config, 'enable_bsa', False):
bsa_params_cfg = getattr(pipeline_config, 'bsa_params', None) or {}
sparsity = getattr(pipeline_config, 'bsa_sparsity', None)
cdf_threshold = getattr(pipeline_config, 'bsa_cdf_threshold', None)
chunk_q = getattr(pipeline_config, 'bsa_chunk_q', None)
chunk_k = getattr(pipeline_config, 'bsa_chunk_k', None)
effective_bsa_params = dict(bsa_params_cfg) if isinstance(
bsa_params_cfg, dict) else {}
if sparsity is not None:
effective_bsa_params['sparsity'] = sparsity
if cdf_threshold is not None:
effective_bsa_params['cdf_threshold'] = cdf_threshold
if chunk_q is not None:
effective_bsa_params['chunk_3d_shape_q'] = chunk_q
if chunk_k is not None:
effective_bsa_params['chunk_3d_shape_k'] = chunk_k
# Provide defaults
effective_bsa_params.setdefault('sparsity', 0.9375)
effective_bsa_params.setdefault('chunk_3d_shape_q', [4, 4, 4])
effective_bsa_params.setdefault('chunk_3d_shape_k', [4, 4, 4])
if hasattr(transformer, 'enable_bsa'):
logger.info("Enabling BSA for LongCat VC transformer")
transformer.enable_bsa()
if hasattr(transformer, 'blocks'):
try:
for blk in transformer.blocks:
if hasattr(blk, 'self_attn'):
blk.self_attn.bsa_params = effective_bsa_params
except Exception as e:
logger.warning("Failed to set BSA params: %s", e)
logger.info("BSA parameters: %s", effective_bsa_params)
else:
if hasattr(transformer, 'disable_bsa'):
transformer.disable_bsa()
|