制作Dataset,生成Burst

code from Dudhane, Akshay, Syed Waqas Zamir, Salman Khan, Fahad Shahbaz Khan和Ming-Hsuan Yang. 《Burstormer: Burst Image Restoration and Enhancement Transformer》. arXiv, 2023年4月3日. http://arxiv.org/abs/2304.01194.

Dataset(读取图片,图片增强)

import torch
import os
import cv2
import random
import numpy as np

class ZurichRAW2RGB(torch.utils.data.Dataset):
    """针对Zurich RAW to RGB数据集,opencv读取数据集中的图片并进行随机增强
 Canon RGB images from the "Zurich RAW to RGB mapping" dataset. You can download the full dataset (22 GB) from http://people.ee.ethz.ch/~ihnatova/pynet.html#dataset. Alternatively, you can only download the Canon RGB images (5.5 GB) from https://data.vision.ee.ethz.ch/bhatg/zurich-raw-to-rgb.zip
    """
    def __init__(self, root, split='train'):
        super().__init__()

        if split in ['train', 'test']:
            self.img_pth = os.path.join(root, split, 'canon')
        else:
            raise Exception('Unknown split {}'.format(split))

        self.image_list = self._get_image_list(split)
        self.split = split

    def _get_image_list(self, split):
    """获取每个图片的名称"""
        if split == 'train':
            image_list = ['{:d}.jpg'.format(i) for i in range(46839)]#46839
        elif split == 'test':
            image_list = ['{:d}.jpg'.format(i) for i in range(1204)]
        else:
            raise Exception

        return image_list

    def _get_image(self, im_id):
    """opencv读取图片并进行随机增强"""
        path = os.path.join(self.img_pth, self.image_list[im_id])
        img = cv2.imread(path)
        if random.randint(0,1) == 1 and self.split=='train':
            flag_aug = random.randint(1,7)
            img = self.data_augmentation(img, flag_aug)
        else:
            img = img
        return img

    def get_image(self, im_id):
        frame = self._get_image(im_id)

        return frame

    def __len__(self):
        return len(self.image_list)

    def __getitem__(self, index):
        frame = self._get_image(index)

        return frame

    def data_augmentation(self, image, mode):
        """
        Performs data augmentation of the input image
        Input:
            image: a cv2 (OpenCV) image
            mode: int. Choice of transformation to apply to the image
                    0 - no transformation
                    1 - flip up and down
                    2 - rotate counterwise 90 degree
                    3 - rotate 90 degree and flip up and down
                    4 - rotate 180 degree
                    5 - rotate 180 degree and flip
                    6 - rotate 270 degree
                    7 - rotate 270 degree and flip
        """
        if mode == 0:
            # original
            out = image
        elif mode == 1:
            # flip up and down
            out = np.flipud(image)
        elif mode == 2:
            # rotate counterwise 90 degree
            out = np.rot90(image)
        elif mode == 3:
            # rotate 90 degree and flip up and down
            out = np.rot90(image)
            out = np.flipud(out)
        elif mode == 4:
            # rotate 180 degree
            out = np.rot90(image, k=2)
        elif mode == 5:
            # rotate 180 degree and flip
            out = np.rot90(image, k=2)
            out = np.flipud(out)
        elif mode == 6:
            # rotate 270 degree
            out = np.rot90(image, k=3)
        elif mode == 7:
            # rotate 270 degree and flip
            out = np.rot90(image, k=3)
            out = np.flipud(out)
        else:
            raise Exception('Invalid choice of image transformation')
        return out.copy()

调用该类实例返回的是opencv打开的RGB图片 (ndarray : (H, W, C) )

随机裁剪

import torch.nn.functional as F

def random_crop(frames, crop_sz):
    """ 将frames裁剪为crop_sz大小。如果crop_sz大于frames大小,
    则从frames中提取与crop_sz相同长宽比的最大可能部分,将其上采样到crop_sz大小
    """
    if not isinstance(crop_sz, (tuple, list)):
        crop_sz = (crop_sz, crop_sz)
    crop_sz = torch.tensor(crop_sz).float()

    shape = frames.shape

    # Select scale_factor. Ensure the crop fits inside the image
    max_scale_factor = torch.tensor(shape[-2:]).float() / crop_sz
    max_scale_factor = max_scale_factor.min().item()

    if max_scale_factor < 1.0:
        scale_factor = max_scale_factor
    else:
        scale_factor = 1.0

    # Extract the crop
    orig_crop_sz = (crop_sz * scale_factor).floor()

    assert orig_crop_sz[-2] <= shape[-2] and orig_crop_sz[-1] <= shape[-1], 'Bug in crop size estimation!'

    r1 = random.randint(0, shape[-2] - orig_crop_sz[-2])
    c1 = random.randint(0, shape[-1] - orig_crop_sz[-1])

    r2 = r1 + orig_crop_sz[0].int().item()
    c2 = c1 + orig_crop_sz[1].int().item()

    frames_crop = frames[:, r1:r2, c1:c2]

    # Resize to crop_sz
    if scale_factor < 1.0:
        frames_crop = F.interpolate(frames_crop.unsqueeze(0), size=crop_sz.int().tolist(), mode='bilinear', align_corners=True).squeeze(0)
    return frames_crop

附:一图生成Burst

import torch
import data_processing.synthetic_burst_generation as syn_burst_utils
import torchvision.transforms as tfm


class SyntheticBurst(torch.utils.data.Dataset):
    """ Synthetic burst dataset for joint denoising, demosaicking, and super-resolution. 
首先,从base_dataset中加载一张图像。利用[1]中采用的inverse camera pipeline将采样图像转换到线性传感器空间。 再通过添加随机平移、旋转、mosaicked(转为RGGB)、添加随机噪声,最终获得RAW burst 序列。

    [1] Unprocessing Images for Learned Raw Denoising, Brooks, Tim and Mildenhall, Ben and Xue, Tianfan and Chen,
    Jiawen and Sharlet, Dillon and Barron, Jonathan T, CVPR 2019
    """
    def __init__(self, base_dataset, burst_size=8, crop_sz=384, transform=tfm.ToTensor()):
        self.base_dataset = base_dataset

        self.burst_size = burst_size
        self.crop_sz = crop_sz
        self.transform = transform

        self.downsample_factor = 4
        self.burst_transformation_params = {'max_translation': 24.0,
                                            'max_rotation': 1.0,
                                            'max_shear': 0.0,
                                            'max_scale': 0.0,
                                            'border_crop': 24}

        self.image_processing_params = {'random_ccm': True, 'random_gains': True, 'smoothstep': True,
                                        'gamma': True,
                                        'add_noise': True}
        self.interpolation_type = 'bilinear'

    def __len__(self):
        return len(self.base_dataset)

    def __getitem__(self, index):
        """ Generates a synthetic burst
        args:
            index: Index of the image in the base_dataset used to generate the burst

        returns:
            burst: Generated LR RAW burst, a torch tensor of shape
                   [burst_size, 4, self.crop_sz / (2*self.downsample_factor), self.crop_sz / (2*self.downsample_factor)]
                   The 4 channels correspond to 'R', 'G', 'G', and 'B' values in the RGGB bayer mosaick.
                   The extra factor 2 in the denominator (2*self.downsample_factor) corresponds to the mosaicking
                   operation.

            frame_gt: The HR RGB ground truth in the linear sensor space, a torch tensor of shape
                      [3, self.crop_sz, self.crop_sz]

            flow_vectors: The ground truth flow vectors between a burst image and the base image (i.e. the first image in the burst).
                          The flow_vectors can be used to warp the burst images to the base frame, using the 'warp'
                          function in utils.warp package.
                          flow_vectors is torch tensor of shape
                          [burst_size, 2, self.crop_sz / self.downsample_factor, self.crop_sz / self.downsample_factor].
                          Note that the flow_vectors are in the LR RGB space, before mosaicking. Hence it has twice
                          the number of rows and columns, compared to the output burst.

                          NOTE: The flow_vectors are only available during training for the purpose of using any
                                auxiliary losses if needed. The flow_vectors will NOT be provided for the bursts in the
                                test set

            meta_info: A dictionary containing the parameters used to generate the synthetic burst.
        """
        frame = self.base_dataset[index]

        # Augmentation, e.g. convert to tensor
        if self.transform is not None:
            frame = self.transform(frame)

        # Extract a random crop from the image
        crop_sz = self.crop_sz + 2 * self.burst_transformation_params.get('border_crop', 0)
        frame_crop = syn_burst_utils.random_crop(frame, crop_sz)

        # Generate RAW burst
        burst, frame_gt, burst_rgb, flow_vectors, meta_info = syn_burst_utils.rgb2rawburst(frame_crop,
                                                                                           self.burst_size,
                                                                                           self.downsample_factor,
                                                                                           burst_transformation_params=self.burst_transformation_params,
                                                                                           image_processing_params=self.image_processing_params,
                                                                                           interpolation_type=self.interpolation_type
                                                                                           )

        if self.burst_transformation_params.get('border_crop') is not None:
            border_crop = self.burst_transformation_params.get('border_crop')
            frame_gt = frame_gt[:, border_crop:-border_crop, border_crop:-border_crop]  # 对GT进行同样裁剪
        #frame_gt = frame_gt.unsqueeze(0)
        return burst, frame_gt, flow_vectors, meta_info

具体的操作函数:

def rgb2rawburst(image, burst_size, downsample_factor=1, burst_transformation_params=None,
                 image_processing_params=None, interpolation_type='bilinear'):
    """ Generates a synthetic LR RAW burst from the input image. The input sRGB image is first converted to linear
    sensor space using an inverse camera pipeline. A LR burst is then generated by applying random
    transformations defined by burst_transformation_params to the input image, and downsampling it by the
    downsample_factor. The generated burst is then mosaicekd and corrputed by random noise.
    """

    if image_processing_params is None:
        image_processing_params = {}

    _defaults = {'random_ccm': True, 'random_gains': True, 'smoothstep': True, 'gamma': True, 'add_noise': True}
    for k, v in _defaults.items():
        if k not in image_processing_params:
            image_processing_params[k] = v

    # Sample camera pipeline params
    if image_processing_params['random_ccm']:
        rgb2cam = rgb2raw.random_ccm() # a (3,3) matrix:color correction matrices
    else:
        rgb2cam = torch.eye(3).float()
    cam2rgb = rgb2cam.inverse()  # 求逆

    # Sample gains
    if image_processing_params['random_gains']:
        rgb_gain, red_gain, blue_gain = rgb2raw.random_gains()
    else:
        rgb_gain, red_gain, blue_gain = (1.0, 1.0, 1.0)

    # Approximately inverts global tone mapping.
    use_smoothstep = image_processing_params['smoothstep']
    if use_smoothstep:
        image = rgb2raw.invert_smoothstep(image)

    # Inverts gamma compression.
    use_gamma = image_processing_params['gamma']
    if use_gamma:
        image = rgb2raw.gamma_expansion(image)

    # Inverts color correction.
    image = rgb2raw.apply_ccm(image, rgb2cam)
    # 将color correction matrices (3,3)矩阵乘 image(3,HxW)

    # Approximately inverts white balance and brightening.
    image = rgb2raw.safe_invert_gains(image, rgb_gain, red_gain, blue_gain)

    # Clip saturated pixels.
    image = image.clamp(0.0, 1.0)

    # Generate LR burst
    image_burst_rgb, flow_vectors = single2lrburst(image, burst_size=burst_size,
                                                   downsample_factor=downsample_factor,
                                                   transformation_params=burst_transformation_params,
                                                   interpolation_type=interpolation_type)

    # mosaic: turn RGB(3 channels) to RAW(RGGB:4 channels)
    image_burst = rgb2raw.mosaic(image_burst_rgb.clone())

    # Add noise
    if image_processing_params['add_noise']:
        shot_noise_level, read_noise_level = rgb2raw.random_noise_levels()
        image_burst = rgb2raw.add_noise(image_burst, shot_noise_level, read_noise_level)
    else:
        shot_noise_level = 0
        read_noise_level = 0

    # Clip saturated pixels.
    image_burst = image_burst.clamp(0.0, 1.0)

    meta_info = {'rgb2cam': rgb2cam, 'cam2rgb': cam2rgb, 'rgb_gain': rgb_gain, 'red_gain': red_gain,
                 'blue_gain': blue_gain, 'smoothstep': use_smoothstep, 'gamma': use_gamma,
                 'shot_noise_level': shot_noise_level, 'read_noise_level': read_noise_level}
    return image_burst, image, image_burst_rgb, flow_vectors, meta_info


def get_tmat(image_shape, translation, theta, shear_values, scale_factors):
    """ Generates a transformation matrix corresponding to the input transformation parameters """
    im_h, im_w = image_shape

    t_mat = np.identity(3)

    t_mat[0, 2] = translation[0]
    t_mat[1, 2] = translation[1]
    t_rot = cv2.getRotationMatrix2D((im_w * 0.5, im_h * 0.5), theta, 1.0)
    t_rot = np.concatenate((t_rot, np.array([0.0, 0.0, 1.0]).reshape(1, 3)))

    t_shear = np.array([[1.0, shear_values[0], -shear_values[0] * 0.5 * im_w],
                        [shear_values[1], 1.0, -shear_values[1] * 0.5 * im_h],
                        [0.0, 0.0, 1.0]])

    t_scale = np.array([[scale_factors[0], 0.0, 0.0],
                        [0.0, scale_factors[1], 0.0],
                        [0.0, 0.0, 1.0]])

    t_mat = t_scale @ t_rot @ t_shear @ t_mat

    t_mat = t_mat[:2, :]

    return t_mat


def single2lrburst(image, burst_size, downsample_factor=1, transformation_params=None,
                   interpolation_type='bilinear'):
    """ Generates a burst of size burst_size from the input image by applying random transformations defined by
    transformation_params, and downsampling the resulting burst by downsample_factor.
    """

    if interpolation_type == 'bilinear':
        interpolation = cv2.INTER_LINEAR
    elif interpolation_type == 'lanczos':
        interpolation = cv2.INTER_LANCZOS4
    else:
        raise ValueError

    normalize = False
    if isinstance(image, torch.Tensor):
        if image.max() < 2.0:
            image = image * 255.0
            normalize = True
        image = torch_to_numpy(image).astype(np.uint8)

    burst = []
    sample_pos_inv_all = []

    rvs, cvs = torch.meshgrid([torch.arange(0, image.shape[0]),
                               torch.arange(0, image.shape[1])])

    sample_grid = torch.stack((cvs, rvs, torch.ones_like(cvs)), dim=-1).float()  # shape: (h, w, 3)

    for i in range(burst_size):
        if i == 0:
            # For base image, do not apply any random transformations. We only translate the image to center the
            # sampling grid
            shift = (downsample_factor / 2.0) - 0.5
            translation = (shift, shift)
            theta = 0.0
            shear_factor = (0.0, 0.0)
            scale_factor = (1.0, 1.0)
        else:
            # Sample random image transformation parameters
            max_translation = transformation_params.get('max_translation', 0.0)

            if max_translation <= 0.01:
                shift = (downsample_factor / 2.0) - 0.5
                translation = (shift, shift)
            else:
                translation = (random.uniform(-max_translation, max_translation),
                               random.uniform(-max_translation, max_translation))

            max_rotation = transformation_params.get('max_rotation', 0.0)
            theta = random.uniform(-max_rotation, max_rotation)

            max_shear = transformation_params.get('max_shear', 0.0)
            shear_x = random.uniform(-max_shear, max_shear)
            shear_y = random.uniform(-max_shear, max_shear)
            shear_factor = (shear_x, shear_y)

            max_ar_factor = transformation_params.get('max_ar_factor', 0.0)
            ar_factor = np.exp(random.uniform(-max_ar_factor, max_ar_factor))

            max_scale = transformation_params.get('max_scale', 0.0)
            scale_factor = np.exp(random.uniform(-max_scale, max_scale))

            scale_factor = (scale_factor, scale_factor * ar_factor)

        output_sz = (image.shape[1], image.shape[0])

        # Generate a affine transformation matrix corresponding to the sampled parameters
        t_mat = get_tmat((image.shape[0], image.shape[1]), translation, theta, shear_factor, scale_factor)
        t_mat_tensor = torch.from_numpy(t_mat)

        # Apply the sampled affine transformation
        image_t = cv2.warpAffine(image, t_mat, output_sz, flags=interpolation,
                                 borderMode=cv2.BORDER_CONSTANT)

        t_mat_tensor_3x3 = torch.cat((t_mat_tensor.float(), torch.tensor([0.0, 0.0, 1.0]).view(1, 3)), dim=0)
        t_mat_tensor_inverse = t_mat_tensor_3x3.inverse()[:2, :].contiguous()  # 求逆并取前两行

        sample_pos_inv = torch.mm(sample_grid.view(-1, 3), t_mat_tensor_inverse.t().float()).view(
            *sample_grid.shape[:2], -1)  # shape: (h, w, 2)

        if transformation_params.get('border_crop') is not None:
            border_crop = transformation_params.get('border_crop')

            image_t = image_t[border_crop:-border_crop, border_crop:-border_crop, :]
            sample_pos_inv = sample_pos_inv[border_crop:-border_crop, border_crop:-border_crop, :]
            # ramdom_crop时多裁了2*border_crop大小,此时消去,crop_sz仍为设定值

        # Downsample the image
        image_t = cv2.resize(image_t, None, fx=1.0 / downsample_factor, fy=1.0 / downsample_factor,
                             interpolation=interpolation)
        sample_pos_inv = cv2.resize(sample_pos_inv.numpy(), None, fx=1.0 / downsample_factor,
                                    fy=1.0 / downsample_factor,
                                    interpolation=interpolation)

        sample_pos_inv = torch.from_numpy(sample_pos_inv).permute(2, 0, 1)

        if normalize:
            image_t = numpy_to_torch(image_t).float() / 255.0
        else:
            image_t = numpy_to_torch(image_t).float()
        burst.append(image_t)
        sample_pos_inv_all.append(sample_pos_inv / downsample_factor)

    burst_images = torch.stack(burst)
    sample_pos_inv_all = torch.stack(sample_pos_inv_all)

    # Compute the flow vectors to go from the i'th burst image to the base image
    flow_vectors = sample_pos_inv_all - sample_pos_inv_all[:, :1, ...]

    return burst_images, flow_vectors

rgb2rawburst函数将输入图像转化为linear sensor space,通过使用inverse camera pipeline。然后将sRGB转换为RAW和添加噪声等。

single2lrburst函数负责通过affine transformation matrix来实现单图到burst。

get_tmat函数即用来得到affine transformation matrix。

注意整个系列代码中得到了flow vectors,作者称可将burst图像还原(warp ?)到原始图。

import torch
import random
import math
import cv2 as cv
import numpy as np
import utils.data_format_utils as df_utils
""" Based on http://timothybrooks.com/tech/unprocessing 
Functions for forward and inverse camera pipeline. All functions input a torch float tensor of shape (c, h, w). 
Additionally, some also support batch operations, i.e. inputs of shape (b, c, h, w)
"""


def random_ccm():
    """Generates random RGB -> Camera color correction matrices."""
    # Takes a random convex combination of XYZ -> Camera CCMs.
    xyz2cams = [[[1.0234, -0.2969, -0.2266],
               [-0.5625, 1.6328, -0.0469],
               [-0.0703, 0.2188, 0.6406]],
              [[0.4913, -0.0541, -0.0202],
               [-0.613, 1.3513, 0.2906],
               [-0.1564, 0.2151, 0.7183]],
              [[0.838, -0.263, -0.0639],
               [-0.2887, 1.0725, 0.2496],
               [-0.0627, 0.1427, 0.5438]],
              [[0.6596, -0.2079, -0.0562],
               [-0.4782, 1.3016, 0.1933],
               [-0.097, 0.1581, 0.5181]]]

    num_ccms = len(xyz2cams)  # 4
    xyz2cams = torch.tensor(xyz2cams)  # shape:(4,3,3)

    weights = torch.FloatTensor(num_ccms, 1, 1).uniform_(0.0, 1.0)  # shape:(4,1,1)
    weights_sum = weights.sum()
    xyz2cam = (xyz2cams * weights).sum(dim=0) / weights_sum  # shape:(3,3)

    # Multiplies with RGB -> XYZ to get RGB -> Camera CCM.
    rgb2xyz = torch.tensor([[0.4124564, 0.3575761, 0.1804375],
                            [0.2126729, 0.7151522, 0.0721750],
                            [0.0193339, 0.1191920, 0.9503041]])
    rgb2cam = torch.mm(xyz2cam, rgb2xyz)  # shape:(3,3)

    # Normalizes each row.
    rgb2cam = rgb2cam / rgb2cam.sum(dim=-1, keepdims=True)
    return rgb2cam


def random_gains():
    """Generates random gains for brightening and white balance."""
    # RGB gain represents brightening.
    rgb_gain = 1.0 / random.gauss(mu=0.8, sigma=0.1)

    # Red and blue gains represent white balance.
    red_gain = random.uniform(1.9, 2.4)
    blue_gain = random.uniform(1.5, 1.9)
    return rgb_gain, red_gain, blue_gain


def apply_smoothstep(image):
    """Apply global tone mapping curve."""
    image_out = 3 * image**2 - 2 * image**3
    return image_out


def invert_smoothstep(image):
    """Approximately inverts a global tone mapping curve."""
    image = image.clamp(0.0, 1.0)
    return 0.5 - torch.sin(torch.asin(1.0 - 2.0 * image) / 3.0)


def gamma_expansion(image):
    """Converts from gamma to linear space."""
    # Clamps to prevent numerical instability of gradients near zero.
    return image.clamp(1e-8) ** 2.2


def gamma_compression(image):
    """Converts from linear to gammaspace."""
    # Clamps to prevent numerical instability of gradients near zero.
    return image.clamp(1e-8) ** (1.0 / 2.2)


def apply_ccm(image, ccm):
    """Applies a color correction matrix."""
    assert image.dim() == 3 and image.shape[0] == 3

    shape = image.shape
    image = image.view(3, -1)
    ccm = ccm.to(image.device).type_as(image)

    image = torch.mm(ccm, image)

    return image.view(shape)


def apply_gains(image, rgb_gain, red_gain, blue_gain):
    """Inverts gains while safely handling saturated pixels."""
    assert image.dim() == 3 and image.shape[0] in [3, 4]

    if image.shape[0] == 3:
        gains = torch.tensor([red_gain, 1.0, blue_gain]) * rgb_gain
    else:
        gains = torch.tensor([red_gain, 1.0, 1.0, blue_gain]) * rgb_gain
    gains = gains.view(-1, 1, 1)
    gains = gains.to(image.device).type_as(image)

    return (image * gains).clamp(0.0, 1.0)


def safe_invert_gains(image, rgb_gain, red_gain, blue_gain):
    """Inverts gains while safely handling saturated pixels."""
    assert image.dim() == 3 and image.shape[0] == 3

    gains = torch.tensor([1.0 / red_gain, 1.0, 1.0 / blue_gain]) / rgb_gain
    gains = gains.view(-1, 1, 1)

    # Prevents dimming of saturated pixels by smoothly masking gains near white.
    gray = image.mean(dim=0, keepdims=True)
    inflection = 0.9
    mask = ((gray - inflection).clamp(0.0) / (1.0 - inflection)) ** 2.0

    safe_gains = torch.max(mask + (1.0 - mask) * gains, gains)
    return image * safe_gains


def mosaic(image, mode='rggb'):
    """Extracts RGGB Bayer planes from an RGB image."""
    shape = image.shape
    if image.dim() == 3:
        image = image.unsqueeze(0)

    if mode == 'rggb':
        red = image[:, 0, 0::2, 0::2]  # red channel, 隔2取1个pixel
        green_red = image[:, 1, 0::2, 1::2]
        green_blue = image[:, 1, 1::2, 0::2]
        blue = image[:, 2, 1::2, 1::2]
        image = torch.stack((red, green_red, green_blue, blue), dim=1)
    elif mode == 'grbg':
        green_red = image[:, 1, 0::2, 0::2]
        red = image[:, 0, 0::2, 1::2]
        blue = image[:, 2, 0::2, 1::2]
        green_blue = image[:, 1, 1::2, 1::2]

        image = torch.stack((green_red, red, blue, green_blue), dim=1)

    if len(shape) == 3:
        return image.view((4, shape[-2] // 2, shape[-1] // 2))
    else:
        return image.view((-1, 4, shape[-2] // 2, shape[-1] // 2))


def demosaic(image):
    assert isinstance(image, torch.Tensor)
    image = image.clamp(0.0, 1.0) * 255

    if image.dim() == 4:
        num_images = image.size(0)
        batch_input = True
    else:
        num_images = 1
        batch_input = False
        image = image.unsqueeze(0)

    # Generate single channel input for opencv
    im_sc = torch.zeros((num_images, image.shape[-2] * 2, image.shape[-1] * 2, 1))
    im_sc[:, ::2, ::2, 0] = image[:, 0, :, :]
    im_sc[:, ::2, 1::2, 0] = image[:, 1, :, :]
    im_sc[:, 1::2, ::2, 0] = image[:, 2, :, :]
    im_sc[:, 1::2, 1::2, 0] = image[:, 3, :, :]

    im_sc = im_sc.numpy().astype(np.uint8)

    out = []

    for im in im_sc:
        im_dem_np = cv.cvtColor(im, cv.COLOR_BAYER_BG2RGB_VNG)

        # Convert to torch image
        im_t = df_utils.npimage_to_torch(im_dem_np, input_bgr=False)
        out.append(im_t)

    if batch_input:
        return torch.stack(out, dim=0)
    else:
        return out[0]


def random_noise_levels():
    """Generates random noise levels from a log-log linear distribution."""
    log_min_shot_noise = math.log(0.0001)
    log_max_shot_noise = math.log(0.012)
    log_shot_noise = random.uniform(log_min_shot_noise, log_max_shot_noise)
    shot_noise = math.exp(log_shot_noise)

    line = lambda x: 2.18 * x + 1.20
    log_read_noise = line(log_shot_noise) + random.gauss(mu=0.0, sigma=0.26)
    read_noise = math.exp(log_read_noise)
    return shot_noise, read_noise


def add_noise(image, shot_noise=0.01, read_noise=0.0005):
    """Adds random shot (proportional to image) and read (independent) noise."""
    variance = image * shot_noise + read_noise
    noise = torch.FloatTensor(image.shape).normal_().to(image.device)*variance.sqrt()
    return image + noise


def process_linear_image_rgb(image, meta_info, return_np=False):
    image = apply_gains(image, meta_info['rgb_gain'], meta_info['red_gain'], meta_info['blue_gain'])
    image = apply_ccm(image, meta_info['cam2rgb'])

    if meta_info['gamma']:
        image = gamma_compression(image)

    if meta_info['smoothstep']:
        image = apply_smoothstep(image)

    image = image.clamp(0.0, 1.0)

    if return_np:
        image = df_utils.torch_to_npimage(image)
    return image


def process_linear_image_raw(image, meta_info):
    image = apply_gains(image, meta_info['rgb_gain'], meta_info['red_gain'], meta_info['blue_gain'])
    image = demosaic(image)
    image = apply_ccm(image, meta_info['cam2rgb'])

    if meta_info['gamma']:
        image = gamma_compression(image)

    if meta_info['smoothstep']:
        image = apply_smoothstep(image)
    return image.clamp(0.0, 1.0)

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
K-fold交叉验证是一种常用的模型评估方法,它将数据集分成K个部分,每次使用其中K-1个部分作为训练集,剩余的1个部分作为验证集,重复训练K次,每次使用不同的验证集,最终将K次的验证结果取平均值作为模型的性能评估指标。 在Python中,可以使用scikit-learn库中的KFold函数来实现K-fold交叉验证。下面是一个示例代码: ```python from sklearn.model_selection import KFold # 定义K值 k = 5 # 加载数据集 X, y = load_data() # 定义K-fold交叉验证对象 kf = KFold(n_splits=k) # 循环训练K次 for train_index, test_index in kf.split(X): # 获取训练集和验证集 X_train, X_test = X[train_index], X[test_index] y_train, y_test = y[train_index], y[test_index] # 训练模型并测试性能 model = train_model(X_train, y_train) score = evaluate_model(model, X_test, y_test) print('Fold score:', score) ``` 其中,load_data()函数可以根据自己的数据集返回X和y,train_model()函数用于训练模型,evaluate_model()函数用于评估模型性能,可以根据具体的应用场景进行实现。 如果想要在PyTorch中使用K-fold交叉验证,可以使用Dataset和DataLoader生成数据集和批量数据。下面是一个示例代码: ```python from torch.utils.data import Dataset, DataLoader from sklearn.model_selection import KFold # 定义K值 k = 5 # 加载数据集 dataset = MyDataset() # 定义K-fold交叉验证对象 kf = KFold(n_splits=k) # 循环训练K次 for train_index, test_index in kf.split(dataset): # 获取训练集和验证集 train_dataset = Subset(dataset, train_index) test_dataset = Subset(dataset, test_index) # 定义DataLoader train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True) test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=False) # 训练模型并测试性能 model = MyModel() trainer = Trainer(model, train_dataloader, test_dataloader) score = trainer.train() print('Fold score:', score) ``` 其中,MyDataset()是自定义的数据集类,MyModel()是自定义的模型类,Trainer()是用于训练和评估模型的类,可以根据具体的应用场景进行实现。Subset类用于从数据集中选取特定的样本。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值