官方代码
test-TF
import os
import numpy as np
import rawpy
import tensorflow as tf
import skimage.metrics
from matplotlib import pyplot as plt
from unetTF import unet
import argparse
def normalization(input_data, black_level, white_level):
output_data = (input_data.astype(float) - black_level) / (white_level - black_level)
return output_data
def inv_normalization(input_data, black_level, white_level):
output_data = np.clip(input_data, 0., 1.) * (white_level - black_level) + black_level
output_data = output_data.astype(np.uint16)
return output_data
def write_back_dng(src_path, dest_path, raw_data):
"""
replace dng data
"""
width = raw_data.shape[0]
height = raw_data.shape[1]
falsie = os.path.getsize(src_path)
data_len = width * height * 2
header_len = 8
with open(src_path, "rb") as f_in:
data_all = f_in.read(falsie)
dng_format = data_all[5] + data_all[6] + data_all[7]
with open(src_path, "rb") as f_in:
header = f_in.read(header_len)
if dng_format != 0:
_ = f_in.read(data_len)
meta = f_in.read(falsie - header_len - data_len)
else:
meta = f_in.read(falsie - header_len - data_len)
_ = f_in.read(data_len)
data = raw_data.tobytes()
with open(dest_path, "wb") as f_out:
f_out.write(header)
if dng_format != 0:
f_out.write(data)
f_out.write(meta)
else:
f_out.write(meta)
f_out.write(data)
if os.path.getsize(src_path) != os.path.getsize(dest_path):
print("replace raw data failed, file size mismatch!")
else:
print("replace raw data finished")
def read_image(input_path):
raw = rawpy.imread(input_path)
raw_data = raw.raw_image_visible
height = raw_data.shape[0]
width = raw_data.shape[1]
raw_data_expand = np.expand_dims(raw_data, axis=2)
raw_data_expand_c = np.concatenate((raw_data_expand[0:height:2, 0:width:2, :],
raw_data_expand[0:height:2, 1:width:2, :],
raw_data_expand[1:height:2, 0:width:2, :],
raw_data_expand[1:height:2, 1:width:2, :]), axis=2)
return raw_data_expand_c, height, width
def write_image(input_data, height, width):
output_data = np.zeros((height, width), dtype=np.uint16)
for channel_y in range(2):
for channel_x in range(2):
output_data[channel_y:height:2, channel_x:width:2] = input_data[0, :, :, 2 * channel_y + channel_x]
return output_data
def denoise_raw(input_path, output_path, ground_path, model_path, black_level, white_level):
"""
Example: obtain ground truth
"""
gt = rawpy.imread(ground_path).raw_image_visible
"""
pre-process
"""
raw_data_channels, height, width = read_image(input_path)
raw_data_channels_normal = normalization(raw_data_channels, black_level, white_level)
raw_data_channels_normal = tf.convert_to_tensor(np.reshape(raw_data_channels_normal,
(1, raw_data_channels_normal.shape[0],
raw_data_channels_normal.shape[1], 4)))
print(raw_data_channels_normal.shape[1])
print(raw_data_channels_normal.shape[1])
net = unet((raw_data_channels_normal.shape[1], raw_data_channels_normal.shape[2], 4))
net.load_weights(model_path)
"""
inference
"""
result_data = net(raw_data_channels_normal, training=False)
result_data = result_data.numpy()
"""
post-process
"""
result_data = inv_normalization(result_data, black_level, white_level)
result_write_data = write_image(result_data, height, width)
write_back_dng(input_path, output_path, result_write_data)
"""
obtain psnr and ssim
"""
psnr = skimage.metrics.peak_signal_noise_ratio(
gt.astype(np.float), result_write_data.astype(np.float), data_range=white_level)
ssim = skimage.metrics.structural_similarity(
gt.astype(np.float), result_write_data.astype(np.float), multichannel=True, data_range=white_level)
print('psnr:', psnr)
print('ssim:', ssim)
"""
Example: this demo_code shows your input or gt or result image
"""
f0 = rawpy.imread(ground_path)
f1 = rawpy.imread(input_path)
f2 = rawpy.imread(output_path)
f, axarr = plt.subplots(1, 3)
axarr[0].imshow(f0.postprocess(use_camera_wb=True))
axarr[1].imshow(f1.postprocess(use_camera_wb=True))
axarr[2].imshow(f2.postprocess(use_camera_wb=True))
axarr[0].set_title('gt')
axarr[1].set_title('noisy')
axarr[2].set_title('de-noise')
plt.show()
def main(args):
model_path = args.model_path
black_level = args.black_level
white_level = args.white_level
input_path = args.input_path
output_path = args.output_path
ground_path = args.ground_path
denoise_raw(input_path, output_path, ground_path, model_path, black_level, white_level)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--model_path', type=str, default="./models/tf_model.h5")
parser.add_argument('--black_level', type=int, default=1024)
parser.add_argument('--white_level', type=int, default=16383)
parser.add_argument('--input_path', type=str, default="./testset/noisy0.dng")
parser.add_argument('--output_path', type=str, default="./data/denoise0.dng")
parser.add_argument('--ground_path', type=str, default="./testset/noisy0.dng")
args = parser.parse_args()
main(args)
mode_tf
import tensorflow as tf
import tensorflow.keras as keras
import tensorflow.keras.layers as layers
import numpy as np
def unet(input_shape, out_channel=4, kernel=3, pool_size=(2, 2), feature_base=32):
input_holder = layers.Input(shape=input_shape)
n, h, w, c = input_holder.shape
h_pad = 32 - h % 32 if not h % 32 == 0 else 0
w_pad = 32 - w % 32 if not w % 32 == 0 else 0
padded_image = tf.pad(input_holder, [[0, 0], [0, h_pad], [0, w_pad], [0, 0]], "reflect")
conv1_1 = layers.Conv2D(feature_base, (kernel, kernel), padding="same")(padded_image)
conv1_1 = layers.LeakyReLU(0.2)(conv1_1)
conv1_2 = layers.Conv2D(feature_base, (kernel, kernel), padding="same")(conv1_1)
conv1_2 = layers.LeakyReLU(0.2)(conv1_2)
pool_1 = layers.MaxPool2D(pool_size)(conv1_2)
conv2_1 = layers.Conv2D(feature_base * 2, (kernel, kernel), padding="same")(pool_1)
conv2_1 = layers.LeakyReLU(0.2)(conv2_1)
conv2_2 = layers.Conv2D(feature_base * 2, (kernel, kernel), padding="same")(conv2_1)
conv2_2 = layers.LeakyReLU(0.2)(conv2_2)
pool_2 = layers.MaxPool2D(pool_size)(conv2_2)
conv3_1 = layers.Conv2D(feature_base * 4, (kernel, kernel), padding="same")(pool_2)
conv3_1 = layers.LeakyReLU(0.2)(conv3_1)
conv3_2 = layers.Conv2D(feature_base * 4, (kernel, kernel), padding="same")(conv3_1)
conv3_2 = layers.LeakyReLU(0.2)(conv3_2)
pool_3 = layers.MaxPool2D(pool_size)(conv3_2)
conv4_1 = layers.Conv2D(feature_base * 8, (kernel, kernel), padding="same")(pool_3)
conv4_1 = layers.LeakyReLU(0.2)(conv4_1)
conv4_2 = layers.Conv2D(feature_base * 8, (kernel, kernel), padding="same")(conv4_1)
conv4_2 = layers.LeakyReLU(0.2)(conv4_2)
pool_4 = layers.MaxPool2D(pool_size)(conv4_2)
conv5_1 = layers.Conv2D(feature_base * 16, (kernel, kernel), padding="same")(pool_4)
conv5_1 = layers.LeakyReLU(0.2)(conv5_1)
conv5_2 = layers.Conv2D(feature_base * 16, (kernel, kernel), padding="same")(conv5_1)
conv5_2 = layers.LeakyReLU(0.2)(conv5_2)
unpool1 = layers.Conv2DTranspose(feature_base * 8, pool_size, (2, 2), "same")(conv5_2)
concat1 = layers.Concatenate()([unpool1, conv4_2])
conv6_1 = layers.Conv2D(feature_base * 8, (kernel, kernel), padding="same")(concat1)
conv6_1 = layers.LeakyReLU(0.2)(conv6_1)
conv6_2 = layers.Conv2D(feature_base * 8, (kernel, kernel), padding="same")(conv6_1)
conv6_2 = layers.LeakyReLU(0.2)(conv6_2)
unpool2 = layers.Conv2DTranspose(feature_base * 4, pool_size, (2, 2), "same")(conv6_2)
concat2 = layers.Concatenate()([unpool2, conv3_2])
conv7_1 = layers.Conv2D(feature_base * 4, (kernel, kernel), padding="same")(concat2)
conv7_1 = layers.LeakyReLU(0.2)(conv7_1)
conv7_2 = layers.Conv2D(feature_base * 4, (kernel, kernel), padding="same")(conv7_1)
conv7_2 = layers.LeakyReLU(0.2)(conv7_2)
unpool3 = layers.Conv2DTranspose(feature_base * 2, pool_size, (2, 2), "same")(conv7_2)
concat3 = layers.Concatenate()([unpool3, conv2_2])
conv8_1 = layers.Conv2D(feature_base * 2, (kernel, kernel), padding="same")(concat3)
conv8_1 = layers.LeakyReLU(0.2)(conv8_1)
conv8_2 = layers.Conv2D(feature_base * 2, (kernel, kernel), padding="same")(conv8_1)
conv8_2 = layers.LeakyReLU(0.2)(conv8_2)
unpool4 = layers.Conv2DTranspose(feature_base, pool_size, (2, 2), "same")(conv8_2)
concat4 = layers.Concatenate()([unpool4, conv1_2])
conv9_1 = layers.Conv2D(feature_base, (kernel, kernel), padding="same")(concat4)
conv9_1 = layers.LeakyReLU(0.2)(conv9_1)
conv9_2 = layers.Conv2D(feature_base, (kernel, kernel), padding="same")(conv9_1)
conv9_2 = layers.LeakyReLU(0.2)(conv9_2)
out = layers.Conv2D(out_channel, (1, 1), padding="same")(conv9_2)
out_holder = out[:, :h, :w, :]
net_model = keras.Model(inputs=input_holder, outputs=out_holder)
return net_model
if __name__ == "__main__":
test_input = tf.convert_to_tensor(np.random.randn(1, 512, 512, 4))
net = unet((512, 512, 4))
net.summary()
output = net(test_input, training=False)
print("test over")
test_torch
import os
import numpy as np
import rawpy
import torch
import skimage.metrics
from matplotlib import pyplot as plt
from unetTorch import Unet
import argparse
def normalization(input_data, black_level, white_level):
output_data = (input_data.astype(float) - black_level) / (white_level - black_level)
return output_data
def inv_normalization(input_data, black_level, white_level):
output_data = np.clip(input_data, 0., 1.) * (white_level - black_level) + black_level
output_data = output_data.astype(np.uint16)
return output_data
def write_back_dng(src_path, dest_path, raw_data):
"""
replace dng data
"""
width = raw_data.shape[0]
height = raw_data.shape[1]
falsie = os.path.getsize(src_path)
data_len = width * height * 2
header_len = 8
with open(src_path, "rb") as f_in:
data_all = f_in.read(falsie)
dng_format = data_all[5] + data_all[6] + data_all[7]
with open(src_path, "rb") as f_in:
header = f_in.read(header_len)
if dng_format != 0:
_ = f_in.read(data_len)
meta = f_in.read(falsie - header_len - data_len)
else:
meta = f_in.read(falsie - header_len - data_len)
_ = f_in.read(data_len)
data = raw_data.tobytes()
with open(dest_path, "wb") as f_out:
f_out.write(header)
if dng_format != 0:
f_out.write(data)
f_out.write(meta)
else:
f_out.write(meta)
f_out.write(data)
if os.path.getsize(src_path) != os.path.getsize(dest_path):
print("replace raw data failed, file size mismatch!")
else:
print("replace raw data finished")
def read_image(input_path):
raw = rawpy.imread(input_path)
raw_data = raw.raw_image_visible
height = raw_data.shape[0]
width = raw_data.shape[1]
raw_data_expand = np.expand_dims(raw_data, axis=2)
raw_data_expand_c = np.concatenate((raw_data_expand[0:height:2, 0:width:2, :],
raw_data_expand[0:height:2, 1:width:2, :],
raw_data_expand[1:height:2, 0:width:2, :],
raw_data_expand[1:height:2, 1:width:2, :]), axis=2)
return raw_data_expand_c, height, width
def write_image(input_data, height, width):
output_data = np.zeros((height, width), dtype=np.uint16)
for channel_y in range(2):
for channel_x in range(2):
output_data[channel_y:height:2, channel_x:width:2] = input_data[0:, :, :, 2 * channel_y + channel_x]
return output_data
def denoise_raw(input_path, output_path, ground_path, model_path, black_level, white_level):
"""
Example: obtain ground truth
"""
gt = rawpy.imread(ground_path).raw_image_visible
"""
pre-process
"""
raw_data_expand_c, height, width = read_image(input_path)
raw_data_expand_c_normal = normalization(raw_data_expand_c, black_level, white_level)
raw_data_expand_c_normal = torch.from_numpy(np.transpose(
raw_data_expand_c_normal.reshape(-1, height//2, width//2, 4), (0, 3, 1, 2))).float()
net = Unet()
if model_path is not None:
net.load_state_dict(torch.load(model_path))
net.eval()
"""
inference
"""
result_data = net(raw_data_expand_c_normal)
"""
post-process
"""
result_data = result_data.cpu().detach().numpy().transpose(0, 2, 3, 1)
result_data = inv_normalization(result_data, black_level, white_level)
result_write_data = write_image(result_data, height, width)
write_back_dng(input_path, output_path, result_write_data)
"""
obtain psnr and ssim
"""
psnr = skimage.metrics.peak_signal_noise_ratio(
gt.astype(np.float), result_write_data.astype(np.float), data_range=white_level)
ssim = skimage.metrics.structural_similarity(
gt.astype(np.float), result_write_data.astype(np.float), multichannel=True, data_range=white_level)
print('psnr:', psnr)
print('ssim:', ssim)
"""
Example: this demo_code shows your input or gt or result image
"""
f0 = rawpy.imread(ground_path)
f1 = rawpy.imread(input_path)
f2 = rawpy.imread(output_path)
f, axarr = plt.subplots(1, 3)
axarr[0].imshow(f0.postprocess(use_camera_wb=True))
axarr[1].imshow(f1.postprocess(use_camera_wb=True))
axarr[2].imshow(f2.postprocess(use_camera_wb=True))
axarr[0].set_title('gt')
axarr[1].set_title('noisy')
axarr[2].set_title('de-noise')
def main(args):
model_path = args.model_path
black_level = args.black_level
white_level = args.white_level
input_path = args.input_path
output_path = args.output_path
ground_path = args.ground_path
denoise_raw(input_path, output_path, ground_path, model_path, black_level, white_level)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--model_path', type=str, default="./models/th_model.pth")
parser.add_argument('--black_level', type=int, default=1024)
parser.add_argument('--white_level', type=int, default=16383)
parser.add_argument('--input_path', type=str, default="./testset/noisy0.dng")
parser.add_argument('--output_path', type=str, default="./data/denoise0.dng")
parser.add_argument('--ground_path', type=str, default="./testset/noisy0.dng")
args = parser.parse_args()
main(args)
model_torch
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
class Unet(nn.Module):
def __init__(self, in_channels=4, out_channels=4):
super(Unet, self).__init__()
self.conv1_1 = nn.Conv2d(in_channels, 32, kernel_size=3, stride=1, padding=1)
self.conv1_2 = nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1)
self.pool1 = nn.MaxPool2d(kernel_size=2)
self.conv2_1 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
self.conv2_2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
self.pool2 = nn.MaxPool2d(kernel_size=2)
self.conv3_1 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
self.conv3_2 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1)
self.pool3 = nn.MaxPool2d(kernel_size=2)
self.conv4_1 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1)
self.conv4_2 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
self.pool4 = nn.MaxPool2d(kernel_size=2)
self.conv5_1 = nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1)
self.conv5_2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
self.upv6 = nn.ConvTranspose2d(512, 256, 2, stride=2)
self.conv6_1 = nn.Conv2d(512, 256, kernel_size=3, stride=1, padding=1)
self.conv6_2 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
self.upv7 = nn.ConvTranspose2d(256, 128, 2, stride=2)
self.conv7_1 = nn.Conv2d(256, 128, kernel_size=3, stride=1, padding=1)
self.conv7_2 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1)
self.upv8 = nn.ConvTranspose2d(128, 64, 2, stride=2)
self.conv8_1 = nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1)
self.conv8_2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
self.upv9 = nn.ConvTranspose2d(64, 32, 2, stride=2)
self.conv9_1 = nn.Conv2d(64, 32, kernel_size=3, stride=1, padding=1)
self.conv9_2 = nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1)
self.conv10_1 = nn.Conv2d(32, out_channels, kernel_size=1, stride=1)
def forward(self, x):
n, c, h, w = x.shape
h_pad = 32 - h % 32 if not h % 32 == 0 else 0
w_pad = 32 - w % 32 if not w % 32 == 0 else 0
padded_image = F.pad(x, (0, w_pad, 0, h_pad), 'replicate')
conv1 = self.leaky_relu(self.conv1_1(padded_image))
conv1 = self.leaky_relu(self.conv1_2(conv1))
pool1 = self.pool1(conv1)
conv2 = self.leaky_relu(self.conv2_1(pool1))
conv2 = self.leaky_relu(self.conv2_2(conv2))
pool2 = self.pool1(conv2)
conv3 = self.leaky_relu(self.conv3_1(pool2))
conv3 = self.leaky_relu(self.conv3_2(conv3))
pool3 = self.pool1(conv3)
conv4 = self.leaky_relu(self.conv4_1(pool3))
conv4 = self.leaky_relu(self.conv4_2(conv4))
pool4 = self.pool1(conv4)
conv5 = self.leaky_relu(self.conv5_1(pool4))
conv5 = self.leaky_relu(self.conv5_2(conv5))
up6 = self.upv6(conv5)
up6 = torch.cat([up6, conv4], 1)
conv6 = self.leaky_relu(self.conv6_1(up6))
conv6 = self.leaky_relu(self.conv6_2(conv6))
up7 = self.upv7(conv6)
up7 = torch.cat([up7, conv3], 1)
conv7 = self.leaky_relu(self.conv7_1(up7))
conv7 = self.leaky_relu(self.conv7_2(conv7))
up8 = self.upv8(conv7)
up8 = torch.cat([up8, conv2], 1)
conv8 = self.leaky_relu(self.conv8_1(up8))
conv8 = self.leaky_relu(self.conv8_2(conv8))
up9 = self.upv9(conv8)
up9 = torch.cat([up9, conv1], 1)
conv9 = self.leaky_relu(self.conv9_1(up9))
conv9 = self.leaky_relu(self.conv9_2(conv9))
conv10 = self.conv10_1(conv9)
out = conv10[:, :, :h, :w]
return out
def leaky_relu(self, x):
out = torch.max(0.2 * x, x)
return out
if __name__ == "__main__":
test_input = torch.from_numpy(np.random.randn(1, 4, 512, 512)).float()
net = Unet()
output = net(test_input)
print("test over")