基于Permutohedral Lattice 的Bilateral filter 源码及部分注释【C++】


基于Permutohedral Lattice 的Bilateral filter 源码及部分注释【来自于网络】


实现基于论文《Fast High-Dimensional Filtering Using the Permutohedral Lattice》 .

延伸阅读 saliency filters精读之permutohedral lattice 

1.bilateralPermutohedral 方法:

static Mat bilateralPermutohedral(Mat img, Mat edge, float sigma_s, float sigma_r)  // img 和 edge 都必须是CV_32F类型
    {

        float invSpatialStdev = 1.0f / sigma_s;
        float invColorStdev = 1.0f / sigma_r;

        // Construct the position vectors out of x, y, r, g, and b.
        int height = img.rows;
        int width = img.cols;

        int eCh = edge.channels();	// 1 或 3
        int iCh = img.channels();
        Image positions(1, width, height, 2 + eCh);	// 只有一个子窗口
        Image input(1, width, height, iCh);

        //From Mat to Image
        for (int y = 0; y < height; y++)
        {
            float *pimg = img.ptr<float>(y);
            float *pedge = edge.ptr<float>(y);
            for (int x = 0; x < width; x++)
            {
				// 参考论文 p4 3.1
				// 5维的 positiion vector
                positions(x, y)[0] = invSpatialStdev * x;		// 0
                positions(x, y)[1] = invSpatialStdev * y;		// 1

                for(int c = 0; c < eCh; c++)
                    positions(x, y)[2 + c] = invColorStdev * pedge[x * eCh + c];	// 2+
				// 3维的 input vector
                for(int c = 0; c < iCh; c++)
                    input(x, y)[c] = pimg[x * iCh + c];
            }
        }

        // Filter the input with respect to the position vectors. (see permutohedral.h)
        Image out = PermutohedralLattice::filter(input, positions);

        // Save the result
        Mat imgOut(img.size(), img.type());
        for (int y = 0; y < height; y++)
        {
            float *pimgOut = imgOut.ptr<float>(y);
            for (int x = 0; x < width; x++)
            {
                for(int c = 0; c < iCh; c++)
                    pimgOut[x * iCh + c] = out(x, y)[c];
            }
        }

        return imgOut;
    }


2. PermutohedralLattice 类:

/***************************************************************/
/* The algorithm class that performs the filter
 *
 * PermutohedralLattice::filter(...) does all the work.
 *
 */
/***************************************************************/
class PermutohedralLattice
{
public:

    /* Filters given image against a reference image.
     *   im : image to be bilateral-filtered. (input vector)
     *   ref : reference image whose edges are to be respected. (position vector)
     */
    static Image filter(Image im, Image ref)
    {
        //timeval t[5];

        // Create lattice
        // gettimeofday(t+0, NULL);
		// d = ref.channels            (5)
		// vd = im.channels + 1    (3+1)
        PermutohedralLattice lattice(ref.channels, im.channels + 1, im.width * im.height * im.frames);

        // Splat into the lattice
        // gettimeofday(t+1, NULL);
        //	printf("Splatting...\n");

        float *col = new float[im.channels + 1];
        col[im.channels] = 1; // homogeneous coordinate

        float *imPtr = im(0, 0, 0);
        float *refPtr = ref(0, 0, 0);	// position vector
        for (int t = 0; t < im.frames; t++)
        {
            for (int y = 0; y < im.height; y++)
            {
                for (int x = 0; x < im.width; x++)
                {
                    for (int c = 0; c < im.channels; c++)
                    {
                        col[c] = *imPtr++;
                    }
                    lattice.splat(refPtr, col);
                    refPtr += ref.channels;
                }
            }
        }

        // Blur the lattice
        // gettimeofday(t+2, NULL);
        //	printf("Blurring...");
        lattice.blur();

        // Slice from the lattice
        // gettimeofday(t+3, NULL);
        //	printf("Slicing...\n");
        Image out(im.frames, im.width, im.height, im.channels);

        lattice.beginSlice();
        float *outPtr = out(0, 0, 0);
        for (int t = 0; t < im.frames; t++)
        {
            for (int y = 0; y < im.height; y++)
            {
                for (int x = 0; x < im.width; x++)
                {
                    lattice.slice(col);
                    float scale = 1.0f / col[im.channels];
                    for (int c = 0; c < im.channels; c++)
                    {
                        *outPtr++ = col[c] * scale;
                    }
                }
            }
        }

        // Print time elapsed for each step
        //    gettimeofday(t+4, NULL);
        //     const char *names[4] = {"Init  ", "Splat ", "Blur  ", "Slice "};
        //     for (int i = 1; i < 5; i++)
        //       printf("%s: %3.3f ms\n", names[i-1], (t[i].tv_sec - t[i-1].tv_sec) +
        // 	     (t[i].tv_usec - t[i-1].tv_usec)/1000000.0);

        return out;
    }

    /* Constructor
     *     d_ : dimensionality of key vectors (ref.channels)
     *    vd_ : dimensionality of value vectors (im.channels + 1)
     *    nData_ : number of points in the input (im.size * im.frames)
     */
    PermutohedralLattice(int d_, int vd_, int nData_) :
        d(d_), vd(vd_), nData(nData_), hashTable(d_, vd_)
    {
        // Allocate storage for various arrays
        elevated = new float[d + 1];
        scaleFactor = new float[d];

        greedy = new short[d + 1];
        rank = new char[d + 1];
        barycentric = new float[d + 2];
        replay = new ReplayEntry[nData * (d + 1)];
        nReplay = 0;
        canonical = new short[(d + 1) * (d + 1)];
        key = new short[d + 1];

        // compute the coordinates of the canonical simplex, in which
        // the difference between a contained point and the zero
        // remainder vertex is always in ascending order. (See pg.4 of paper.)
		// 论文第四页,d=4的矩阵例子(列主序)
        for (int i = 0; i <= d; i++)
        {
            for (int j = 0; j <= d - i; j++)
                canonical[i * (d + 1) + j] = i;
            for (int j = d - i + 1; j <= d; j++)
                canonical[i * (d + 1) + j] = i - (d + 1);
        }

        // Compute parts of the rotation matrix E. (See pg.4-5 of paper.)
        for (int i = 0; i < d; i++)
        {
            // the diagonal entries for normalization
            scaleFactor[i] = 1.0f / (sqrtf( (float)(i + 1) * (i + 2) ));

            /* We presume that the user would like to do a Gaussian blur of standard deviation
             * 1 in each dimension (or a total variance of d, summed over dimensions.)
             * Because the total variance of the blur performed by this algorithm is not d,
             * we must scale the space to offset this.
             *
             * The total variance of the algorithm is (See pg.6 and 10 of paper):
             *  [variance of splatting] + [variance of blurring] + [variance of splatting]
             *   = d(d+1)(d+1)/12 + d(d+1)(d+1)/2 + d(d+1)(d+1)/12
             *   = 2d(d+1)(d+1)/3.
             *
             * So we need to scale the space by (d+1)sqrt(2/3).
             */
			 // 论文 第四页 scale position vector
            scaleFactor[i] *= (d + 1) * sqrtf(2.0 / 3);
        }
    }


    /* Performs splatting with given position and value vectors */
	// position: d-dimension position vector
	// value: [r, g, b, 1]
    void splat(float *position, float *value)
    {
        // first rotate position into the (d+1)-dimensional hyperplane
		// 论文 第五页 Ex计算
        elevated[d] = -d * position[d - 1] * scaleFactor[d - 1];
        for (int i = d - 1; i > 0; i--)
            elevated[i] = (elevated[i + 1] -
                           i * position[i - 1] * scaleFactor[i - 1] +
                           (i + 2) * position[i] * scaleFactor[i]);
        elevated[0] = elevated[1] + 2 * position[0] * scaleFactor[0];

        // prepare to find the closest lattice points
        float scale = 1.0f / (d + 1);
        char *myrank = rank;
        short *mygreedy = greedy;

        // greedily search for the closest zero-colored lattice point
		// 论文 第三页
        int sum = 0;
        for (int i = 0; i <= d; i++)
        {
            float v = elevated[i] * scale;
            float up = ceilf(v) * (d + 1);	// 查找最近的整数点,up / down
            float down = floorf(v) * (d + 1);

            if (up - elevated[i] < elevated[i] - down) 
				mygreedy[i] = (short)up;
            else 
				mygreedy[i] = (short)down;

            sum += mygreedy[i];
        }
        sum /= d + 1;	// consistent remainder (d+1)

        // rank differential to find the permutation between this simplex and the canonical one.
        // (See pg. 3-4 in paper.)
		// 相对差值小的rank++
        memset(myrank, 0, sizeof(char) * (d + 1));
        for (int i = 0; i < d; i++)
            for (int j = i + 1; j <= d; j++)
                if (elevated[i] - mygreedy[i] < elevated[j] - mygreedy[j])
					myrank[i]++;
                else 
					myrank[j]++;

        if (sum > 0)
        {
            // sum too large - the point is off the hyperplane.
            // need to bring down the ones with the smallest differential
            for (int i = 0; i <= d; i++)
            {
                if (myrank[i] >= d + 1 - sum)
                {
                    mygreedy[i] -= d + 1;
                    myrank[i] += sum - (d + 1);
                }
                else
                    myrank[i] += sum;
            }
        }
        else if (sum < 0)
        {
            // sum too small - the point is off the hyperplane
            // need to bring up the ones with largest differential
            for (int i = 0; i <= d; i++)
            {
                if (myrank[i] < -sum)
                {
                    mygreedy[i] += d + 1;
                    myrank[i] += (d + 1) + sum;
                }
                else
                    myrank[i] += sum;
            }
        }

        // Compute barycentric coordinates (See pg.10 of paper.)
        memset(barycentric, 0, sizeof(float) * (d + 2));
        for (int i = 0; i <= d; i++)
        {
            barycentric[d - myrank[i]] += (elevated[i] - mygreedy[i]) * scale;
            barycentric[d + 1 - myrank[i]] -= (elevated[i] - mygreedy[i]) * scale;
        }
        barycentric[0] += 1.0f + barycentric[d + 1];

        // Splat the value into each vertex of the simplex, with barycentric weights.
        for (int remainder = 0; remainder <= d; remainder++)
        {
            // Compute the location of the lattice point explicitly (all but the last coordinate - it's redundant because they sum to zero)
            for (int i = 0; i < d; i++)
                key[i] = mygreedy[i] + canonical[remainder * (d + 1) + myrank[i]];

            // Retrieve pointer to the value at this vertex.
            float *val = hashTable.lookup(key, true);

            // Accumulate values with barycentric weight.
            for (int i = 0; i < vd; i++)
                val[i] += barycentric[remainder] * value[i];

            // Record this interaction to use later when slicing
            replay[nReplay].offset = val - hashTable.getValues();
            replay[nReplay].weight = barycentric[remainder];
            nReplay++;

        }
    }

    // Prepare for slicing
    void beginSlice()
    {
        nReplay = 0;
    }

    /* Performs slicing out of position vectors. Note that the barycentric weights and the simplex
     * containing each position vector were calculated and stored in the splatting step.
     * We may reuse this to accelerate the algorithm. (See pg. 6 in paper.)
     */
    void slice(float *col)
    {
        float *base = hashTable.getValues();
        for (int j = 0; j < vd; j++) 
			col[j] = 0;
        for (int i = 0; i <= d; i++)
        {
            ReplayEntry r = replay[nReplay++];
            for (int j = 0; j < vd; j++)
            {
                col[j] += r.weight * base[r.offset + j];
            }
        }
    }

    /* Performs a Gaussian blur along each projected axis in the hyperplane. */
    void blur()
    {
        // Prepare arrays
        short *neighbor1 = new short[d + 1];
        short *neighbor2 = new short[d + 1];
        float *newValue = new float[vd * hashTable.size()];
        float *oldValue = hashTable.getValues();
        float *hashTableBase = oldValue;

        float *zero = new float[vd];
        for (int k = 0; k < vd; k++) 
			zero[k] = 0;

        // For each of d+1 axes,
        for (int j = 0; j <= d; j++)
        {
            printf("blur %d\t", j);
            fflush(stdout);

            // For each vertex in the lattice,
            for (int i = 0; i < hashTable.size(); i++)   // blur point i in dimension j
            {
                short *key    = hashTable.getKeys() + i * (d); // keys to current vertex
                for (int k = 0; k < d; k++)
                {
                    neighbor1[k] = key[k] + 1;
                    neighbor2[k] = key[k] - 1;
                }
                neighbor1[j] = key[j] - d;
                neighbor2[j] = key[j] + d; // keys to the neighbors along the given axis.

                float *oldVal = oldValue + i * vd;
                float *newVal = newValue + i * vd;

                float *vm1, *vp1;
				//printf("first neighbor\n");
                vm1 = hashTable.lookup(neighbor1, false); // look up first neighbor
                if (vm1) 
					vm1 = vm1 - hashTableBase + oldValue;
                else 
					vm1 = zero;
				//printf("second neighbor\n");
                vp1 = hashTable.lookup(neighbor2, false); // look up second neighbor
                if (vp1) 
					vp1 = vp1 - hashTableBase + oldValue;
                else 
					vp1 = zero;

                // Mix values of the three vertices
                for (int k = 0; k < vd; k++)
                    newVal[k] = (0.25f * vm1[k] + 0.5f * oldVal[k] + 0.25f * vp1[k]);
            }
            float *tmp = newValue;
            newValue = oldValue;
            oldValue = tmp;
            // the freshest data is now in oldValue, and newValue is ready to be written over
        }

        // depending where we ended up, we may have to copy data
        if (oldValue != hashTableBase)
        {
            memcpy(hashTableBase, oldValue, hashTable.size()*vd * sizeof(float));
            delete oldValue;
        }
        else
        {
            delete newValue;
        }
        printf("\n");

        delete zero;
        delete neighbor1;
        delete neighbor2;
    }

private:

    int d, vd, nData;
    float *elevated, *scaleFactor, *barycentric;
    short *canonical;
    short *key;

    // slicing is done by replaying splatting (ie storing the sparse matrix)
    struct ReplayEntry
    {
        int offset;
        float weight;
    } *replay;
    int nReplay, nReplaySub;

public:
    char  *rank;
    short *greedy;
    HashTablePermutohedral hashTable;
};


3. 用于permutohedral lattice的哈希表:

/***************************************************************/
/* Hash table implementation for permutohedral lattice
 *
 * The lattice points are stored sparsely using a hash table.
 * The key for each point is its spatial location in the (d+1)-
 * dimensional space.
 */
/***************************************************************/

class HashTablePermutohedral
{
public:
    /* Constructor
     *  kd_: the dimensionality of the position vectors on the hyperplane.
     *  vd_: the dimensionality of the value vectors
     */
    HashTablePermutohedral(int kd_, int vd_) : kd(kd_), vd(vd_)
    {
        capacity = 1 << 15;
        filled = 0;
        entries = new Entry[capacity];
        keys = new short[kd * capacity / 2];		// 多维 键-值对
        values = new float[vd * capacity / 2];
        memset(values, 0, sizeof(float)*vd * capacity / 2);
    }

    // Returns the number of vectors stored.
    int size()
    {
        return filled;
    }

    // Returns a pointer to the keys array.
    short *getKeys()
    {
        return keys;
    }

    // Returns a pointer to the values array.
    float *getValues()
    {
        return values;
    }

    /* Returns the index into the hash table for a given key.
     *  key: a pointer to the position vector.
     *  h: hash of the position vector.
     *  create: a flag specifying whether an entry should be created,
     *          should an entry with the given key not found.
     */
	// 返回 value 指针的偏移量
    int lookupOffset(short *key, size_t h, bool create = true)
    {

        // Double hash table size if necessary
		// 如果存储的数据达到或超过容量的一半
        if (filled >= (capacity / 2) - 1)
        {
            grow();
        }

        // Find the entry with the given key
		// 根据给定的 hash 索引 entry
        while (1)
        {
            Entry e = entries[h];
            // check if the cell is empty
			// 检查该 entry 的 key 是否存在
            if (e.keyIdx == -1)
            {
                if (!create) 
					return -1; // Return not found.

                // need to create an entry. Store the given key.
                for (int i = 0; i < kd; i++)
                    keys[filled * kd + i] = key[i];

                e.keyIdx = filled * kd;
                e.valueIdx = filled * vd;
                entries[h] = e;
                filled++;

                return e.valueIdx;
            }

            // check if the cell has a matching key
            bool match = true;
            for (int i = 0; i < kd && match; i++)
                match = keys[e.keyIdx + i] == key[i];
            if (match)
                return e.valueIdx;

            // increment the bucket with wraparound
			// 顺序查找下一个 entry 【计算出的hash值相同的情况】
            h++;
			// 如果到达最后一个 entry, 则从第一个 entry 开始找
            if (h == capacity) 
				h = 0;
        }
    }

    /* Looks up the value vector associated with a given key vector.
     *  k : pointer to the key vector to be looked up.
     *  create : true if a non-existing key should be created.
     */
    float *lookup(short *k, bool create = true)
    {
        size_t h = hash(k) % capacity;
        int offset = lookupOffset(k, h, create);
        if (offset < 0) 
			return NULL;
        else 
			return values + offset;
    };

    /* Hash function used in this implementation. A simple base conversion. */
    size_t hash(const short *key)
    {
        size_t k = 0;
        for (int i = 0; i < kd; i++)
        {
            k += key[i];
            k *= 2531011;
        }
        return k;
    }

private:
    /* Grows the size of the hash table */
    void grow()
    {
        printf("Resizing hash table\n");

        size_t oldCapacity = capacity;
        capacity *= 2;	// 变为2倍容量

        // Migrate the value vectors.
        float *newValues = new float[vd * capacity / 2];
        memset(newValues, 0, sizeof(float)*vd * capacity / 2);
        memcpy(newValues, values, sizeof(float)*vd * filled);
        delete[] values;
        values = newValues;

        // Migrate the key vectors.
        short *newKeys = new short[kd * capacity / 2];
        memcpy(newKeys, keys, sizeof(short)*kd * filled);
        delete[] keys;
        keys = newKeys;

        Entry *newEntries = new Entry[capacity];

        // Migrate the table of indices.
        for (size_t i = 0; i < oldCapacity; i++)
        {
            if (entries[i].keyIdx == -1) 
				continue;
			// 根据键值计算hash
            size_t h = hash(keys + entries[i].keyIdx) % capacity;
			// 如果hash对应entry的keyidx已经被占用,则顺序往后找 entry,直到发现该 entry 的 keyidx 未被占用
            while (newEntries[h].keyIdx != -1)
            {
                h++;
                if (h == capacity)
					h = 0;
            }
            newEntries[h] = entries[i];
        }
        delete[] entries;
        entries = newEntries;
    }

    // Private struct for the hash table entries.
    struct Entry
    {
        Entry() : keyIdx(-1), valueIdx(-1) {}
        int keyIdx;		// keys 的索引
        int valueIdx;	// values 的索引
    };

    short *keys;
    float *values;
    Entry *entries;
    size_t capacity, filled;	// 分别表示 entry 的容量 和 已填充的 entry 数
    int kd, vd;	 // keys 和 values 数组的维度(PermutohedraLattice 会将数据 splat 到高维空间)
};


效果图:





  • 1
    点赞
  • 11
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

ShaderJoy

您的打赏是我继续写博客的动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值