主要输调用过程
采样主函数中分为两个流程:
流程一:利用make_grid 函数生成网格图像,再利用迭代器,把这个生成过程显示出来。1000步的生成过程,每隔5步生成一个网格图像,转换到(0, 255),再从rgb转到bgr
流程二:采样之后保存图像
num_each_label = num_samples // num_classes
if vis_process:
for label in range(num_classes):
y = torch.ones(num_each_label, dtype=torch.long, device=device) * label
def generate_images() -> "yield image numpy array":
gen = diffusion.sample_diffusion_sequence(num_each_label, device, y)
for idx, image_tensor in tqdm.tqdm(enumerate(gen), desc=f"Generating for label {label}..", total=args.num_timesteps):
if idx % 5 != 0: # 1000 / 5 = 200 frames
continue
grid = make_grid(image_tensor, nrow=num_each_label)
arr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy()
image_bgr = arr[..., ::-1]
yield image_bgr
to_video = os.path.join(save_dir, f"{label}.mp4")
cv2_utils.images_to_video(generate_images(), to_video)
to_gif = os.path.join(save_dir, f"{label}.gif")
cv2_utils.images_to_gif(list(generate_images()), to_gif)
else:
for label in range(num_classes):
y = torch.ones(num_each_label, dtype=torch.long, device=device) * label
samples = diffusion.sample(num_each_label, device, y=y)
for image_id in range(len(samples)):
image = ((samples[image_id] + 1) / 2).clip(0, 1)
torchvision.utils.save_image(image, f"{save_dir}/{label}-{image_id}.png")
采样函数
@torch.no_grad()
def sample(self, batch_size, device, y=None, use_ema=True):
if y is not None and batch_size != len(y):
raise ValueError("sample batch size different from length of given y")
x = torch.randn(batch_size, self.img_channels, *self.img_size, device=device)
for t in range(self.num_timesteps - 1, -1, -1):
t_batch = torch.tensor([t], device=device).repeat(batch_size)
x = self.remove_noise(x, t_batch, y, use_ema)
if t > 0:
x += extract(self.sigma, t_batch, x.shape) * torch.randn_like(x)
return x.cpu().detach()
def sample_diffusion_sequence(self, batch_size, device, y=None, use_ema=True):
if y is not None and batch_size != len(y):
raise ValueError("sample batch size different from length of given y")
x = torch.randn(batch_size, self.img_channels, *self.img_size, device=device)
for t in range(self.num_timesteps - 1, -1, -1):
t_batch = torch.tensor([t], device=device).repeat(batch_size)
x = self.remove_noise(x, t_batch, y, use_ema)
if t > 0:
x += extract(self.sigma, t_batch, x.shape) * torch.randn_like(x)
yield x.cpu().detach()
这部分代码实现很简单,先是从高斯分布中采样,然后逐步去噪声,加干扰。
sample和sample_diffusion_sequence差别就在返回值
sample:return x.cpu().detach()
.cpu() 将数据的处理设备从其他设备(如.cuda()拿到cpu上),不会改变变量类型,转换后仍然是Tensor变量。 为什么需要这一步? 因为gpu上的数组不能直接进行转换类型的操作。
.detach() 函数可以返回一个完全相同的tensor,新的tensor开辟与旧的tensor共享内存,新的tensor会脱离计算图,不会牵扯梯度计算。也就是requires_grad=False, 因此可以 接着进行numpy() 的操作,解决了numpy()需要建立在无梯度的tensor的基础上的问题。
sample_diffusion_sequence: yield x.cpu().detach()
yield:在Python中,yield 关键字用于从函数中返回一个生成器(generator)。生成器是一个可以记住上一次返回位置的对象,并在下一次迭代时从该位置继续执行。这使得它们非常适合用于需要逐个处理大量数据的场景,因为它们可以按需生成数据,从而节省内存。
当你调用一个包含 yield 的函数时,该函数不会立即执行其代码,而是返回一个迭代器(即生成器)。然后,你可以通过迭代这个生成器来逐步执行函数中的代码。每次迭代时,yield 语句会“暂停”函数的执行,并返回紧随其后的值给迭代器的调用者。当迭代器再次请求下一个值时,函数会从上次暂停的位置继续执行,直到遇到下一个 yield 语句或函数结束。