ZTE捞月比赛

官方代码

test-TF

import os
import numpy as np
import rawpy
import tensorflow as tf
import skimage.metrics
from matplotlib import pyplot as plt
from unetTF import unet
import argparse


def normalization(input_data, black_level, white_level):
    output_data = (input_data.astype(float) - black_level) / (white_level - black_level)
    return output_data


def inv_normalization(input_data, black_level, white_level):
    output_data = np.clip(input_data, 0., 1.) * (white_level - black_level) + black_level
    output_data = output_data.astype(np.uint16)
    return output_data


def write_back_dng(src_path, dest_path, raw_data):
    """
    replace dng data
    """
    width = raw_data.shape[0]
    height = raw_data.shape[1]
    falsie = os.path.getsize(src_path)
    data_len = width * height * 2
    header_len = 8

    with open(src_path, "rb") as f_in:
        data_all = f_in.read(falsie)
        dng_format = data_all[5] + data_all[6] + data_all[7]

    with open(src_path, "rb") as f_in:
        header = f_in.read(header_len)
        if dng_format != 0:
            _ = f_in.read(data_len)
            meta = f_in.read(falsie - header_len - data_len)
        else:
            meta = f_in.read(falsie - header_len - data_len)
            _ = f_in.read(data_len)

        data = raw_data.tobytes()

    with open(dest_path, "wb") as f_out:
        f_out.write(header)
        if dng_format != 0:
            f_out.write(data)
            f_out.write(meta)
        else:
            f_out.write(meta)
            f_out.write(data)

    if os.path.getsize(src_path) != os.path.getsize(dest_path):
        print("replace raw data failed, file size mismatch!")
    else:
        print("replace raw data finished")


def read_image(input_path):
    raw = rawpy.imread(input_path)
    raw_data = raw.raw_image_visible
    height = raw_data.shape[0]
    width = raw_data.shape[1]

    raw_data_expand = np.expand_dims(raw_data, axis=2)
    raw_data_expand_c = np.concatenate((raw_data_expand[0:height:2, 0:width:2, :],
                                        raw_data_expand[0:height:2, 1:width:2, :],
                                        raw_data_expand[1:height:2, 0:width:2, :],
                                        raw_data_expand[1:height:2, 1:width:2, :]), axis=2)
    return raw_data_expand_c, height, width


def write_image(input_data, height, width):
    output_data = np.zeros((height, width), dtype=np.uint16)
    for channel_y in range(2):
        for channel_x in range(2):
            output_data[channel_y:height:2, channel_x:width:2] = input_data[0, :, :, 2 * channel_y + channel_x]
    return output_data


def denoise_raw(input_path, output_path, ground_path, model_path, black_level, white_level):
    """
    Example: obtain ground truth
    """
    gt = rawpy.imread(ground_path).raw_image_visible

    """
    pre-process
    """
    raw_data_channels, height, width = read_image(input_path)
    raw_data_channels_normal = normalization(raw_data_channels, black_level, white_level)
    raw_data_channels_normal = tf.convert_to_tensor(np.reshape(raw_data_channels_normal,
                                                               (1, raw_data_channels_normal.shape[0],
                                                                raw_data_channels_normal.shape[1], 4)))
    print(raw_data_channels_normal.shape[1])
    print(raw_data_channels_normal.shape[1])
    net = unet((raw_data_channels_normal.shape[1], raw_data_channels_normal.shape[2], 4))
    net.load_weights(model_path)

    """
    inference
    """
    result_data = net(raw_data_channels_normal, training=False)
    result_data = result_data.numpy()

    """
    post-process
    """
    result_data = inv_normalization(result_data, black_level, white_level)
    result_write_data = write_image(result_data, height, width)
    write_back_dng(input_path, output_path, result_write_data)

    """
    obtain psnr and ssim
    """
    psnr = skimage.metrics.peak_signal_noise_ratio(
        gt.astype(np.float), result_write_data.astype(np.float), data_range=white_level)
    ssim = skimage.metrics.structural_similarity(
        gt.astype(np.float), result_write_data.astype(np.float), multichannel=True, data_range=white_level)
    print('psnr:', psnr)
    print('ssim:', ssim)

    """
    Example: this demo_code shows your input or gt or result image
    """
    f0 = rawpy.imread(ground_path)
    f1 = rawpy.imread(input_path)
    f2 = rawpy.imread(output_path)
    f, axarr = plt.subplots(1, 3)
    axarr[0].imshow(f0.postprocess(use_camera_wb=True))
    axarr[1].imshow(f1.postprocess(use_camera_wb=True))
    axarr[2].imshow(f2.postprocess(use_camera_wb=True))
    axarr[0].set_title('gt')
    axarr[1].set_title('noisy')
    axarr[2].set_title('de-noise')
    plt.show()


def main(args):
    model_path = args.model_path
    black_level = args.black_level
    white_level = args.white_level
    input_path = args.input_path
    output_path = args.output_path
    ground_path = args.ground_path
    denoise_raw(input_path, output_path, ground_path, model_path, black_level, white_level)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--model_path', type=str, default="./models/tf_model.h5")
    parser.add_argument('--black_level', type=int, default=1024)
    parser.add_argument('--white_level', type=int, default=16383)
    parser.add_argument('--input_path', type=str, default="./testset/noisy0.dng")
    parser.add_argument('--output_path', type=str, default="./data/denoise0.dng")
    parser.add_argument('--ground_path', type=str, default="./testset/noisy0.dng")

    args = parser.parse_args()
    main(args)

mode_tf

import tensorflow as tf
import tensorflow.keras as keras
import tensorflow.keras.layers as layers
import numpy as np


def unet(input_shape, out_channel=4, kernel=3, pool_size=(2, 2), feature_base=32):
    input_holder = layers.Input(shape=input_shape)

    n, h, w, c = input_holder.shape
    h_pad = 32 - h % 32 if not h % 32 == 0 else 0
    w_pad = 32 - w % 32 if not w % 32 == 0 else 0
    padded_image = tf.pad(input_holder, [[0, 0], [0, h_pad], [0, w_pad], [0, 0]], "reflect")

    conv1_1 = layers.Conv2D(feature_base, (kernel, kernel), padding="same")(padded_image)
    conv1_1 = layers.LeakyReLU(0.2)(conv1_1)
    conv1_2 = layers.Conv2D(feature_base, (kernel, kernel), padding="same")(conv1_1)
    conv1_2 = layers.LeakyReLU(0.2)(conv1_2)

    pool_1 = layers.MaxPool2D(pool_size)(conv1_2)

    conv2_1 = layers.Conv2D(feature_base * 2, (kernel, kernel), padding="same")(pool_1)
    conv2_1 = layers.LeakyReLU(0.2)(conv2_1)
    conv2_2 = layers.Conv2D(feature_base * 2, (kernel, kernel), padding="same")(conv2_1)
    conv2_2 = layers.LeakyReLU(0.2)(conv2_2)

    pool_2 = layers.MaxPool2D(pool_size)(conv2_2)

    conv3_1 = layers.Conv2D(feature_base * 4, (kernel, kernel), padding="same")(pool_2)
    conv3_1 = layers.LeakyReLU(0.2)(conv3_1)
    conv3_2 = layers.Conv2D(feature_base * 4, (kernel, kernel), padding="same")(conv3_1)
    conv3_2 = layers.LeakyReLU(0.2)(conv3_2)

    pool_3 = layers.MaxPool2D(pool_size)(conv3_2)

    conv4_1 = layers.Conv2D(feature_base * 8, (kernel, kernel), padding="same")(pool_3)
    conv4_1 = layers.LeakyReLU(0.2)(conv4_1)
    conv4_2 = layers.Conv2D(feature_base * 8, (kernel, kernel), padding="same")(conv4_1)
    conv4_2 = layers.LeakyReLU(0.2)(conv4_2)

    pool_4 = layers.MaxPool2D(pool_size)(conv4_2)

    conv5_1 = layers.Conv2D(feature_base * 16, (kernel, kernel), padding="same")(pool_4)
    conv5_1 = layers.LeakyReLU(0.2)(conv5_1)
    conv5_2 = layers.Conv2D(feature_base * 16, (kernel, kernel), padding="same")(conv5_1)
    conv5_2 = layers.LeakyReLU(0.2)(conv5_2)

    unpool1 = layers.Conv2DTranspose(feature_base * 8, pool_size, (2, 2), "same")(conv5_2)
    concat1 = layers.Concatenate()([unpool1, conv4_2])
    conv6_1 = layers.Conv2D(feature_base * 8, (kernel, kernel), padding="same")(concat1)
    conv6_1 = layers.LeakyReLU(0.2)(conv6_1)
    conv6_2 = layers.Conv2D(feature_base * 8, (kernel, kernel), padding="same")(conv6_1)
    conv6_2 = layers.LeakyReLU(0.2)(conv6_2)

    unpool2 = layers.Conv2DTranspose(feature_base * 4, pool_size, (2, 2), "same")(conv6_2)
    concat2 = layers.Concatenate()([unpool2, conv3_2])
    conv7_1 = layers.Conv2D(feature_base * 4, (kernel, kernel), padding="same")(concat2)
    conv7_1 = layers.LeakyReLU(0.2)(conv7_1)
    conv7_2 = layers.Conv2D(feature_base * 4, (kernel, kernel), padding="same")(conv7_1)
    conv7_2 = layers.LeakyReLU(0.2)(conv7_2)

    unpool3 = layers.Conv2DTranspose(feature_base * 2, pool_size, (2, 2), "same")(conv7_2)
    concat3 = layers.Concatenate()([unpool3, conv2_2])
    conv8_1 = layers.Conv2D(feature_base * 2, (kernel, kernel), padding="same")(concat3)
    conv8_1 = layers.LeakyReLU(0.2)(conv8_1)
    conv8_2 = layers.Conv2D(feature_base * 2, (kernel, kernel), padding="same")(conv8_1)
    conv8_2 = layers.LeakyReLU(0.2)(conv8_2)

    unpool4 = layers.Conv2DTranspose(feature_base, pool_size, (2, 2), "same")(conv8_2)
    concat4 = layers.Concatenate()([unpool4, conv1_2])
    conv9_1 = layers.Conv2D(feature_base, (kernel, kernel), padding="same")(concat4)
    conv9_1 = layers.LeakyReLU(0.2)(conv9_1)
    conv9_2 = layers.Conv2D(feature_base, (kernel, kernel), padding="same")(conv9_1)
    conv9_2 = layers.LeakyReLU(0.2)(conv9_2)

    out = layers.Conv2D(out_channel, (1, 1), padding="same")(conv9_2)
    out_holder = out[:, :h, :w, :]

    net_model = keras.Model(inputs=input_holder, outputs=out_holder)
    return net_model


if __name__ == "__main__":
    test_input = tf.convert_to_tensor(np.random.randn(1, 512, 512, 4))
    net = unet((512, 512, 4))
    net.summary()
    output = net(test_input, training=False)
    print("test over")

test_torch

import os
import numpy as np
import rawpy
import torch
import skimage.metrics
from matplotlib import pyplot as plt
from unetTorch import Unet
import argparse


def normalization(input_data, black_level, white_level):
    output_data = (input_data.astype(float) - black_level) / (white_level - black_level)
    return output_data


def inv_normalization(input_data, black_level, white_level):
    output_data = np.clip(input_data, 0., 1.) * (white_level - black_level) + black_level
    output_data = output_data.astype(np.uint16)
    return output_data


def write_back_dng(src_path, dest_path, raw_data):
    """
    replace dng data
    """
    width = raw_data.shape[0]
    height = raw_data.shape[1]
    falsie = os.path.getsize(src_path)
    data_len = width * height * 2
    header_len = 8

    with open(src_path, "rb") as f_in:
        data_all = f_in.read(falsie)
        dng_format = data_all[5] + data_all[6] + data_all[7]

    with open(src_path, "rb") as f_in:
        header = f_in.read(header_len)
        if dng_format != 0:
            _ = f_in.read(data_len)
            meta = f_in.read(falsie - header_len - data_len)
        else:
            meta = f_in.read(falsie - header_len - data_len)
            _ = f_in.read(data_len)

        data = raw_data.tobytes()

    with open(dest_path, "wb") as f_out:
        f_out.write(header)
        if dng_format != 0:
            f_out.write(data)
            f_out.write(meta)
        else:
            f_out.write(meta)
            f_out.write(data)

    if os.path.getsize(src_path) != os.path.getsize(dest_path):
        print("replace raw data failed, file size mismatch!")
    else:
        print("replace raw data finished")


def read_image(input_path):
    raw = rawpy.imread(input_path)
    raw_data = raw.raw_image_visible
    height = raw_data.shape[0]
    width = raw_data.shape[1]

    raw_data_expand = np.expand_dims(raw_data, axis=2)
    raw_data_expand_c = np.concatenate((raw_data_expand[0:height:2, 0:width:2, :],
                                        raw_data_expand[0:height:2, 1:width:2, :],
                                        raw_data_expand[1:height:2, 0:width:2, :],
                                        raw_data_expand[1:height:2, 1:width:2, :]), axis=2)
    return raw_data_expand_c, height, width


def write_image(input_data, height, width):
    output_data = np.zeros((height, width), dtype=np.uint16)
    for channel_y in range(2):
        for channel_x in range(2):
            output_data[channel_y:height:2, channel_x:width:2] = input_data[0:, :, :, 2 * channel_y + channel_x]
    return output_data


def denoise_raw(input_path, output_path, ground_path, model_path, black_level, white_level):
    """
    Example: obtain ground truth
    """
    gt = rawpy.imread(ground_path).raw_image_visible 

    """
    pre-process
    """
    raw_data_expand_c, height, width = read_image(input_path)
    raw_data_expand_c_normal = normalization(raw_data_expand_c, black_level, white_level)
    raw_data_expand_c_normal = torch.from_numpy(np.transpose(
        raw_data_expand_c_normal.reshape(-1, height//2, width//2, 4), (0, 3, 1, 2))).float()
    net = Unet()
    if model_path is not None:
        net.load_state_dict(torch.load(model_path))
    net.eval()

    """
    inference
    """
    result_data = net(raw_data_expand_c_normal)

    """
    post-process
    """
    result_data = result_data.cpu().detach().numpy().transpose(0, 2, 3, 1)
    result_data = inv_normalization(result_data, black_level, white_level)
    result_write_data = write_image(result_data, height, width)
    write_back_dng(input_path, output_path, result_write_data)

    """
    obtain psnr and ssim
    """
    psnr = skimage.metrics.peak_signal_noise_ratio(
        gt.astype(np.float), result_write_data.astype(np.float), data_range=white_level)
    ssim = skimage.metrics.structural_similarity(
        gt.astype(np.float), result_write_data.astype(np.float), multichannel=True, data_range=white_level)
    print('psnr:', psnr)
    print('ssim:', ssim)

    """
    Example: this demo_code shows your input or gt or result image
    """
    f0 = rawpy.imread(ground_path)
    f1 = rawpy.imread(input_path)
    f2 = rawpy.imread(output_path)
    f, axarr = plt.subplots(1, 3)
    axarr[0].imshow(f0.postprocess(use_camera_wb=True))
    axarr[1].imshow(f1.postprocess(use_camera_wb=True))
    axarr[2].imshow(f2.postprocess(use_camera_wb=True))
    axarr[0].set_title('gt')
    axarr[1].set_title('noisy')
    axarr[2].set_title('de-noise')


def main(args):
    model_path = args.model_path
    black_level = args.black_level
    white_level = args.white_level
    input_path = args.input_path
    output_path = args.output_path
    ground_path = args.ground_path

    denoise_raw(input_path, output_path, ground_path, model_path, black_level, white_level)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--model_path', type=str, default="./models/th_model.pth")
    parser.add_argument('--black_level', type=int, default=1024)
    parser.add_argument('--white_level', type=int, default=16383)
    parser.add_argument('--input_path', type=str, default="./testset/noisy0.dng")
    parser.add_argument('--output_path', type=str, default="./data/denoise0.dng")
    parser.add_argument('--ground_path', type=str, default="./testset/noisy0.dng")
    args = parser.parse_args()
    main(args)

model_torch

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

class Unet(nn.Module):
    def __init__(self, in_channels=4, out_channels=4):
        super(Unet, self).__init__()

        # device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
        self.conv1_1 = nn.Conv2d(in_channels, 32, kernel_size=3, stride=1, padding=1)
        self.conv1_2 = nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1)
        self.pool1 = nn.MaxPool2d(kernel_size=2)

        self.conv2_1 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.conv2_2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
        self.pool2 = nn.MaxPool2d(kernel_size=2)

        self.conv3_1 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
        self.conv3_2 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1)
        self.pool3 = nn.MaxPool2d(kernel_size=2)

        self.conv4_1 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1)
        self.conv4_2 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
        self.pool4 = nn.MaxPool2d(kernel_size=2)

        self.conv5_1 = nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1)
        self.conv5_2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)

        self.upv6 = nn.ConvTranspose2d(512, 256, 2, stride=2)
        self.conv6_1 = nn.Conv2d(512, 256, kernel_size=3, stride=1, padding=1)
        self.conv6_2 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)

        self.upv7 = nn.ConvTranspose2d(256, 128, 2, stride=2)
        self.conv7_1 = nn.Conv2d(256, 128, kernel_size=3, stride=1, padding=1)
        self.conv7_2 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1)

        self.upv8 = nn.ConvTranspose2d(128, 64, 2, stride=2)
        self.conv8_1 = nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1)
        self.conv8_2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)

        self.upv9 = nn.ConvTranspose2d(64, 32, 2, stride=2)
        self.conv9_1 = nn.Conv2d(64, 32, kernel_size=3, stride=1, padding=1)
        self.conv9_2 = nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1)

        self.conv10_1 = nn.Conv2d(32, out_channels, kernel_size=1, stride=1)

    def forward(self, x):
        n, c, h, w = x.shape
        h_pad = 32 - h % 32 if not h % 32 == 0 else 0
        w_pad = 32 - w % 32 if not w % 32 == 0 else 0
        padded_image = F.pad(x, (0, w_pad, 0, h_pad), 'replicate')

        conv1 = self.leaky_relu(self.conv1_1(padded_image))
        conv1 = self.leaky_relu(self.conv1_2(conv1))
        pool1 = self.pool1(conv1)

        conv2 = self.leaky_relu(self.conv2_1(pool1))
        conv2 = self.leaky_relu(self.conv2_2(conv2))
        pool2 = self.pool1(conv2)

        conv3 = self.leaky_relu(self.conv3_1(pool2))
        conv3 = self.leaky_relu(self.conv3_2(conv3))
        pool3 = self.pool1(conv3)

        conv4 = self.leaky_relu(self.conv4_1(pool3))
        conv4 = self.leaky_relu(self.conv4_2(conv4))
        pool4 = self.pool1(conv4)

        conv5 = self.leaky_relu(self.conv5_1(pool4))
        conv5 = self.leaky_relu(self.conv5_2(conv5))

        up6 = self.upv6(conv5)
        up6 = torch.cat([up6, conv4], 1)
        conv6 = self.leaky_relu(self.conv6_1(up6))
        conv6 = self.leaky_relu(self.conv6_2(conv6))

        up7 = self.upv7(conv6)
        up7 = torch.cat([up7, conv3], 1)
        conv7 = self.leaky_relu(self.conv7_1(up7))
        conv7 = self.leaky_relu(self.conv7_2(conv7))

        up8 = self.upv8(conv7)
        up8 = torch.cat([up8, conv2], 1)
        conv8 = self.leaky_relu(self.conv8_1(up8))
        conv8 = self.leaky_relu(self.conv8_2(conv8))

        up9 = self.upv9(conv8)
        up9 = torch.cat([up9, conv1], 1)
        conv9 = self.leaky_relu(self.conv9_1(up9))
        conv9 = self.leaky_relu(self.conv9_2(conv9))

        conv10 = self.conv10_1(conv9)
        out = conv10[:, :, :h, :w]

        return out

    def leaky_relu(self, x):
        out = torch.max(0.2 * x, x)
        return out


if __name__ == "__main__":
    test_input = torch.from_numpy(np.random.randn(1, 4, 512, 512)).float()
    net = Unet()
    output = net(test_input)
    print("test over")

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值