扩散模型的训练时比较简单的
上图可见,unet是epsθ是unet。noise和预测出来的noise做个mse loss。
训练的常规过程:
latents = vae.encode(batch["pixel_values"].to(weight_dtype)).latent_dist_sample()
latents = latents*vae.config.scaling_factor
noise = torch.randn_like(latents)
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
encoder_hidden_states = text_encoder(batch["input_ids"])[0]
target = noise
model_pred = unet(noisy_latents, timesteps, encode