🎯 Distillation#

We introduce a new finetuning strategy - Sparse-distill, which jointly integrates DMD and VSA in a single training process. This approach combines the benefits of both distillation to shorten diffusion steps and sparse attention to reduce attention computations, enabling much faster video generation.

📊 Model Overview#

We provide two distilled models:

Both models are trained on 61×448×832 resolution but support generating videos with any resolution (1.3B model mainly support 480P, 14B model support 480P and 720P, quality may degrade for different resolutions).

⚙️ Inference#

First install VSA. Set MODEL_BASE to your own model path and run:

bash scripts/inference/v1_inference_wan_dmd.sh

🗂️ Dataset#

We use the FastVideo 480P Synthetic Wan dataset (FastVideo/Wan-Syn_77x448x832_600k) for distillation, which contains 600k synthetic latents.

Download Dataset#

# Download the preprocessed dataset
python scripts/huggingface/download_hf.py \
    --repo_id "FastVideo/Wan-Syn_77x448x832_600k" \
    --local_dir "FastVideo/Wan-Syn_77x448x832_600k" \
    --repo_type "dataset"

🚀 Training Scripts#

Wan2.1 1.3B Model Sparse-Distill#

For the 1.3B model, we use 4 nodes with 32 H200 GPUs (8 GPUs per node):

# Multi-node training (8 nodes, 64 GPUs total)
sbatch examples/distill/Wan2.1-T2V/Wan-Syn-Data-480P/distill_dmd_VSA_t2v_1.3B.slurm

Key Configuration:

  • Global batch size: 64

  • Gradient accumulation steps: 2

  • Learning rate: 1e-5

  • VSA attention sparsity: 0.8

  • Training steps: 4000 (~12 hours)

Wan2.1 14B Model Sparse-Distill#

For the 14B model, we use 8 nodes with 64 H200 GPUs (8 GPUs per node):

# Multi-node training (8 nodes, 64 GPUs total)
sbatch examples/distill/Wan2.1-T2V/Wan-Syn-Data-480P/distill_dmd_VSA_t2v_14B.slurm

Key Configuration:

  • Global batch size: 64

  • Sequence parallel size: 4

  • Gradient accumulation steps: 4

  • Learning rate: 1e-5

  • VSA attention sparsity: 0.9

  • Training steps: 3000 (~52 hours)

  • HSDP shard dim: 8

Wan2.2 5B Model Sparse-Distill#

For the 5B model, we use 8 nodes with 64 H200 GPUs (8 GPUs per node):

# Multi-node training (8 nodes, 64 GPUs total)
sbatch examples/distill/Wan2.2-TI2V-5B-Diffusers/Data-free/distill_dmd_t2v_5B.sh 

Key Configuration:

  • Global batch size: 64

  • Sequence parallel size: 1

  • Gradient accumulation steps: 1

  • Learning rate: 2e-5

  • Training steps: 3000 (~12 hours)

  • HSDP shard dim: 1