显著图分割-Saliency Map python

通过Saliency map原理配合傅里叶转换抓取显著性区域,实现焊缝抓取~
资源来自《opencv with python Blueprints》

import cv2
import numpy as np
import matplotlib.pyplot as plt
import math, os, threading

class Saliency:
    '''
    A class that generates a saliency map from RGB color image.
    Saliency map: We will use Fourier analysis to get a general understanding
    of natural image statistics, which will help us build a model of what general
    image backgrounds look like. By comparing and contrasting the background
    model to a specific image frame, we can locate sub-regions of the image that
    pop out of their surroundings. Ideally, these sub-regions correspond to the
    image patches that tend to grab our immediate attention when looking at the
    image
    '''
    def __init__(self, img, use_numpy_fft=True, gauss_kernel=(5,5)):
        self.use_numpy_fft = use_numpy_fft
        self.gauss_kernel = gauss_kernel
        self.frame_org = img

        # A saliency map will be generated from a down sampled version of the image,
        # and because the computation is relatively time-intensive,
        # we will maintain a flag need_saliency_map that makes sure we do the computations only once:
        self.small_shape = (64, 64)
        self.frame_small = cv2.resize(img, self.small_shape[1::-1])

        # whether we need to do the math (True) or it has already been done (False)
        self.need_saliency_map = True
        # From then on, the user may call any of the class' public methods,
        # which will all be passed on the same image

    def _get_channel_saliency_magnitude(self,channel):
        '''
        In order to generate a saliency map based on the spectral residual approach,
        we need to process each channel of an input image separately (single channel
        in the case of a grayscale input image, and three separate channels in the
        case of an RGB input image).
        The resulting single-channel saliency map(magnitude) is then returned to
        Saliency.get_saliency_map, where the procedure is repeated for all channels
        of the input image. If the input image is grayscale, we are pretty much done.
        :param channel:
        :return:->magnitude
        '''

        # 1. Calculate the (magnitude and phase of the)Fourier spectrum of an image,
        # by again using either the fft module of NumPy's or OpenCV functionality.
        if self.use_numpy_fft:
            img_dft = np.fft.fft2(channel)
            magnitude, angle = cv2.cartToPolar(np.real(img_dft),
                                               np.imag(img_dft))
        else:
            img_dft = cv2.dft(np.float32(channel),
                              flags=cv2.DFT_COMPLEX_OUTPUT)
            magnitude, angle = cv2.cartToPolar(img_dft[:,:,0],
                                               img_dft[:,:,1])

        # 2. Calculate the log amplitude of the Fourier spectrum.
        # We will clip the lower bound of magnitudes to 1e-9 in order to
        # prevent a division by zero while calculating the log.
        log_amplitude = np.log10(magnitude.clip(min=1e-9))

        # 3. Approximate the averaged spectrum of a typical natural image by
        # convolving the image with a local averaging filter.
        log_amplitude_blur = cv2.blur(log_amplitude,(3,3))

        # 4. Calculate the pectral residual.
        # The spectral residual primarily contains the nontrivial (or unexpected) parts of a scene.
        residual = np.exp(log_amplitude - log_amplitude_blur)

        # 5. Calculate the saliency map by using the inverse Fourier transform,
        # agian either via the fft module in NumPy or with OpenCV.
        if self.use_numpy_fft:
        #---------------------------------------Question---------------------------------------------------
            real_part, imag_part = cv2.polarToCart(residual,angle)
            img_combined = np.fft.ifft2(real_part + 1j*imag_part)
            magnitude, _ = cv2.cartToPolar(np.real(img_combined),np.imag(img_combined))
        else:
            img_dft[:, :, 0], img_dft[:, :, 1] = cv2.polarToCart(residual, angle)
            img_combined = cv2.idft(img_dft)
            magnitude, _ = cv2.cartToPolar(img_combined[:, :, 0], img_combined[:, :, 1])
        return magnitude


    def plot_magnitude(self):
        '''
        In OpenCV, this transformation can be achieved with the Discrete Fourier Transform(DFT)
        using the plot_magnitude method of the saliency class.
        :return:
        '''
        # 1.Convert the image to grayscale if necessary:
        # Cause the method accepts both grayscale and RGB color images,
        # we need to make sure we operate on a single-channel image
        if len(self.frame_org.shape) > 2:
            frame = cv2.cvtColor(self.frame_org, cv2.COLOR_BGR2GRAY)
        else:
            frame = self.frame_org

        # 2. Expand the image to an optimal size:
        # It turns out that the performance of a DFT depends on the image size.
        # It tends to be fastest for the image sizes that are multiples of the number two.
        # It is therefore generally a good idea to pad the image with zeros.
        rows, cols = frame.shape[:2]
        nrows = cv2.getOptimalDFTSize(rows)
        ncols = cv2.getOptimalDFTSize(cols)
        frame = cv2.copyMakeBorder(frame,
                                   top=0, bottom=nrows-rows,
                                   left=0, right=ncols-cols,
                                   borderType=cv2.BORDER_CONSTANT, value=0)

        # 3. Apply the DFT:
        # This is a single function call in NumPy.
        # The result is a 2D matrix of complex numbers.
        img_dft = np.fft.fft2(frame)

        # 4. Transform the real and complex values to magnitude:
        # A complex number has a real (Re) and a complex (imaginary - Im) part.
        # To extract the magnitude, we take the absolute value.
        magnitude = np.abs(img_dft)

        # 5. Switch to a logarithmic scale:
        # It turns out that the dynamic range of the Fourier coefficients
        # usually too large to be displayed on the scree.
        # We have some small and some high changing values that we can't observe like this.
        # Therefore, the high values will all turn out as a white points,
        # and the small ones as black points.
        # To use the gray scale values for visualization,
        # we can transform out linear scale to a logarithmic one.
        log_magnitude = np.log10(magnitude)

        # 6. Shift quadrants:
        # To center the spectrum on the image.
        # This makes it easier to visually inspect the magnitude spectrum.
        spectrum = np.fft.fftshift(log_magnitude)

        # 7. Return the result for plotting.
        return spectrum/np.max(spectrum)*255

    def get_saliency_map(self):
        '''
        The main method to convert an RGB color image to a saliency map.
        :return: The saliency map which value range in [0.,1.]
        '''
        if self.need_saliency_map:
            # have't calculated saliency map for this frame yet
            num_channels = 1
            if len(self.frame_org.shape) == 2:
                # single channel
                sal = self._get_channel_saliency_magnitude(self.frame_small)
            else:
                # consider each channel independently
                sal = np.zeros_like(self.frame_small).astype(np.float32)
                for c in range(len(self.frame_small.shape)):
                    sal[:, :, c] = self._get_channel_saliency_magnitude(self.frame_small[:, :, c])

                # The overall salience of a multichannel image is the determined
                # by average over all channel
                sal = np.mean(sal,2)

            # Finally, we fneed to apply some post-processing, such as an optional blurring
            # stage to make the result appear smoother
            if self.gauss_kernel is not None:
                sal = cv2.GaussianBlur(sal, self.gauss_kernel, sigmaX=8, sigmaY=0)

            # Also we need to square the values in sal in order to highlight the regions of high salience,
            # as outlined by the authors of the original paper.
            # In order to display the image, we scale it back up to its original resolution and
            # normalize the values, so that the largest value is one.
            sal = sal ** 2
            sal = np.float32(sal) / np.max(sal)
            sal = cv2.resize(sal, self.frame_org.shape[1::-1])

            #Inorder to aviod having to redo all these intense calculations,
            # we store a local copy of the saliency map for further reference and
            # make sure to lower the flag.
            self.saliency_map = sal
            self.need_saliency_map = False
        return self.saliency_map

    def get_proto_objects_map(self, use_otsu=True):
        '''
        A method to convert a saliency map into a binary mask containing all the proto-objects.
        :return:
        '''
        saliency = self.get_saliency_map()
        if use_otsu:
            img_objs = cv2.threshold(np.uint8(saliency*255),0,255,cv2.THRESH_BINARY + cv2.THRESH_OTSU)[1]
        else:
            thresh = np.mean(saliency) * 255
            img_objs = cv2.threshold(np.uint8(saliency*255),thresh,255,cv2.THRESH_BINARY)[1]
        return img_objs

    def plot_power_spectrum(self):
        '''
        A method to display the radially averaged power spectrum of an RGB color image,
        which is helpful to understand natural image statistics.
        :return:
        '''

        # 1. Convert the image to grayscale if necessary.
        if len(self.frame_org.shape)>2:
            frame = cv2.cvtColor(self.frame_org, cv2.COLOR_BGR2GRAY)
        else:
            frame = self.frame_org

        # 2. Expand the image to optimal size.
        rows, cols = frame.shape[:2]
        nrows = cv2.getOptimalDFTSize(rows)
        ncols = cv2.getOptimalDFTSize(cols)
        frame = cv2.copyMakeBorder(frame,
                                   top=0, bottom=nrows-rows,
                                   left=0, right=ncols-cols,
                                   borderType=cv2.BORDER_CONSTANT, value=0)

        # 3. Apply the DFT and get the log spectrum:
        # Here we give the user an option (via flag use_numpy_fft)
        # to use either NumPy's or OpenCV's Fourier tools.
        if self.use_numpy_fft:
            img_dft = np.fft.fft2(frame)
            spectrum = np.log10(np.real(np.abs(img_dft))**2)
        else:
            img_dft = cv2.dft(np.float32(frame),
                              flags=cv2.DFT_COMPLEX_OUTPUT)
            spectrum = np.log10(img_dft[:,:,0]**2 +
                                img_dft[:,:,1]**2)

        # 4. Perfom radial averaging:
        # This is the tricky part.
        # It would be wrong to simply average the 2D spectrum in the direction of x or y.
        # What we are intersted in is a spectrum as a function of frequency,
        # independent of the exact orientation.
        # This sometimes also called the "radially averaged power spectrum (RAPS)",
        # and can be achieved by summing up all the frequency magnitudes,
        # starting at the center of the image,
        # looking into all possible (radial) directions, from some frequency r to r+dr.
        # We use the binning function of NumPy's histogram to sum up the number,
        # and accumulate the in variable histo
        L = max(frame.shape)
        freqs = np.fft.fftfreq(L)[:int(L/2)]
        dists = np.sqrt(np.fft.fftfreq(frame.shape[0])[:,np.newaxis]**2 +
                        np.fft.fftfreq(frame.shape[1])**2)
        dcount = np.histogram(dists.ravel(), bins=freqs)[0]
        histo, bins = np.histogram(dists.ravel(),
                                   bins=freqs,
                                   weights=spectrum.ravel())

        # 5. Plot the result
        # Finally we can plot the accumulated numbers in histo,
        # but must not forget to normalize these by the bin size(dcount).
        centers = (bins[:-1] + bins[1:]) / 2
        plt.plot(centers, histo/dcount)
        plt.xlabel('frequency')
        plt.ylabel('log-spectrum')
        plt.show()

if __name__ == '__main__':
    filtPath = r'D:\Jay.Lee\Study\imgs\weldcircle.png'
    img = cv2.imread(filtPath,cv2.IMREAD_COLOR)
    saliency = Saliency(img)
    saliency.plot_power_spectrum()
    mask = cv2.morphologyEx(saliency.get_proto_objects_map(use_otsu=False),
                            cv2.MORPH_CLOSE,
                            cv2.getStructuringElement(cv2.MORPH_ELLIPSE,(55,55)))
    cv2.imshow('saliency map', cv2.bitwise_and(img,img,mask=mask))
    cv2.waitKey()
    cv2.destroyAllWindows()

效果如下:
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

  • 10
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值