nanodiffusion代码逐行理解之sample

主要输调用过程

采样主函数中分为两个流程:
流程一:利用make_grid 函数生成网格图像,再利用迭代器,把这个生成过程显示出来。1000步的生成过程,每隔5步生成一个网格图像,转换到(0, 255),再从rgb转到bgr
流程二:采样之后保存图像

num_each_label = num_samples // num_classes

    if vis_process:
        for label in range(num_classes):
            y = torch.ones(num_each_label, dtype=torch.long, device=device) * label

            def generate_images() -> "yield image numpy array":
                gen = diffusion.sample_diffusion_sequence(num_each_label, device, y)
                for idx, image_tensor in tqdm.tqdm(enumerate(gen), desc=f"Generating for label {label}..", total=args.num_timesteps):
                    if idx % 5 != 0:  # 1000 / 5 = 200 frames
                        continue
                    grid = make_grid(image_tensor, nrow=num_each_label)
                    arr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy()
                    image_bgr = arr[..., ::-1]
                    yield image_bgr

            to_video = os.path.join(save_dir, f"{label}.mp4")
            cv2_utils.images_to_video(generate_images(), to_video)

            to_gif = os.path.join(save_dir, f"{label}.gif")
            cv2_utils.images_to_gif(list(generate_images()), to_gif)
    else:
        for label in range(num_classes):
            y = torch.ones(num_each_label, dtype=torch.long, device=device) * label
            samples = diffusion.sample(num_each_label, device, y=y)
            for image_id in range(len(samples)):
                image = ((samples[image_id] + 1) / 2).clip(0, 1)
                torchvision.utils.save_image(image, f"{save_dir}/{label}-{image_id}.png")

采样函数

@torch.no_grad()
def sample(self, batch_size, device, y=None, use_ema=True):
    if y is not None and batch_size != len(y):
        raise ValueError("sample batch size different from length of given y")

    x = torch.randn(batch_size, self.img_channels, *self.img_size, device=device)
    
    for t in range(self.num_timesteps - 1, -1, -1):
        t_batch = torch.tensor([t], device=device).repeat(batch_size)
        x = self.remove_noise(x, t_batch, y, use_ema)

        if t > 0:
            x += extract(self.sigma, t_batch, x.shape) * torch.randn_like(x)
    
    return x.cpu().detach()
 def sample_diffusion_sequence(self, batch_size, device, y=None, use_ema=True):
     if y is not None and batch_size != len(y):
         raise ValueError("sample batch size different from length of given y")

     x = torch.randn(batch_size, self.img_channels, *self.img_size, device=device)
     for t in range(self.num_timesteps - 1, -1, -1):
         t_batch = torch.tensor([t], device=device).repeat(batch_size)
         x = self.remove_noise(x, t_batch, y, use_ema)

         if t > 0:
             x += extract(self.sigma, t_batch, x.shape) * torch.randn_like(x)
         
         yield x.cpu().detach()

这部分代码实现很简单,先是从高斯分布中采样,然后逐步去噪声,加干扰。

sample和sample_diffusion_sequence差别就在返回值
sample:return x.cpu().detach()
.cpu() 将数据的处理设备从其他设备(如.cuda()拿到cpu上),不会改变变量类型,转换后仍然是Tensor变量。 为什么需要这一步? 因为gpu上的数组不能直接进行转换类型的操作。
.detach() 函数可以返回一个完全相同的tensor,新的tensor开辟与旧的tensor共享内存,新的tensor会脱离计算图,不会牵扯梯度计算。也就是requires_grad=False, 因此可以 接着进行numpy() 的操作,解决了numpy()需要建立在无梯度的tensor的基础上的问题。
sample_diffusion_sequence: yield x.cpu().detach()
yield:在Python中,yield 关键字用于从函数中返回一个生成器(generator)。生成器是一个可以记住上一次返回位置的对象,并在下一次迭代时从该位置继续执行。这使得它们非常适合用于需要逐个处理大量数据的场景,因为它们可以按需生成数据,从而节省内存。

当你调用一个包含 yield 的函数时,该函数不会立即执行其代码,而是返回一个迭代器(即生成器)。然后,你可以通过迭代这个生成器来逐步执行函数中的代码。每次迭代时,yield 语句会“暂停”函数的执行,并返回紧随其后的值给迭代器的调用者。当迭代器再次请求下一个值时,函数会从上次暂停的位置继续执行,直到遇到下一个 yield 语句或函数结束。

  • 5
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
当然可以,下面是逐行注释的代码: ``` def learn(self): # 如果记忆库中没有足够的数据,就返回 if not self.memory.ready(): return # 从记忆库中采样一批数据 states, actions, rewards, next_states, terminals = self.memory.sample_buffer() # 构建batch索引 batch_idx = np.arange(self.batch_size) # 将样本转换成张量,并将其移动到GPU上 states_tensor = T.tensor(states, dtype=T.float).to(device) rewards_tensor = T.tensor(rewards, dtype=T.float).to(device) next_states_tensor = T.tensor(next_states, dtype=T.float).to(device) terminals_tensor = T.tensor(terminals).to(device) # 使用目标网络计算下一状态的Q值 with T.no_grad(): q_ = self.q_target.forward(next_states_tensor) q_[terminals_tensor] = 0.0 # 如果下一状态是终止状态,将其Q值设为0 target = rewards_tensor + self.gamma * T.max(q_, dim=-1)[0] # 使用当前网络计算当前状态的Q值,并根据动作索引选择对应的Q值 q = self.q_eval.forward(states_tensor)[batch_idx, actions] # 计算损失函数 loss = F.mse_loss(q, target.detach()) # 清空当前网络的优化器梯度 self.q_eval.optimizer.zero_grad() # 反向传播,计算梯度 loss.backward() # 更新当前网络的参数 self.q_eval.optimizer.step() # 更新目标网络的参数 self.update_network_parameters() # 更新epsilon值,用于epsilon-greedy策略 self.epsilon = self.epsilon - self.eps_dec if self.epsilon > self.eps_min else self.eps_min ``` 这段代码的主要作用是训练强化学习中的Q-learning算法。这个算法的主要思想是,通过不断迭代更新Q值,来逼近最优值函数,从而得到最优策略。代码中的主要步骤如下: 1. 从记忆库中采样一批数据 2. 使用目标网络计算下一状态的Q值,并根据终止状态判断是否为0 3. 使用当前网络计算当前状态的Q值,并根据动作索引选择对应的Q值 4. 计算损失函数 5. 清空当前网络的优化器梯度 6. 反向传播,计算梯度 7. 更新当前网络的参数 8. 更新目标网络的参数 9. 更新epsilon值,用于epsilon-greedy策略
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值