A Fast Approximation of the Bilateral Filter using a Signal Processing Approach

import numpy as np
import math
import scipy.signal, scipy.interpolate
import matplotlib.pyplot as plt
import cv2

def bilateral_approximation(image, sigmaS, sigmaR, samplingS=None, samplingR=None):
    # It is derived from Jiawen Chen's matlab implementation
    # The original papers and matlab code are available at http://people.csail.mit.edu/sparis/bf/

    # --------------- 原始分辨率 --------------- #
    inputHeight = image.shape[0]
    inputWidth = image.shape[1]
    sigmaS = sigmaS
    sigmaR = sigmaR
    samplingS = sigmaS if (samplingS is None) else samplingS
    samplingR = sigmaR if (samplingR is None) else samplingR
    edgeMax = np.amax(image)
    edgeMin = np.amin(image)
    edgeDelta = edgeMax - edgeMin

    # --------------- 下采样 --------------- #
    derivedSigmaS = sigmaS / samplingS
    derivedSigmaR = sigmaR / samplingR

    paddingXY = math.floor(2 * derivedSigmaS) + 1
    paddingZ = math.floor(2 * derivedSigmaR) + 1

    downsampledWidth = int(round((inputWidth - 1) / samplingS) + 1 + 2 * paddingXY)
    downsampledHeight = int(round((inputHeight - 1) / samplingS) + 1 + 2 * paddingXY)
    downsampledDepth = int(round(edgeDelta / samplingR) + 1 + 2 * paddingZ)

    wi = np.zeros((downsampledHeight, downsampledWidth, downsampledDepth))
    w = np.zeros((downsampledHeight, downsampledWidth, downsampledDepth))

    # 下采样索引
    (ygrid, xgrid) = np.meshgrid(range(inputWidth), range(inputHeight))

    dimx = np.around(xgrid / samplingS) + paddingXY
    dimy = np.around(ygrid / samplingS) + paddingXY
    dimz = np.around((image - edgeMin) / samplingR) + paddingZ

    flat_image = image.flatten()
    flatx = dimx.flatten()
    flaty = dimy.flatten()
    flatz = dimz.flatten()

    # 盒式滤波器(平均下采样)
    for k in range(dimz.size):
        image_k = flat_image[k]
        dimx_k = int(flatx[k])
        dimy_k = int(flaty[k])
        dimz_k = int(flatz[k])

        wi[dimx_k, dimy_k, dimz_k] += image_k
        w[dimx_k, dimy_k, dimz_k] += 1

    # ---------------  三维卷积 --------------- #
    # 生成卷积核
    kernelWidth = 2 * derivedSigmaS + 1
    kernelHeight = kernelWidth
    kernelDepth = 2 * derivedSigmaR + 1

    halfKernelWidth = math.floor(kernelWidth / 2)
    halfKernelHeight = math.floor(kernelHeight / 2)
    halfKernelDepth = math.floor(kernelDepth / 2)

    (gridX, gridY, gridZ) = np.meshgrid(range(int(kernelWidth)), range(int(kernelHeight)), range(int(kernelDepth)))
    # 平移,使得中心为0
    gridX -= halfKernelWidth
    gridY -= halfKernelHeight
    gridZ -= halfKernelDepth
    gridRSquared = ((gridX * gridX + gridY * gridY) / (derivedSigmaS * derivedSigmaS)) + \
                   ((gridZ * gridZ) / (derivedSigmaR * derivedSigmaR))
    kernel = np.exp(-0.5 * gridRSquared)

    # 卷积
    blurredGridData = scipy.signal.fftconvolve(wi, kernel, mode='same')
    blurredGridWeights = scipy.signal.fftconvolve(w, kernel, mode='same')

    # ---------------  divide --------------- #
    blurredGridWeights = np.where(blurredGridWeights == 0, -2, blurredGridWeights)  # avoid divide by 0, won't read there anyway
    normalizedBlurredGrid = blurredGridData / blurredGridWeights
    normalizedBlurredGrid = np.where(blurredGridWeights < -1, 0, normalizedBlurredGrid)  # put 0s where it's undefined

    # --------------- 上采样 --------------- #
    (ygrid, xgrid) = np.meshgrid(range(inputWidth), range(inputHeight))

    # 上采样索引
    dimx = (xgrid / samplingS) + paddingXY
    dimy = (ygrid / samplingS) + paddingXY
    dimz = (image - edgeMin) / samplingR + paddingZ

    out_image = scipy.interpolate.interpn((range(normalizedBlurredGrid.shape[0]),
                                          (dimx, dimy, dimz))
    return out_image

if __name__ == "__main__":
    image = cv2.imread('lena512.bmp', 0)
    mean_image = bilateral_approximation(image, sigmaS=64, sigmaR=32, samplingS=32, samplingR=16)
    plt.imshow(image, cmap='gray')
    plt.imshow(mean_image, cmap='gray')


