def forward(
self,
batch: ForwardBatch,
fastvideo_args: FastVideoArgs,
) -> ForwardBatch:
"""Run denoising loop with I2V conditioning."""
# Load transformer if needed
if not fastvideo_args.model_loaded["transformer"]:
loader = TransformerLoader()
self.transformer = loader.load(
fastvideo_args.model_paths["transformer"], fastvideo_args)
fastvideo_args.model_loaded["transformer"] = True
# Setup
target_dtype = torch.bfloat16
autocast_enabled = (target_dtype != torch.float32
) and not fastvideo_args.disable_autocast
latents = batch.latents
timesteps = batch.timesteps
prompt_embeds = batch.prompt_embeds[0]
prompt_attention_mask = (batch.prompt_attention_mask[0]
if batch.prompt_attention_mask else None)
guidance_scale = batch.guidance_scale
do_classifier_free_guidance = batch.do_classifier_free_guidance
# Get num_cond_latents from batch
num_cond_latents = getattr(batch, 'num_cond_latents', 0)
if num_cond_latents > 0:
logger.info("I2V Denoising: num_cond_latents=%s, latent_shape=%s",
num_cond_latents, latents.shape)
# Prepare negative prompts for CFG
if do_classifier_free_guidance:
negative_prompt_embeds = batch.negative_prompt_embeds[0]
negative_prompt_attention_mask = (batch.negative_attention_mask[0]
if batch.negative_attention_mask
else None)
prompt_embeds_combined = torch.cat(
[negative_prompt_embeds, prompt_embeds], dim=0)
if prompt_attention_mask is not None:
prompt_attention_mask_combined = torch.cat(
[negative_prompt_attention_mask, prompt_attention_mask],
dim=0)
else:
prompt_attention_mask_combined = None
else:
prompt_embeds_combined = prompt_embeds
prompt_attention_mask_combined = prompt_attention_mask
# Denoising loop
num_inference_steps = len(timesteps)
with tqdm(total=num_inference_steps,
desc="I2V Denoising") as progress_bar:
for i, t in enumerate(timesteps):
# 1. Expand latents for CFG
if do_classifier_free_guidance:
latent_model_input = torch.cat([latents] * 2)
else:
latent_model_input = latents
latent_model_input = latent_model_input.to(target_dtype)
# 2. Expand timestep to match batch size
timestep = t.expand(
latent_model_input.shape[0]).to(target_dtype)
# 3. CRITICAL: Expand timestep to temporal dimension
# and set conditioning frames to timestep=0
timestep = timestep.unsqueeze(-1).repeat(
1, latent_model_input.shape[2])
# Mark conditioning frames as clean (timestep=0)
if num_cond_latents > 0:
timestep[:, :num_cond_latents] = 0
# 4. Run transformer with num_cond_latents
batch.is_cfg_negative = False
with set_forward_context(
current_timestep=i,
attn_metadata=None,
forward_batch=batch,
), torch.autocast(device_type='cuda',
dtype=target_dtype,
enabled=autocast_enabled):
noise_pred = self.transformer(
hidden_states=latent_model_input,
encoder_hidden_states=prompt_embeds_combined,
timestep=timestep,
encoder_attention_mask=prompt_attention_mask_combined,
num_cond_latents=num_cond_latents,
)
# 5. Apply CFG with optimized scale
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
B = noise_pred_cond.shape[0]
positive = noise_pred_cond.reshape(B, -1)
negative = noise_pred_uncond.reshape(B, -1)
# CFG-zero optimized scale
st_star = self.optimized_scale(positive, negative)
st_star = st_star.view(B, 1, 1, 1, 1)
noise_pred = (
noise_pred_uncond * st_star + guidance_scale *
(noise_pred_cond - noise_pred_uncond * st_star))
# 6. CRITICAL: Negate for flow matching scheduler
noise_pred = -noise_pred
# 7. CRITICAL: Only update non-conditioned frames
# The conditioning frames stay FIXED throughout denoising
if num_cond_latents > 0:
latents[:, :, num_cond_latents:] = self.scheduler.step(
noise_pred[:, :, num_cond_latents:],
t,
latents[:, :, num_cond_latents:],
return_dict=False)[0]
else:
# No conditioning, update all frames
latents = self.scheduler.step(noise_pred,
t,
latents,
return_dict=False)[0]
progress_bar.update()
# Update batch with denoised latents
batch.latents = latents
return batch