unity computer shader基础。
#include "Platforms.cginc"
//这个是kernel的名字
#pragma kernel CSMatrixCopyPaster
//这是两个可以读写的buffer,一般由c#进行赋值
RWStructuredBuffer<float4x4> matricesFrom;
RWStructuredBuffer<float4x4> matricesTo;
uint matricesFromLength;
//numthreads是一种三维结构,如果你要算矩阵加法,那么定义一个4,4,1的结构可以自动匹配矩阵,加快计算。当然你也可以直接定义16,1,1的结构,那么就需要计算行列索引,多一步开销
[numthreads(THREADS,1,1)]
void CSMatrixCopyPaster (uint3 id : SV_DispatchThreadID)
{
//只看功能就是做了一层过滤,超过一定长度就过滤掉,这个id是在整个组内的索引,而x就是第一维,也就是如果超过这个下标就直接不要
if(id.x >= matricesFromLength)
return;
matricesTo[id.x] = matricesFrom[id.x];
}
#include "Platforms.cginc"
#pragma kernel CSMatrixMultiplier
RWStructuredBuffer<float4x4> matrices1;
RWStructuredBuffer<float4x4> matrices2;
RWStructuredBuffer<float4x4> resultMatrices;
uint resultMatricesLength;
[numthreads(THREADS,1,1)]
void CSMatrixMultiplier (uint3 id : SV_DispatchThreadID)
{
if(id.x >= resultMatricesLength)
return;
resultMatrices[id.x] = mul(matrices1[id.x], matrices2[id.x]);
}
#include "Platforms.cginc"
#pragma kernel CSMatrixPointMultiplier
RWStructuredBuffer<float4x4> matrices;
RWStructuredBuffer<float3> inPoints;
RWStructuredBuffer<float3> outPoints;
uint matricesLength;
[numthreads(THREADS,1,1)]
void CSMatrixPointMultiplier (uint3 id : SV_DispatchThreadID)
{
if(id.x >= matricesLength)
return;
outPoints[id.x] = mul(matrices[id.x], float4(inPoints[id.x], 1));
}
#include "Platforms.cginc"
#pragma kernel CSMatrixSelector
#pragma kernel CSPointsSelector
RWStructuredBuffer<uint> indices;
RWStructuredBuffer<float4x4> matrices;
RWStructuredBuffer<float4x4> selectedMatrices;
uint indicesLength;
//这里定义了两个kernel,第一个是从矩阵里面索引第几行或列,第二个是从坐标中获取第几个
[numthreads(THREADS,1,1)]
void CSMatrixSelector (uint3 id : SV_DispatchThreadID)
{
if(id.x >= indicesLength)
return;
selectedMatrices[id.x] = matrices[indices[id.x]];
}
RWStructuredBuffer<float3> points;
RWStructuredBuffer<float3> selectedPoints;
[numthreads(THREADS,1,1)]
void CSPointsSelector (uint3 id : SV_DispatchThreadID)
{
if(id.x >= indicesLength)
return;
selectedPoints[id.x] = points[indices[id.x]];
}
//这是kernel的基础框架
public class KernelBase : IPass
{
public bool IsEnabled { get; set; }
protected readonly int KernelId;
protected ComputeShader Shader { get; private set; }
//构造函数是shader的路径以及kernel的名字
public KernelBase(string shaderPath, string kernelName)
{
Shader = Resources.Load<ComputeShader>(shaderPath);
Assert.IsNotNull(Shader, "Can't load shader");
KernelId = Shader.FindKernel(kernelName);
IsEnabled = true;
}
//Dispatch,一个是id,其他就是xyz
public virtual void Dispatch()
{
if (!IsEnabled)
return;
if(Props.Count == 0)
CacheAttributes();
BindAttributes();
Shader.Dispatch(KernelId, GetGroupsNumX(), GetGroupsNumY(), GetGroupsNumZ());
}
public virtual void Dispose()
{
}
public virtual int GetGroupsNumX()
{
return 1;
}
public virtual int GetGroupsNumY()
{
return 1;
}
public virtual int GetGroupsNumZ()
{
return 1;
}
#region Reflection
protected readonly List<KeyValuePair<GpuData, object>> Props = new List<KeyValuePair<GpuData, object>>();
protected virtual void CacheAttributes()
{
Props.Clear();
获得这个类的全部属性,然后看属性是否定义了Attribute GpuData的标签,如果有,那么就从属性中获取这个标签
var properties = GetType().GetProperties();
foreach (var propertyInfo in properties)
{
if (!Attribute.IsDefined(propertyInfo, typeof(GpuData)))
continue;
var attribute = (GpuData) Attribute.GetCustomAttribute(propertyInfo, typeof(GpuData));
var obj = propertyInfo.GetValue(this, null);
Props.Add(new KeyValuePair<GpuData, object>(attribute, obj));
}
}
protected void BindAttributes()
{
for (var i = 0; i < Props.Count; i++)
{
var attribute = Props[i].Key;
var obj = Props[i].Value;
if (obj is IBufferWrapper)
{
var buffer = ((IBufferWrapper) obj).ComputeBuffer;
Shader.SetBuffer(KernelId, attribute.Name, buffer);
Shader.SetInt(attribute.Name + "Length", buffer.count);
}
else if (obj is Texture)
Shader.SetTexture(KernelId, attribute.Name, (Texture) obj);
else if (obj is GpuValue<int>)
Shader.SetInt(attribute.Name, ((GpuValue<int>)obj).Value);
else if (obj is GpuValue<float>)
Shader.SetFloat(attribute.Name, ((GpuValue<float>)obj).Value);
else if (obj is GpuValue<Vector3>)
Shader.SetVector(attribute.Name, ((GpuValue<Vector3>)obj).Value);
else if (obj is GpuValue<Color>)
Shader.SetVector(attribute.Name, ((GpuValue<Color>)obj).Value.ToVector());
else if (obj is GpuValue<bool>)
Shader.SetBool(attribute.Name, ((GpuValue<bool>)obj).Value);
else if (obj is GpuValue<GpuMatrix4x4>)
Shader.SetFloats(attribute.Name, ((GpuValue<GpuMatrix4x4>) obj).Value.Values);
}
}
KernelBase获取了所有定义了GpuData属性的变量,并且根据具体变量的类型,会将对应的值传给comptershader里的变量。
public class PrimitiveBase : IPass
{
private readonly List<IPass> passes = new List<IPass>();
protected void Bind()
{
CachePassAttributes();
CacheOwnAttributes();
BindAttributes();
}
public virtual void Dispatch()
{
for (var i = 0; i < passes.Count; i++)
passes[i].Dispatch();
}
public virtual void Dispose()
{
for (var i = 0; i < passes.Count; i++)
passes[i].Dispose();
}
public void AddPass(IPass pass)
{
passes.Add(pass);
}
public void RemovePass(IPass pass)
{
if (!passes.Contains(pass))
{
Debug.LogError("Can't find pass");
return;
}
passes.Remove(pass);
}
#region reflection
private readonly List<List<KeyValuePair<GpuData, PropertyInfo>>> passesAttributes = new List<List<KeyValuePair<GpuData, PropertyInfo>>>();
private readonly List<KeyValuePair<GpuData, PropertyInfo>> ownAttributes = new List<KeyValuePair<GpuData, PropertyInfo>>();
private void CachePassAttributes()
{
passesAttributes.Clear();
for (var i = 0; i < passes.Count; i++)
{
var pass = passes[i];
var passProperties = pass.GetType().GetProperties();
var passAttributes = new List<KeyValuePair<GpuData, PropertyInfo>>();
passesAttributes.Add(passAttributes);
for (var j = 0; j < passProperties.Length; j++)
{
var passProperty = passProperties[j];
if (!Attribute.IsDefined(passProperty, typeof(GpuData)))
continue;
var passAttribute = (GpuData)Attribute.GetCustomAttribute(passProperty, typeof(GpuData));
passAttributes.Add(new KeyValuePair<GpuData, PropertyInfo>(passAttribute, passProperty));
}
}
}
private void CacheOwnAttributes()
{
ownAttributes.Clear();
var properties = GetType().GetProperties();
for (var i = 0; i < properties.Length; i++)
{
var propertyInfo = properties[i];
if (!Attribute.IsDefined(propertyInfo, typeof(GpuData)))
continue;
var attribute = (GpuData)Attribute.GetCustomAttribute(propertyInfo, typeof(GpuData));
ownAttributes.Add(new KeyValuePair<GpuData, PropertyInfo>(attribute, propertyInfo));
}
}
protected void BindAttributes()
{
for (var i = 0; i < ownAttributes.Count; i++)
{
var ownAttribute = ownAttributes[i];
for (var j = 0; j < passesAttributes.Count; j++)
{
var passAttributes = passesAttributes[j];
for (var k = 0; k < passAttributes.Count; k++)
{
var passAttribute = passAttributes[k];
if (passAttribute.Key.Name.Equals(ownAttribute.Key.Name))
{
passAttribute.Value.SetValue(passes[j], ownAttribute.Value.GetValue(this, null), null);
}
}
}
}
}
#endregion
}
PrimitiveBase管理一堆KernelBase,并且会遍历所有的属性,如果和自己的属性的名字一样,那么就会把这个属性的值设置成自己的值。具体用处还要后面看具体功能。
public static class ComputeUtils
{
这个是把数组传给computerbuffer,
public static ComputeBuffer ToComputeBuffer<T>(this T[] array, int stride,
ComputeBufferType type = ComputeBufferType.Default)
{
var buffer = new ComputeBuffer(array.Length, stride, type);
buffer.SetData(array);
return buffer;
}
public static T[] ToArray<T>(this ComputeBuffer buffer)
{
var array = new T[buffer.count];
buffer.GetData(array);
return array;
}
public static void LogBuffer<T>(ComputeBuffer buffer)
{
var array = new T[buffer.count];
buffer.GetData(array);
for (var i = 0; i < array.Length; i++)
Debug.Log(string.Format("i:{0} val:{1}", i, array[i]));
}
public static void LogLargeBuffer<T>(ComputeBuffer buffer)
{
var array = new T[buffer.count];
buffer.GetData(array);
var log = "";
for (var i = 1; i <= array.Length; i++)
{
log += "|" + array[i - 1];
if (i % 12 == 0)
{
Debug.Log(string.Format("from i:{0} values:{1}", i, log));
log = "";
}
}
}
}