Source examples/inference/sta_mask_search.
STA Mask Search Examples#
bash examples/inference/sta_mask_search/inference_wan_sta.sh
Example materials#
inference_wan_sta.sh
#!/bin/bash
export FASTVIDEO_ATTENTION_CONFIG=assets/mask_strategy_wan.json
export FASTVIDEO_ATTENTION_BACKEND=SLIDING_TILE_ATTN
export MODEL_BASE=Wan-AI/Wan2.1-T2V-14B-Diffusers
base_port=29503
num_gpu=1
gpu_ids=$(seq 0 $((num_gpu-1)))
skip_time_steps=12
output_path="inference_results/sta/mask_search_full"
STA_mode="STA_searching"
for i in $gpu_ids; do
port=$((base_port+i))
CUDA_VISIBLE_DEVICES=$i MASTER_PORT=$port python examples/inference/sta_mask_search/wan_example.py \
--prompt_path ./assets/prompt_${i}.txt \
--output_path $output_path \
--STA_mode $STA_mode &
sleep 1
done
wait
echo "STA searching completed"
output_path="inference_results/sta/mask_search_sparse"
STA_mode="STA_tuning"
for i in $gpu_ids; do
port=$((base_port+i))
CUDA_VISIBLE_DEVICES=$i MASTER_PORT=$port python examples/inference/sta_mask_search/wan_example.py \
--prompt_path ./assets/prompt_${i}.txt \
--output_path $output_path \
--STA_mode $STA_mode \
--skip_time_steps $skip_time_steps &
sleep 1
done
wait
echo "STA tuning completed"
echo "All jobs completed"
wan_example.py
import os
import argparse
from fastvideo import VideoGenerator, SamplingParam
def main(args):
os.makedirs(args.output_path, exist_ok=True)
# Create a video generator with a pre-trained model
generator = VideoGenerator.from_pretrained(
"Wan-AI/Wan2.1-T2V-14B-Diffusers",
num_gpus=args.num_gpus, # Adjust based on your hardware
STA_mode=args.STA_mode,
skip_time_steps=args.skip_time_steps
)
# Prompts for your video
prompt = args.prompt
prompt_path = args.prompt_path
negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
if prompt_path is not None:
with open(prompt_path, "r") as f:
prompts = f.readlines()
else:
prompts = [prompt]
params = SamplingParam(
height=args.height,
width=args.width,
num_frames=args.num_frames,
num_inference_steps=args.num_inference_steps,
fps=args.fps,
guidance_scale=args.guidance_scale,
seed=args.seed,
return_frames=True, # Also return frames from this call (defaults to False)
output_path=args.output_path, # Controls where videos are saved
save_video=True,
negative_prompt=negative_prompt
)
# Generate the video
for prompt in prompts:
video = generator.generate_video(
prompt,
sampling_param=params,
)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--prompt", type=str, default="A man is dancing.")
parser.add_argument("--prompt_path", type=str, default=None)
parser.add_argument("--height", type=int, default=768)
parser.add_argument("--width", type=int, default=1280)
parser.add_argument("--num_frames", type=int, default=69)
parser.add_argument("--num_inference_steps", type=int, default=50)
parser.add_argument("--fps", type=int, default=16)
parser.add_argument("--guidance_scale", type=float, default=5.0)
parser.add_argument("--seed", type=int, default=12345)
parser.add_argument("--output_path", type=str, default="my_videos/")
parser.add_argument("--num_gpus", type=int, default=1)
parser.add_argument("--STA_mode", type=str, default="STA_searching")
parser.add_argument("--skip_time_steps", type=int, default=12)
args = parser.parse_args()
main(args)