DiAD代码逐行理解之test.py

1、代码

# 读取对应的数据,按照指定的格式封装成dict
dataset = MVTecDataset('test','/home/mby/DDPM/data/mvtec_anomaly_detection/')
# test_dataset = VisaDataset('test')
# torch.utils.data.DataLoader 是 PyTorch 中的一个非常有用的类,它提供了一种简便的方式来加载数据,并能够在训练深度学习模型时自动进行批处理(batching)、打乱数据(shuffling)、多进程数据加载等操作。DataLoader 封装了对数据集的迭代过程,使得批量获取数据、迭代数据集变得更加容易。
dataloader = DataLoader(dataset, num_workers=8, batch_size=batch_size, shuffle=True)
#加载模型:通过调用 timm.create_model 函数,指定了 "resnet50" 作为要加载的模型架构。ResNet-50 是一种流行的深度残差网络,广泛应用于图像识别和分类任务中。预训练:通过设置 pretrained=True,您指示该函数加载一个在 ImageNet 数据集上预训练的 ResNet-50 模型。这意呀着模型已经过训练,可以识别 ImageNet 数据集中的 1000 个类别,并且可以用作特征提取器或进一步微调以适应其他任务。特征提取模式:通过设置 features_only=True,您告诉 timm 仅返回模型的特征提取部分,而不是整个模型(包括分类层)。这对于特征提取和迁移学习场景特别有用,因为您可能想要使用预训练模型的中间层来提取图像特征,并将这些特征用于其他任务(如图像检索、图像分割等),或者将它们作为新分类器的输入。
pretrained_model = timm.create_model("resnet50", pretrained=True, features_only=True)
#使用 .cuda() 方法可以将模型的所有参数和数据移动到 GPU 上,以便利用 GPU 的并行计算能力来加速模型的推理过程。
pretrained_model = pretrained_model.cuda()
#设置模型为评估模式
pretrained_model.eval()

model.eval()
os.makedirs(evl_dir, exist_ok=True)
with torch.no_grad():
    for input in dataloader:
        input_img = input['jpg']
        input_features = pretrained_model(input_img.cuda())
        model = model.cuda()
        output= model.log_images_test(input)
        images = output
        log_local(images, input["filename"][0])
        output_img = images['samples']
        output_features = pretrained_model(output_img.cuda())
        input_features = input_features[1:4]
        output_features = output_features[1:4]

        # Calculate the anomaly score
        anomaly_map, _ = cal_anomaly_map(input_features, output_features, input_img.shape[-1], amap_mode='a')
        anomaly_map = gaussian_filter(anomaly_map, sigma=5)
        anomaly_map = torch.from_numpy(anomaly_map)
        anomaly_map_prediction = anomaly_map.unsqueeze(dim=0).unsqueeze(dim=1)
        input["mask"] = input["mask"]

        root = os.path.join('log_image/')
        name = input["filename"][0][-7:-4]
        filename_feature = "{}-features.jpg".format(name)
        path_feature = os.path.join(root, input["filename"][0][:-7], filename_feature)
        pred_feature = anomaly_map_prediction.squeeze().detach().cpu().numpy()
        pred_feature = (pred_feature * 255).astype("uint8")
        pred_feature = Image.fromarray(pred_feature, mode='L')
        pred_feature.save(path_feature)

        #Heatmap
        anomaly_map_new = np.round(255 * (anomaly_map - anomaly_map.min()) / (anomaly_map.max() - anomaly_map.min()))
        anomaly_map_new = anomaly_map_new.cpu().numpy().astype(np.uint8)
        heatmap = cv2.applyColorMap(anomaly_map_new, colormap=cv2.COLORMAP_JET)
        pixel_mean = [0.485, 0.456, 0.406]
        pixel_std = [0.229, 0.224, 0.225]
        pixel_mean = torch.tensor(pixel_mean).unsqueeze(1).unsqueeze(1)  # 3 x 1 x 1
        pixel_std = torch.tensor(pixel_std).unsqueeze(1).unsqueeze(1)
        image = (input_img.squeeze() * pixel_std + pixel_mean) * 255
        image = image.permute(1, 2, 0).to('cpu').numpy().astype('uint8')
        image_copy = image.copy()
        out_heat_map = cv2.addWeighted(heatmap, 0.5, image_copy, 0.5, 0, image_copy)
        heatmap_name = "{}-heatmap.png".format(name)
        cv2.imwrite(root + input["filename"][0][:-7] + heatmap_name, out_heat_map)

        input['pred'] = anomaly_map_prediction
        input["output"] = output_img.cpu()
        input["input"] = input_img.cpu()

        output2 = input
        dump(evl_dir, output2)

evl_metrics = {'auc': [ {'name': 'max'}, {'name': 'pixel'}, {'name': 'pro'}, {'name': 'appx'}, {'name': 'apsp'}, {'name': 'f1px'}, {'name': 'f1sp'}]}
print("Gathering final results ...")
fileinfos, preds, masks = merge_together(evl_dir)
ret_metrics = performances(fileinfos, preds, masks, evl_metrics)
log_metrics(ret_metrics, evl_metrics)

2、代码理解

get_input函数从一个给定的批次(batch)数据中提取输入特征x和控制信息control,并将它们以特定的格式返回,以供后续的网络层或模型使用。

@torch.no_grad()
def get_input(self, batch, k, bs=None, *args, **kwargs):
    x, c = super().get_input(batch, self.first_stage_key, *args, **kwargs)
    control = batch[self.control_key]
    if bs is not None:
        control = control[:bs]
    control = control.to(self.device)
    # control = einops.rearrange(control, 'b h w c -> b c h w')
    control = control.to(memory_format=torch.contiguous_format).float()
    return x, dict(c_crossattn=[c], c_concat=[control])
 def log_images_test(self, batch, N=4, n_row=2, sample=False, ddim_steps=10, ddim_eta=0.0, return_keys=None,
                   quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True,
                   plot_diffusion_rows=False, unconditional_guidance_scale=9.0, unconditional_guidance_label=None,
                   use_ema_scope=True,
                   **kwargs):
        use_ddim = ddim_steps is not None

        log = dict()
        z, c = self.get_input(batch, self.first_stage_key, bs=N)
        c_cat = c["c_concat"][0][:N]
        c = c["c_crossattn"][0][:N]
        N = min(z.shape[0], N)
        n_row = min(z.shape[0], n_row)
        log["reconstruction"] = self.decode_first_stage(z)
        # log["control"] = c_cat * 2.0 - 1.0
        log["input"] = c_cat

        t = torch.randint(999, 1000, (z.shape[0],), device=self.device).long()
        noise = torch.randn_like(z)
        x_noisy = self.q_sample(x_start=z, t=t, noise=noise)

        if sample:
            # get denoise row
            samples, z_denoise_row = self.sample_log(cond={"c_concat": [c_cat], "c_crossattn": [c]},
                                                     batch_size=N, ddim=use_ddim,
                                                     ddim_steps=ddim_steps, eta=ddim_eta)
            x_samples = self.decode_first_stage(samples)
            log["samples"] = x_samples
            if plot_denoise_rows:
                denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
                log["denoise_row"] = denoise_grid

        if unconditional_guidance_scale > 1.0:
            uc_cross = self.get_unconditional_conditioning(N)
            uc_cat = c_cat  # torch.zeros_like(c_cat)
            uc_full = {"c_concat": [uc_cat], "c_crossattn": [uc_cross]}
            samples_cfg, inter = self.sample_log_test(cond={"c_concat": [c_cat], "c_crossattn": [c]},
                                             batch_size=N, ddim=use_ddim,
                                             ddim_steps=ddim_steps, eta=ddim_eta,
                                             unconditional_guidance_scale=unconditional_guidance_scale,
                                             unconditional_conditioning=uc_full,
                                            x_T=x_noisy, timesteps=t
                                             )
            x_samples_cfg = self.decode_first_stage(samples_cfg)
            log["samples"] = x_samples_cfg
            # log["samples"] = x_samples_cfg

        return log
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值