def forward(
self,
batch: ForwardBatch,
fastvideo_args: FastVideoArgs,
) -> ForwardBatch:
"""
Validate and prepare inputs.
Args:
batch: The current batch information.
fastvideo_args: The inference arguments.
Returns:
The validated batch information.
"""
self._generate_seeds(batch, fastvideo_args)
# Ensure prompt is properly formatted
if batch.prompt is None and batch.prompt_embeds is None:
raise ValueError(
"Either `prompt` or `prompt_embeds` must be provided")
# Ensure negative prompt is properly formatted if using classifier-free guidance
if (batch.do_classifier_free_guidance and batch.negative_prompt is None
and batch.negative_prompt_embeds is None):
raise ValueError(
"For classifier-free guidance, either `negative_prompt` or "
"`negative_prompt_embeds` must be provided")
# Validate height and width
if batch.height is None or batch.width is None:
raise ValueError(
"Height and width must be provided. Please set `height` and `width`."
)
if batch.height % 8 != 0 or batch.width % 8 != 0:
raise ValueError(
f"Height and width must be divisible by 8 but are {batch.height} and {batch.width}."
)
# Validate number of inference steps
if batch.num_inference_steps <= 0:
raise ValueError(
f"Number of inference steps must be positive, but got {batch.num_inference_steps}"
)
# Validate guidance scale if using classifier-free guidance
if batch.do_classifier_free_guidance and batch.guidance_scale <= 0:
raise ValueError(
f"Guidance scale must be positive, but got {batch.guidance_scale}"
)
# for i2v, get image from image_path
# @TODO(Wei) hard-coded for wan2.2 5b ti2v for now. Should put this in image_encoding stage
if batch.image_path is not None:
if batch.image_path.endswith(".mp4"):
image = load_video(batch.image_path)[0]
else:
image = load_image(batch.image_path)
batch.pil_image = image
# further processing for ti2v task
if (fastvideo_args.pipeline_config.ti2v_task
or fastvideo_args.pipeline_config.is_causal
) and batch.pil_image is not None:
img = batch.pil_image
ih, iw = img.height, img.width
pipeline_class_name = type(fastvideo_args.pipeline_config).__name__
if 'MatrixGame' in pipeline_class_name or 'MatrixCausal' in pipeline_class_name:
oh, ow = batch.height, batch.width
img = img.resize((ow, oh), Image.LANCZOS)
else:
# Standard Wan logic
patch_size = fastvideo_args.pipeline_config.dit_config.arch_config.patch_size
vae_stride = fastvideo_args.pipeline_config.vae_config.arch_config.scale_factor_spatial
dh, dw = patch_size[1] * vae_stride, patch_size[2] * vae_stride
max_area = 480 * 832
ow, oh = best_output_size(iw, ih, dw, dh, max_area)
scale = max(ow / iw, oh / ih)
img = img.resize((round(iw * scale), round(ih * scale)),
Image.LANCZOS)
# center-crop
x1 = (img.width - ow) // 2
y1 = (img.height - oh) // 2
img = img.crop((x1, y1, x1 + ow, y1 + oh))
assert img.width == ow and img.height == oh
logger.info("final processed img height: %s, img width: %s",
img.height, img.width)
# to tensor
img = TF.to_tensor(img).sub_(0.5).div_(0.5).to(
self.device).unsqueeze(1)
img = img.unsqueeze(0)
batch.height = oh
batch.width = ow
batch.pil_image = img
# for v2v, get control video from video path
if batch.video_path is not None:
pil_images, original_fps = load_video(batch.video_path,
return_fps=True)
logger.info("Loaded video with %s frames, original FPS: %s",
len(pil_images), original_fps)
# Get target parameters from batch
target_fps = batch.fps
target_num_frames = batch.num_frames
target_height = batch.height
target_width = batch.width
if target_fps is not None and original_fps is not None:
frame_skip = max(1, int(original_fps // target_fps))
if frame_skip > 1:
pil_images = pil_images[::frame_skip]
effective_fps = original_fps / frame_skip
logger.info(
"Resampled video from %.1f fps to %.1f fps (skip=%s)",
original_fps, effective_fps, frame_skip)
# Limit to target number of frames
if target_num_frames is not None and len(
pil_images) > target_num_frames:
pil_images = pil_images[:target_num_frames]
logger.info("Limited video to %s frames (from %s total)",
target_num_frames, len(pil_images))
# Resize each PIL image to target dimensions
resized_images = []
for pil_img in pil_images:
resized_img = resize(pil_img,
target_height,
target_width,
resize_mode="default",
resample="lanczos")
resized_images.append(resized_img)
# Convert PIL images to numpy array
video_numpy = pil_to_numpy(resized_images)
video_numpy = normalize(video_numpy)
video_tensor = numpy_to_pt(video_numpy)
# Rearrange to [C, T, H, W] and add batch dimension -> [B, C, T, H, W]
input_video = video_tensor.permute(1, 0, 2, 3).unsqueeze(0)
batch.video_latent = input_video
# Validate action control inputs (Matrix-Game)
if batch.mouse_cond is not None:
if batch.mouse_cond.dim() != 3 or batch.mouse_cond.shape[-1] != 2:
raise ValueError(
f"mouse_cond must have shape (B, T, 2), but got {batch.mouse_cond.shape}"
)
logger.info("Action control: mouse_cond validated - shape %s",
batch.mouse_cond.shape)
if batch.keyboard_cond is not None:
if batch.keyboard_cond.dim() != 3:
raise ValueError(
f"keyboard_cond must have 3 dimensions (B, T, K), but got {batch.keyboard_cond.dim()}"
)
keyboard_dim = batch.keyboard_cond.shape[-1]
if keyboard_dim not in {2, 4, 6, 7}:
raise ValueError(
f"keyboard_cond last dimension must be 2, 4, 6, or 7, but got {keyboard_dim}"
)
logger.info(
"Action control: keyboard_cond validated - shape %s (dim=%d)",
batch.keyboard_cond.shape, keyboard_dim)
if batch.grid_sizes is not None:
if not isinstance(batch.grid_sizes, list | tuple | torch.Tensor):
raise ValueError("grid_sizes must be a list, tuple, or tensor")
if isinstance(batch.grid_sizes, torch.Tensor):
if batch.grid_sizes.numel() != 3:
raise ValueError(
"grid_sizes must have 3 elements [F, H, W]")
else:
if len(batch.grid_sizes) != 3:
raise ValueError(
"grid_sizes must have 3 elements [F, H, W]")
logger.info("Action control: grid_sizes validated - %s",
batch.grid_sizes)
return batch