pytorch可视化输出的图像

7 篇文章 0 订阅
5 篇文章 0 订阅

最近在学习GAN,跑了两个目前是sota的开源代码:https://github.com/Tencent/Real-SR 以及 https://github.com/ManuelFritsche/real-world-sr ,前者是在esrgan的基础上做了一定的改动,后者包含dsgan和esrgan-fs两部分,其中esrgan-fs与腾讯的代码均在BasicSR增添了自己的改进,总体的框架基本不变。.

在学习的时候,主要是dsgan部分和esrgan部分对生成器图片和可视化的操作有一点不同,在此做以记录。

                                  esrgan                                                dsgan
将图片转为tensor进行训练
  1. 利用cv2读取,读取出来是numpy的ndarray,维度是H,W,C
  2. 转为float32,并且除以255,转换为[0,1]
  3. BGR转RGB
  4. H,W,C转C,H,W(因为pytorch需要C,H,W格式的)
  1. 利用PIL.Image读取,读取出来是(H,W)
  2. 利用pytorch自带的 transforms.ToTensor()将图片转为Tensor,此时维度是C,H,W

注意:

transforms.ToTensor()将numpy的ndarray或PIL.Image读的图片转换成形状为(C,H, W)的Tensor格式,且/255归一化到[0,1.0]之间。通道的具体顺序与cv2读的还是PIL.Image读的图片有关系:cv2:(B,G,R);PIL.Image:(R, G, B)

将tensor转换为图片保存(Generator生成的图像可视化部分)
  1. 将tensor截断到[0,1]
  2. 拉伸到[0,1]
  3. 转为numpy格式(.numpy())
  4. RGB转BGR (因为想要用cv2保存)
  5. CHW转H,W,C(因为想要用cv2保存)
  1. 将tensor截断到[0,1](之后计算psnr)
  2. 使用pytorch自带的make_grid直接在tensorboard中可视化

注意:

make_grid的输入必须是tensor,并且值需要在[0,1]之间,并且要加.cpu()

esrgan生成图片可视化

  • 官方:直接以图片形式保存

esrgan是将生成器生成的tensor转换成矩阵的形式直接保存于目录中,没有将其显示在tensorboard中。

转换的函数为util.tensor2img,输入为visuals['SR'],维度是[batch,c,h,w],因为这个函数是val时候调用的,而val的batchsize一般为1,即这里batch = 1。

visuals = model.get_current_visuals()
sr_img = util.tensor2img(visuals['SR'])  # uint8

再来看util.tensor2img函数:

def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)):
    '''
    Converts a torch Tensor into an image Numpy array
    Input: 4D(B,(3/1),H,W), 3D(C,H,W), or 2D(H,W), any range, RGB channel order
    Output: 3D(H,W,C) or 2D(H,W), [0,255], np.uint8 (default)
    '''
    #step1:首先通过squeez()将输入的tensor去掉batch这个维度,变为[c,h,w],
    #之后转为float和cpu,再将tensor的值限制在[0,1]之间(有时候tensor为负值)
    tensor = tensor.squeeze().float().cpu().clamp_(*min_max)  # clamp
    #step1:将tensor进行一个线性拉伸,拉伸到最大值为1,最小值为0
    tensor = (tensor - min_max[0]) / (min_max[1] - min_max[0])  # to range [0,1]
    n_dim = tensor.dim()  
    if n_dim == 4:
        n_img = len(tensor)
        img_np = make_grid(tensor, nrow=int(math.sqrt(n_img)), normalize=False).numpy()
        img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0))  # HWC, BGR
    elif n_dim == 3:  #一般情况下会直接跳到这里
        img_np = tensor.numpy()  #转为array形式
        img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0))  # CHW->HWC, RGB->BGR(因为后续要用cv2保存)
    elif n_dim == 2:
        img_np = tensor.numpy()
    else:
        raise TypeError(
            'Only support 4D, 3D and 2D tensor. But received with dimension: {:d}'.format(n_dim))
    if out_type == np.uint8:
        img_np = (img_np * 255.0).round()  #乘以255转为uint8
        # Important. Unlike matlab, numpy.unit8() WILL NOT round by default.
    return img_np.astype(out_type)
  • 在tensorboard记录

因为我自己想在tensorboard中实时看到val的结果,因此这里直接将生成器生成的图像记录在tensorboard中。

visuals = model.get_current_visuals()

sr_img = visuals['SR']  #得到超分后的图像
min_max = [0,1]
sr_img = sr_img.squeeze().float().cpu().clamp_(*min_max)  # clamp
sr_img = (sr_img - min_max[0]) / (min_max[1] - min_max[0])  # to range [0,1]
bicubic_img = display_transform()(visuals['LQ'].data.cpu())
tb_logger.add_image('compare_bicubic_sr_'+str(idx) ,make_grid([bicubic_img,sr_img],padding=20), current_step)    


def display_transform():
    return Compose([
        ToPILImage(),
        Resize(512),
        ToTensor()
    ])   

step1:首先是得到超分后的图像

step2:借鉴官方的代码,进行阶段和线性拉伸

step3:因为我做的任务图像的超分,因此我希望将我超分后的图像与原图LQ进行对比,由于低清的LQ大小仅为128×128,而超分后的SR大小为512×512,因此这里对LQ进行了变换,主要是缩放到512×512,即一样大。(原因是后续make_grid函数必须要求图的尺寸是一致的)

step4:利用makegrid()函数可视化

dsgan生成图片可视化

  • 官方:记录在tensorboard中
fake_img = torch.clamp(model_g(input_img)[0], min=0, max=1)

blur = filter_low_module(fake_img)
hf = filter_high_module(fake_img)
val_image_list = [
    utils.display_transform()(Naive_LR.data.cpu().squeeze(0)),
    utils.display_transform()(fake_img.data.cpu().squeeze(0)),
    utils.display_transform()(blur.data.cpu().squeeze(0)),
    utils.display_transform()(hf.data.cpu().squeeze(0))]
n_val_images = len(val_image_list)
val_images.extend(val_image_list)

if opt.saving and len(val_loader) > 0:
val_images = torch.stack(val_images)
val_images = torch.chunk(val_images, val_images.size(0) // (n_val_images * 5))
for index, image in enumerate(val_images):
    image = tvutils.make_grid(image, nrow=n_val_images, padding=5)
    writer.add_image('val/target_fake_low_high_' + str(index), image, iteration)

def display_transform():
    return Compose([
        ToPILImage(),
        Resize(400),
        CenterCrop(400),
        ToTensor()
    ])

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值