attentive-gan-derainnet:实现对单张图像去雨并输出注意力图

"""
test model
"""
import os.path as ops
import argparse

import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
import cv2
from skimage.measure import compare_ssim
from skimage.measure import compare_psnr

from attentive_gan_model import derain_drop_net
from config import global_config

CFG = global_config.cfg


def init_args():
    """

    :return:
    """
    parser = argparse.ArgumentParser()
    parser.add_argument('--image_path', type=str, default= r'data/test_data/test_1.png',help='The input image path')
    parser.add_argument('--weights_path', type=str, default=r'weights/derain_gan/derain_gan.ckpt-100000' ,help='The model weights path')
    parser.add_argument('--label_path', type=str, default=None, help='The label image path')

    return parser.parse_args()


def minmax_scale(input_arr):
    """

    :param input_arr:
    :return:
    """
    min_val = np.min(input_arr)
    max_val = np.max(input_arr)

    output_arr = (input_arr - min_val) * 255.0 / (max_val - min_val)

    return output_arr


def visualize_attention_map(attention_map):
    """
    The attention map is a matrix ranging from 0 to 1, where the greater the value,
    the greater attention it suggests
    :param attention_map:
    :return:
    """
    attention_map_color = np.zeros(
        shape=[attention_map.shape[0], attention_map.shape[1], 3],
        dtype=np.uint8
    )

    red_color_map = np.zeros(
        shape=[attention_map.shape[0], attention_map.shape[1]],
        dtype=np.uint8) + 255
    red_color_map = red_color_map * attention_map
    red_color_map = np.array(red_color_map, dtype=np.uint8)

    attention_map_color[:, :, 2] = red_color_map

    return attention_map_color


def test_model(image_path, weights_path, label_path=None):
    """

    :param image_path:
    :param weights_path:
    :param label_path:
    :return:
    """
    assert ops.exists(image_path)
    print('susscess')
    input_tensor = tf.placeholder(dtype=tf.float32,
                                  shape=[CFG.TEST.BATCH_SIZE, CFG.TEST.IMG_HEIGHT, CFG.TEST.IMG_WIDTH, 3],
                                  name='input_tensor'
                                  )

    image = cv2.imread(image_path, cv2.IMREAD_COLOR)
    print('read image')
    image = cv2.resize(image, (CFG.TEST.IMG_WIDTH, CFG.TEST.IMG_HEIGHT), interpolation=cv2.INTER_LINEAR)
    image_vis = image
    image = np.divide(np.array(image, np.float32), 127.5) - 1.0

    label_image_vis = None
    if label_path is not None:
        label_image = cv2.imread(label_path, cv2.IMREAD_COLOR)
        label_image_vis = cv2.resize(
            label_image, (CFG.TEST.IMG_WIDTH, CFG.TEST.IMG_HEIGHT), interpolation=cv2.INTER_LINEAR
        )

    phase = tf.constant('test', tf.string)

    net = derain_drop_net.DeRainNet(phase=phase)
    output, attention_maps = net.inference(input_tensor=input_tensor, name='derain_net')

    # Set sess configuration
    sess_config = tf.ConfigProto(allow_soft_placement=True)
    sess_config.gpu_options.per_process_gpu_memory_fraction = CFG.TEST.GPU_MEMORY_FRACTION
    sess_config.gpu_options.allow_growth = CFG.TEST.TF_ALLOW_GROWTH
    sess_config.gpu_options.allocator_type = 'BFC'

    sess = tf.Session(config=sess_config)

    saver = tf.train.Saver()

    with sess.as_default():
        saver.restore(sess=sess, save_path=weights_path)

        output_image, atte_maps = sess.run(
            [output, attention_maps],
            feed_dict={input_tensor: np.expand_dims(image, 0)})

        output_image = output_image[0]
        for i in range(output_image.shape[2]):
            output_image[:, :, i] = minmax_scale(output_image[:, :, i])

        output_image = np.array(output_image, np.uint8)

        if label_path is not None:
            label_image_vis_gray = cv2.cvtColor(label_image_vis, cv2.COLOR_BGR2YCR_CB)[:, :, 0]
            output_image_gray = cv2.cvtColor(output_image, cv2.COLOR_BGR2YCR_CB)[:, :, 0]
            psnr = compare_psnr(label_image_vis_gray, output_image_gray)
            ssim = compare_ssim(label_image_vis_gray, output_image_gray)

            print('SSIM: {:.5f}'.format(ssim))
            print('PSNR: {:.5f}'.format(psnr))

        # 保存并可视化结果
        cv2.imwrite('src_img.png', image_vis)
        cv2.imwrite('derain_ret.png', output_image)

        plt.figure('src_image')
        plt.title('Original Image')
        plt.imshow(image_vis[:, :, (2, 1, 0)])


        plt.figure('derain_ret')
        plt.title('Derain Image')
        plt.imshow(output_image[:, :, (2, 1, 0)])
        plt.show()

        plt.figure ('attention map')
        plt.subplot (221)
        plt.suptitle ('Attention Map')
        plt.title('atte_map_1')
        plt.imshow(atte_maps[0][0, :, :, 0], cmap='jet')
        plt.savefig('atte_map_1.png')

        plt.subplot (222)
        plt.title('atte_map_2')
        plt.imshow(atte_maps[1][0, :, :, 0], cmap='jet')
        plt.savefig('atte_map_2.png')

        plt.subplot (223)
        plt.title('atte_map_3')
        plt.imshow(atte_maps[2][0, :, :, 0], cmap='jet')
        plt.savefig('atte_map_3.png')

        plt.subplot (224)
        plt.title('atte_map_4')
        plt.imshow(atte_maps[3][0, :, :, 0], cmap='jet')
        plt.savefig('atte_map_4.png')
        plt.show()

    return


if __name__ == '__main__':
    # init args
    # args = init_args()

    # test model
    img_path = r'data/test_data/test_2.png'
    weight_path = r'weights/derain_gan/derain_gan.ckpt-100000'
    test_model(img_path,weight_path)

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

  • 1
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 3
    评论
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值