pytorch训练模型时可视化模型的输入

1.用tensorboard可视化

#tensorboard显示输入图像
import numpy as np
imgGT = (train_data['GT'].cpu().detach().numpy()[0])*255
imgGT = imgGT.astype(np.uint8)
tb_logger.add_image(train_data['NameGT'][0],imgGT,global_step=1,dataformats='CHW')

imgLQ = (train_data['LQs'].cpu().detach()[0]) * 255
imgLQ = imgLQ.type(torch.ByteTensor)

import torchvision.utils as vutils
imgLQCat = vutils.make_grid(imgLQ)
tb_logger.add_image(train_data['NameLQ'][0],
                    imgLQCat,
                  global_step=1)

 

           if glob_iter % 100 == 0:
                I1_ori_img = cv2.normalize(I.cpu().detach().numpy()[0, 0, ...], None, 0, 255, cv2.NORM_MINMAX,
                                           cv2.CV_8U)
                I2_ori_img_fig = cv2.normalize(I2_ori_img.cpu().detach().numpy()[0, 0, ...], None, 0, 255, cv2.NORM_MINMAX,
                                            cv2.CV_8U)
                input_I2 = cv2.normalize(I2.cpu().detach().numpy()[0, 0, ...], None, 0, 255, cv2.NORM_MINMAX, cv2.CV_8U)
                pred_I2 = cv2.normalize(pred_I2.cpu().detach().numpy()[0, 0, ...], None, 0, 255, cv2.NORM_MINMAX,
                                        cv2.CV_8U)

                writer.add_image('I1 and I2',
                                 I1_ori_img,
                                 global_step=1,
                                 dataformats='HW')
                writer.add_image('I1 and I2',
                                 I2_ori_img_fig,
                                 global_step=2,
                                 dataformats='HW')

                writer.add_image('I2 and pred_I2',
                                 input_I2,
                                 global_step=1,
                                 dataformats='HW')
                writer.add_image('I2 and pred_I2',
                                 pred_I2,
                                 global_step=2,
                                 dataformats='HW')

2.当我们创建好表示数据集的xxx.py时,可以测试一下模型的输入

if __name__ == '__main__':
    from torch.utils.data import DataLoader
    import matplotlib.pyplot as plt
    import argparse

    opt= {'dataroot_LQ':'/media/mlxuan/LinuxH/competition/superreso/EDSR/edsr_4k/datasets/Tencent/X4','dataroot_GT':'/media/mlxuan/LinuxH/competition/superreso/EDSR/edsr_4k/datasets/Tencent/gt','phase':'train','scale':4,'N_frames':5,'use_shuffle':True,'n_workers':0,'batch_size':32,'use_flip':True,'use_rot':True,'color':'RGB','GT_size':256,'LQ_size':64,'interval_list':[1]}
    traindata = V4MLXDataset(opt)

    dataloader = DataLoader(traindata, batch_size=5, shuffle=True, num_workers=16)


    img_mean=[0.485, 0.456, 0.406]#RGB的顺序
    img_std = [0.229,0.224,0.225]#RGB的顺序
    for ii, img in enumerate(dataloader):
        for j in range(1):

            imgLq = img['LQs'][j].numpy()
            # imgLq[:,0,:,:] = imgLq[:,0,:,:]*img_std[0]+img_mean[0]
            # imgLq[:,1,:,:] = imgLq[:,1,:,:] * img_std[1] + img_mean[1]
            # imgLq[:,2,:,:] = imgLq[:,2,:,:] * img_std[2] + img_mean[2]
            imgLq = imgLq*255

            imgGT = img['GT'][j].numpy()
            # imgGT[ 0] = imgGT[0] * img_std[0] + img_mean[0]
            # imgGT[ 1] = imgGT[ 1] * img_std[1] + img_mean[1]
            # imgGT[2] = imgGT[2] * img_std[2] + img_mean[2]
            imgGT = imgGT * 255


            imgLq = imgLq.astype(np.uint8)
            imgGT = imgGT.astype(np.uint8)


            plt.figure()
            plt.title('display')
            plt.subplot(211)
            plt.imshow(imgLq[2].swapaxes(0,2))
            plt.subplot(212)
            plt.imshow(imgGT.swapaxes(0,2))
        print(ii)
        if ii == 30:
            break

    plt.show(block=True)

 

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值