Unity FFT海水渲染效果展示

最近研究了一下FFT的海水渲染,发现这东西真的涉及到好多好多的数学知识,我基本就是囫囵吞枣,现在基本也忘得差不多了。我在这里挂几个大佬的文章,想研究的可以去看一下,我这个数学渣渣就不强行解释了。
fft海面模拟(一、二、三)
fft ocean注解

Ocean simulation part one: using the discrete Fourier transform
快速傅立叶变换(FFT)的海面模拟
Shader相册第6期 — 实时水面模拟与渲染
【学习笔记】Unity 基于GPU FFT海洋的实现-理论篇
【学习笔记】Unity 基于GPU FFT海洋的实现-实践篇
我自己写了两个版本,一个是CPU得,另一个是GPU版本。
先看一下CPU版本:
在这里插入图片描述
这个波浪真得相当好看了,但是这个CPU版本的计算是和顶点绑定在一起的,模型的顶点数没办法太多,我这是64x64,再上去就要卡得不行了。
接下来的是GPU版本,这里就是照搬了后面两篇文章的做法,可以看看效果:
在这里插入图片描述
CSDN不能传超过4.5M的图片,下面是视频:

Unity ComputerShader FFT海水渲染

效果真的很不错。
我放一下CPU版本的代码,因为总感觉没代码怪怪的。
代码用的是第三第四篇文章的,涉及到的数学知识太多了,想实现的话还是需要花点时间和耐心得。

using System;
using System.Collections.Generic;
using System.Numerics;
using UnityEngine;
using Random = System.Random;
using Vector2 = UnityEngine.Vector2;
using Vector3 = UnityEngine.Vector3;

public class Ocean
{
    public struct VertexOcean
    {
        public Vector3[] vertices;// vertex
        public Vector3[] normals;// normal
        public Vector2[] uvs;//UV
        public Vector2[] htildes;// htilde0
        public Vector2[] conjugates;// htilde0mk conjugate
        public Vector3[] originals;// original position
        public Vector2[] uvs2;
    }

    public struct ComplexVectorNormal
    {
        // structure used with discrete fourier transform
        public Complex h;// wave height
        public Vector2 D;// displacement
        public Vector3 n;// normal
    }

    float g = 9.81f;                                // gravity constant
    int N, Nplus1;                          // dimension -- N should be a power of 2
    float A;                                // phillips spectrum parameter -- affects heights of waves
    Vector2 w;                              // wind parameter
    float length;                           // length parameter
    Random rand = new Random();

    Complex[] tilde, tildeSlopeX, tildeSlopeZ, tildeDx, tildeDz;
    FFT fft;                              // fast fourier transform

    public VertexOcean oceanData;                    // vertices for vertex buffer object

    public List<int> indices;                  // indicies for vertex buffer object

    int debug = 0;
    

    //M=N,L=Lx=Lz
    public Ocean(int N, float A, Vector2 w, float length)
    {
        tilde = new Complex[N * N];
        tildeSlopeX = new Complex[N * N];
        tildeSlopeZ = new Complex[N * N];
        tildeDx = new Complex[N * N];
        tildeDz = new Complex[N * N];

        fft = new FFT(N);

        indices = new List<int>();

        this.N = N;
        Nplus1 = N + 1;
        this.A = A;
        this.w = w;
        this.length = length;

        oceanData = new VertexOcean();
        oceanData.vertices = new Vector3[Nplus1 * Nplus1];
        oceanData.normals = new Vector3[Nplus1 * Nplus1];
        oceanData.uvs = new Vector2[Nplus1 * Nplus1];
        oceanData.htildes = new Vector2[Nplus1 * Nplus1];
        oceanData.conjugates = new Vector2[Nplus1 * Nplus1];
        oceanData.originals = new Vector3[Nplus1 * Nplus1];
        oceanData.uvs2 = new Vector2[Nplus1 * Nplus1];

        int index;
        Complex htilde0, htilde0mk_conj;
        for (int m = 0; m < Nplus1; m++)
        {
            for (int n = 0; n < Nplus1; n++)
            {
                index = m * Nplus1 + n;

                htilde0 = hTilde_0(n, m);
                htilde0mk_conj = Complex.Conjugate(hTilde_0(-n, -m));

                oceanData.htildes[index].x = (float)htilde0.Real;
                oceanData.htildes[index].y = (float)htilde0.Imaginary;
                oceanData.conjugates[index].x = (float)htilde0mk_conj.Real;
                oceanData.conjugates[index].y = (float)htilde0mk_conj.Imaginary;

                oceanData.originals[index].x = oceanData.vertices[index].x = (n - N / 2.0f) * length / N;
                oceanData.originals[index].y = oceanData.vertices[index].y = 0;
                oceanData.originals[index].z = oceanData.vertices[index].z = (m - N / 2.0f) * length / N;

                oceanData.normals[index].x = 0.0f;
                oceanData.normals[index].y = 1.0f;
                oceanData.normals[index].z = 0.0f;

                oceanData.uvs[index].x = (float)n / N;
                oceanData.uvs[index].y = (float)m / N;
            }
        }

        for (int m = 0; m < N; m++)
        {
            for (int n = 0; n < N; n++)
            {
                index = m * Nplus1 + n;

                indices.Add(index);
                indices.Add(index + Nplus1);
                indices.Add(index + Nplus1 + 1);
                indices.Add(index);
                indices.Add(index + Nplus1 + 1);
                indices.Add(index + 1);
            }
        }
    }

    //高斯随机数
    public Complex GaussRandomVariable()
    {
        float s = 0, u = 0, v = 0;
        while (s > 1 || s == 0)
        {
            u = (float)rand.NextDouble() * 2 - 1;
            v = (float)rand.NextDouble() * 2 - 1;

            s = u * u + v * v;
        }

        s = Mathf.Sqrt(-2 * Mathf.Log(s) / s);
        return new Complex(u * s, v * s);
    }

    //菲利普频谱
    public float Phillips(int n, int m)
    {
        Vector2 k = new Vector2(Mathf.PI * (2 * n - N) / length, Mathf.PI * (2 * m - N) / length);
        float k_Length = k.magnitude;

        if (k_Length < 0.000001) return 0.0f;

        float k_Length2 = k_Length * k_Length;
        float k_Length4 = k_Length2 * k_Length2;

        float KdotW = Vector2.Dot(k.normalized, w.normalized);
        float KdotW2 = KdotW * KdotW;

        float w_Length = w.magnitude;
        float L = w_Length * w_Length / g;
        float L2 = L * L;

        float damping = 0.001f;
        float l2 = L2 * damping * damping;

        //这里为什么要再乘个 Mathf.Exp(-k_Length * l2)?
        float phillips = A * Mathf.Exp(-1.0f / (k_Length2 * L2)) / k_Length4 * KdotW2
                * Mathf.Exp(-k_Length2 * l2);
        return phillips;
    }

    public Complex hTilde_0(int n, int m)
    {
        Complex r = GaussRandomVariable();
        return r * Mathf.Sqrt(Phillips(n, m) / 2.0f);
    }

    Complex hTilde(float t, int n, int m)
    {
        int index = m * Nplus1 + n;

        Complex htilde0 = new Complex(oceanData.htildes[index].x, oceanData.htildes[index].y);
        Complex htilde0mk_conj= new Complex(oceanData.conjugates[index].x, oceanData.conjugates[index].y);

        float omegat = Dispersion(n, m) * t;

        float cos = Mathf.Cos(omegat);
        float sin = Mathf.Sin(omegat);

        Complex c0 = new Complex(cos, sin);
        Complex c1 = new Complex(cos, -sin);

        Complex res = htilde0 * c0 + htilde0mk_conj * c1;

        return res;
    }

    public float Dispersion(int n, int m)
    {
        float w0 = 2.0f * Mathf.PI / 200.0f;
        float kx = Mathf.PI * (2 * n - N) / length;
        float kz = Mathf.PI * (2 * m - N) / length;
        //wo是用来做缩放的?
        return Mathf.Floor(Mathf.Sqrt(g * Mathf.Sqrt(kx * kx + kz * kz) / w0)) * w0;
    }

    public ComplexVectorNormal ComputerHDN(Vector2 x,float t)
    {
        Complex h = new Complex();
        Vector2 D = Vector2.zero;
        Vector3 normal = Vector3.zero;

        Complex c, htilde;
        Vector2 k;
        float kx, kz, k_Length, KdotX;

        for (int m = 0; m < N; m++) 
        {
            kz = 2.0f * Mathf.PI * (m - N / 2.0f) / length;
            for (int n = 0; n < N; n++)
            {
                kx = 2.0f * Mathf.PI * (n - N / 2.0f) / length;
                k = new Vector2(kx, kz);

                k_Length = k.magnitude;
                KdotX = Vector2.Dot(k, x);

                c = new Complex(Mathf.Cos(KdotX), Mathf.Sin(KdotX));
                htilde = hTilde(t, n, m) * c;

                h += htilde;

                normal += new Vector3(-kx * (float)htilde.Imaginary, 0.0f, 
                    -kz * (float)htilde.Imaginary);

                if (k_Length < 0.000001) continue;
                D += new Vector2(kx / k_Length * (float)htilde.Imaginary, kz / k_Length * 
                    (float)htilde.Imaginary);
            }
        }

        normal = new Vector3(0f, 1f, 0f) - normal;
        ComplexVectorNormal complex=new ComplexVectorNormal();
        complex.h = h;
        complex.D = D;
        complex.n = normal.normalized;
        return complex;
    }

    public void EvaluateWaves(float t)
    {
        float kx, kz, len, lambda = -1.0f;
        int index, index1;

        for (int m = 0; m < N; m++)
        {
            kz = Mathf.PI * (2.0f * m - N) / length;
            for (int n = 0; n < N; n++)
            {
                kx = Mathf.PI * (2.0f * n - N) / length;
                len = Mathf.Sqrt(kx * kx + kz * kz);
                index= m * N + n;

                tilde[index] = hTilde(t, n, m);
                tildeSlopeX[index] = tilde[index] * new Complex(0, kx);
                tildeSlopeZ[index] = tilde[index] * new Complex(0, kz);

                if(len<0.000001f)
                {
                    tildeDx[index] = new Complex(0f, 0f);
                    tildeDz[index] = new Complex(0f, 0f);
                }
                else
                {
                    tildeDx[index] = tilde[index] * new Complex(0f, -kx / len);
                    tildeDz[index] = tilde[index] * new Complex(0f, -kz / len);
                }
            }
        }

        for (int m = 0; m < N; m++)
        {
            fft.ComputeFFT(tilde, ref tilde, 1, m * N);
            fft.ComputeFFT(tildeSlopeX, ref tildeSlopeX, 1, m * N);
            fft.ComputeFFT(tildeSlopeZ, ref tildeSlopeZ, 1, m * N);
            fft.ComputeFFT(tildeDx, ref tildeDx, 1, m * N);
            fft.ComputeFFT(tildeDz, ref tildeDz, 1, m * N);
        }

        for (int n = 0; n < N; n++)
        {
            fft.ComputeFFT(tilde, ref tilde, N, n);
            fft.ComputeFFT(tildeSlopeX, ref tildeSlopeX, N, n);
            fft.ComputeFFT(tildeSlopeZ, ref tildeSlopeZ, N, n);
            fft.ComputeFFT(tildeDx, ref tildeDx, N, n);
            fft.ComputeFFT(tildeDz, ref tildeDz, N, n);
        }

        int sign;
        int[] signs = { 1, -1 };
        Vector3 normal;
        Vector3[] offsets = ComputeOffset();
        debug = 0;
        for (int m = 0; m < N; m++)
        {
            for (int n = 0; n < N; n++)
            {
                index = m * N + n;
                index1 = m * Nplus1 + n;

                //如果n+m为奇数,则sign=signs[1]=-1;如果n+m为偶数,则sign=signs[0]=1
                sign = signs[(n + m) & 1];

                tilde[index] = tilde[index] * sign;

                tildeDx[index] = tildeDx[index] * sign;
                tildeDz[index] = tildeDz[index] * sign;


                Vector3 offset = new Vector3(lambda * (float)tildeDx[index].Real, (float)tilde[index].Real,
                    lambda * (float)tildeDz[index].Real);

                oceanData.vertices[index1].y = offset.y;
                oceanData.vertices[index1].x = oceanData.originals[index1].x + offset.x;
                oceanData.vertices[index1].z = oceanData.originals[index1].z + offset.z;

                // normal
                tildeSlopeX[index] = tildeSlopeX[index] * sign;
                tildeSlopeZ[index] = tildeSlopeZ[index] * sign;
                normal = new Vector3(-(float)tildeSlopeX[index].Real, 1f, -(float)tildeSlopeZ[index].Real);
                oceanData.normals[index1] = normal;

                Vector2 jaco=new Vector2(ComputeJacobian(n, m, offsets),0);
                //DebugLog(jaco.x.ToString());
                oceanData.uvs2[index1] = jaco;

                int number;
                if (n == 0 && m == 0)
                {
                    number = index1 + N + Nplus1 * N;
                    oceanData.vertices[number].y = offset.y;
                    oceanData.vertices[number].x = oceanData.originals[number].x + offset.x;
                    oceanData.vertices[number].z = oceanData.originals[number].z + offset.z;
                    oceanData.normals[number] = normal;
                    oceanData.uvs2[number] = jaco;
                }

                if (n == 0)
                {
                    number = index1 + N;
                    oceanData.vertices[number].y = offset.y;
                    oceanData.vertices[number].x = oceanData.originals[number].x + offset.x;
                    oceanData.vertices[number].z = oceanData.originals[number].z + offset.z;
                    oceanData.normals[number] = normal;
                    oceanData.uvs2[number] = jaco;
                }

                if (m == 0)
                {
                    number = index1 + Nplus1 * N;
                    oceanData.vertices[number].y = offset.y;
                    oceanData.vertices[number].x = oceanData.originals[number].x + offset.x;
                    oceanData.vertices[number].z = oceanData.originals[number].z + offset.z;
                    oceanData.normals[number] = normal;
                    oceanData.uvs2[number] = jaco;
                }
            }
        }
    }

    void DebugLog(string a)
    {
        if (debug > 500) return;
        debug++;
        Debug.Log(a);
    }

    Vector3[] ComputeOffset()
    {
        Vector3[] offsets = new Vector3[N*N];
        float lambda = -1.0f;
        int index;
        int[] signs = { 1, -1 };
        for (int m = 0; m < N; m++)
        {
            for (int n = 0; n < N; n++)
            {
                index = m * N + n;
                int sign = signs[(n + m) & 1];
                offsets[index] = new Vector3(lambda * (float)(tildeDx[index] * sign).Real,
                    (float)(tilde[index] * sign).Real,
                    lambda * (float)(tildeDz[index] * sign).Real);
            }
        }
        return offsets;
    }

    float ComputeJacobian(int n, int m, Vector3[] offsets)
    {
        int up = m - 1 < 0 ? (N - 1) * N + n : (m - 1) * N + n;
        int down = m + 1 > N - 1 ? n : (m + 1) * N + n;
        int left = n - 1 < 0 ? m * N + (N - 1) : m * N + n - 1;
        int right = n + 1 > N - 1 ? m * N : m * N + n + 1;

        Vector3 ddx = offsets[down] - offsets[up];
        Vector3 ddz = offsets[right] - offsets[left];

        float jacobian = (1.0f + ddx.x) * (1.0f + ddz.z) - ddx.z * ddz.x;
        return jacobian;
    }
}

using System.Numerics;
using UnityEngine;

public class FFT
{
    public struct ComplexList
    {
        public Complex[] complices;
    }

    int N, which;
    int log2N;
    float pi2=2*Mathf.PI;
    uint[] reversed;
    ComplexList[] T;
    ComplexList[] c = new ComplexList[2];

    public FFT(int N)
    {
        this.N = N;
        c[0].complices = new Complex[N];
        c[1].complices = new Complex[N];

        log2N = (int)(Mathf.Log(N) / Mathf.Log(2));
        //Debug.Log(log2N);

        reversed = new uint[N];
        for (int i = 0; i < N; i++)
        {
            reversed[i] = Reverse((uint)i);
            //Debug.Log(reversed[i]);
        }

        int pow2 = 1;
        T = new ComplexList[log2N];
        for (int i = 0; i < log2N; i++)
        {
            T[i].complices = new Complex[pow2];
            for (int j = 0; j < pow2; j++)
            {
                T[i].complices[j] = ComputeT(j, pow2 * 2);
            }
            pow2 *= 2;
        }

        which = 0;
    }

    //bitreverse算法
    public uint Reverse(uint i)
    {
        uint res = 0;
        for(int j=0;j<log2N;j++)
        {
            res = (res << 1) + (i & 1);
            i >>= 1;
        }
        return res;
    }

    public Complex ComputeT(int x, int n)
    {
        //Debug.Log(Mathf.Cos((float)pi2 * x / n));
        return new Complex(Mathf.Cos(pi2 * x / n), Mathf.Sin(pi2 * x / n));
    }

    public void ComputeFFT(Complex[] input, ref Complex[] output,int stride,int offset)
    {
        for (int i = 0; i < N; i++) 
        {
            c[which].complices[i] = input[reversed[i] * stride + offset];
        }

        int loops = N >> 1;
        int size = 1 << 1;
        int sizeOver2 = 1;
        int w = 0;
        for (int i = 1; i <= log2N; i++)
        {
            //0,1
            which ^= 1;
            for (int j = 0; j < loops; j++)
            {
                for (int k = 0; k < sizeOver2; k++)
                {
                    c[which].complices[size * j + k] = c[which ^ 1].complices[size * j + k] +
                        c[which ^ 1].complices[size * j + sizeOver2 + k] * T[w].complices[k];
                }
                for (int k = sizeOver2; k < size; k++)
                {
                    c[which].complices[size * j + k] = c[which ^ 1].complices[size * j - sizeOver2 + k] -
                        c[which ^ 1].complices[size * j + k] * T[w].complices[k - sizeOver2];
                }
            }
            loops >>= 1;
            size <<= 1;
            sizeOver2 <<= 1;
            w++;
        }

        for (int i = 0; i < N; i++)
        {
            output[i * stride + offset] = c[which].complices[i];
        }
    }
}

using UnityEngine;

public class OceanManager : MonoBehaviour
{
    public int N = 32;
    public float WaveHeight=1f;                                
    public Vector2 Wind=new Vector2(1f,1f);
    public float Length=1f;

    Ocean _ocean;
    Mesh _mesh;
    
    void Start()
    {
        _ocean = new Ocean(N, WaveHeight, Wind, Length);
        _mesh = GetComponent<MeshFilter>().mesh;
    }

    private void FixedUpdate()
    {
        //Debug.Log(Time.time);
        _ocean.EvaluateWaves(Time.time);
        _mesh.vertices = _ocean.oceanData.vertices;
        _mesh.triangles = _ocean.indices.ToArray();
        _mesh.uv = _ocean.oceanData.uvs;
        _mesh.normals = _ocean.oceanData.normals;
        _mesh.uv2 = _ocean.oceanData.uvs2;
        //_mesh.RecalculateNormals();
    }
}

Shader用的是最后两篇的代码,加了一个深度渐变

Shader "Learn/GPUWater"
{
    Properties
    {
        _OceanColor ("Ocean Color",COLOR) = (1,1,1,1)
        _SpecualrColor ("Specular Color",COLOR) = (1,1,1,1)
        [HDR]_BubbleColor ("Bubble Color",COLOR) = (1,1,1,1)
        _MainTex ("Texture", 2D) = "white" {}
        _Gloss("Gloss", Range(8,200)) = 1
        _DepthVisibility ("_Depth Visibility", Float) = 1
    }
    SubShader
    {
        Tags { "RenderType"="Transparent" "Queue"="Transparent" "IgnoreProjector"="True"}
        LOD 100

        Pass
        {
            Blend SrcAlpha OneMinusSrcAlpha
            //Cull Off
            CGPROGRAM
            #pragma vertex vert
            #pragma fragment frag
            
            #include "UnityCG.cginc"
            #include "Lighting.cginc"
            
            struct appdata
            {
                float4 vertex : POSITION;
                float2 uv : TEXCOORD0;
            };
            
            struct v2f
            {
                float2 uv : TEXCOORD0;
                float4 vertex : SV_POSITION;
                float3 worldPos:TEXCOORD1;
                float4 screenPos:TEXCOORD2;
            };
            
            sampler2D _MainTex;
            sampler2D _Displace;
            sampler2D _Normal;
            float4 _MainTex_ST;
            fixed4 _BubbleColor;
            fixed4 _OceanColor;
            fixed4 _SpecualrColor;
            half _Gloss;
            half _DepthVisibility;

            sampler2D _CameraDepthTexture;
            
            v2f vert(appdata v)
            {
                v2f o;
                o.uv = TRANSFORM_TEX(v.uv, _MainTex);
                float4 dispalce=tex2Dlod(_Displace,float4(o.uv,0,0));
                v.vertex+=float4(dispalce.xyz,0);
                o.worldPos=mul(unity_ObjectToWorld,v.vertex).xyz;
                o.vertex = UnityWorldToClipPos(o.worldPos);
                o.screenPos=ComputeScreenPos(o.vertex);
                COMPUTE_EYEDEPTH(o.screenPos.z);
                return o;
            }
            
            fixed4 frag(v2f i) : SV_Target
            {
                float3 lightDir=normalize(UnityWorldSpaceLightDir(i.worldPos));
                float3 viewDir=normalize(UnityWorldSpaceViewDir(i.worldPos));
                float3 halfDir=normalize(lightDir+viewDir);
                float4 normalAndBubbles=tex2D(_Normal,i.uv);
                float3 normal=normalize(UnityObjectToWorldNormal(normalAndBubbles.xyz).rgb);

                float depth=SAMPLE_DEPTH_TEXTURE_PROJ(_CameraDepthTexture, UNITY_PROJ_COORD(i.screenPos));
                float sceneZ=max(0,LinearEyeDepth(depth)- _ProjectionParams.g);
                float partZ=max(0,i.screenPos.z- _ProjectionParams.g);
                float depthGap = sceneZ - partZ;
                float x=clamp(depthGap/_DepthVisibility,0,1);
                
                half NdotL=saturate(dot(lightDir,normal));
                fixed3 diffuse = tex2D(_MainTex, i.uv).rgb*NdotL*_OceanColor.rgb;
                fixed3 bubbles=_BubbleColor.rgb*saturate(dot(lightDir,normal));
                diffuse=lerp(diffuse,bubbles,normalAndBubbles.a);
                fixed3 specular=_LightColor0.rgb*_SpecualrColor.rgb*pow(
                    max(0,dot(normal,halfDir)),_Gloss);

                fixed3 col=diffuse+specular;
                return fixed4(col,_OceanColor.a*x);
            }
            ENDCG
        }
    }
}

评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值