为什么要从深度图重建视空间法线

一个很大的应用情景是在后处理的阶段,或是计算一些屏幕空间的效果(如SSR、SSAO等),只能获取到一张深度贴图,而不是每一个几何体的顶点数据,很多的计算中却又需要用到世界空间的法线或者是视空间的法线,这时我们就需要通过深度图来重建视空间的法线。(诶这段话我是不是写过一遍了)

重建视空间法线的方法

bgolus在他的WorldNormalFromDepthTexture.shader里面很全面的介绍了各种重建视空间法线的方法。其中比较值得注意的是来自Janos Turanszki的根据深度差判断当前像素属于哪个平面的方法,和来自吴彧文的横向和纵向多采样一个点来判断当前像素属于哪个平面的方法,其中吴彧文的方法能够在绝大部分情况下获取到最准确的法线(除了尖角的一个像素)。

除了bgolus介绍的方法之外,我在GameTechDev/XeGTAO中还看到了一种方法。这种方法类似于Janos Turanszki的深度差的方法,不过从深度差中获取的是0-1的边缘值(edgesLRTB,edgesLRTB.x越接近0即代表该像素的左侧越是一条边缘),再使用边缘的两两乘积对四个法线进行插值,最终计算出视空间法线。我个人认为当在两个面相接的地方不需要特别准确的法线值时,这是最好的计算法线的方法。用这个方式计算的法线,在两个面相接的地方,法线会有一种从一个面插值到另一个面的效果(且一定程度上抗锯齿),在两个面远近排布的时候,也能获取到准确的法线。

具体的实现方法

  1. 根据需要使用的方法,采样深度图。在采样比较集中的情况下,可以使用GatherRed方法来减少采样的次数。GatherRed可以得到双线性采样时的四个像素的R通道的值并封装到一个float4中,当屏幕左下角是(0, 0)时,这个float4的x分量对应采样点左上角的颜色的R通道的值,y对应右上角,z对应右下角,w对应左下角,可以在HLSL的文档中看到Gather的相关介绍。Compute Shader的话可以使用group shared memory进一步减少采样。
  2. 使用深度图和当前的uv值计算出像素的视空间的坐标,这一步尤其需要注意视空间坐标Z分量的正负性的问题。Unity的视空间变换矩阵UNITY_MATRIX_V是摄像机位于视空间(0, 0, 0),看向视空间Z轴负方向的,右手系的矩阵。即视空间的坐标Z分量往往是一个负值,其法线的Z分量在往往下是正值(即画面看上去应该多为蓝色)。
  3. 从深度图中计算视空间坐标的时候,如果Unity版本比较旧,会没有UNITY_MATRIX_I_P这个矩阵,这时可以使用unity_CameraInvProjection来代替,但需要注意DirectX平台UV上下翻转的问题。
  4. 当屏幕左下角是(0, 0)时,使用右侧的视空间坐标减去左侧的视空间坐标,使用上侧的视空间坐标减去下侧的视空间坐标。五个采样点(包括位于中心的当前像素)可以获得四个向量,对于右手系的视空间坐标来说,将这四个向量按照水平向量叉乘竖直向量的顺序,就可以获得四个当前像素的法线了。
  5. 最后使用前面介绍的获取法线的方法,从这四个法线中获取最为正确的法线。这些方法往往都会使用深度值来进行判断,这里需要注意的是透视变换带来的深度的非线性的问题。对于屏幕上等距分布的三个点ABC,当他们在世界空间中处于同一条直线时,有 $$ 2 \cdot rawDepthB = rawDepthA + rawDepthC \newline \frac 2 {linearDepthB} = \frac 1 {linearDepthA} + \frac 1 {linearDepthC} $$

ReconstructNormalComputeShader.compute

使用GatherRed的方法,可以减少ReconstructNormalAccurate所需要的的采样,但是在屏幕的边缘会有一些瑕疵,把采样的sampler改成sampler_LinearRepeat在一定程度上能够解决这些瑕疵。这样的话ReconstructNormalFast需要两次采样,ReconstructNormalAccurate则需要五次采样。 要注意使用边缘信息对法线进行插值的方法,需要先对法线进行归一化,不然叉乘导致前后平面计算出的向量长度会远大于同一平面的向量长度,影响最终的法线。

#pragma kernel ReconstructNormalFast
#pragma kernel ReconstructNormalAccurate

#include "Packages/com.unity.render-pipelines.core/ShaderLibrary/Common.hlsl"
#include "Packages/com.unity.render-pipelines.universal/ShaderLibrary/Core.hlsl"

#define FEWER_SAMPLES 0

Texture2D<float> _DepthTexture;
RWTexture2D<float4> _RW_NormalTexture;

SamplerState sampler_LinearClamp;
SamplerState sampler_LinearRepeat;
float4 _TextureSize;

float3 GetViewSpacePosition(float2 uv, float depth)
{
#if UNITY_UV_STARTS_AT_TOP
    uv.y = 1.0 - uv.y;
#endif
    float3 positionNDC = float3(uv * 2.0 - 1.0, depth);
    float4 positionVS = mul(UNITY_MATRIX_I_P, float4(positionNDC, 1.0));
    positionVS /= positionVS.w;
    return positionVS.xyz;
}

//-UNITY_MATRIX_P._m11 = rcp(tan(fovy / 2))
float3 GetViewSpacePositionFromLinearDepth(float2 uv, float linearDepth)
{
#if UNITY_UV_STARTS_AT_TOP
    uv.y = 1.0 - uv.y;
#endif
    float2 uvNDC = uv * 2.0 - 1.0;
    return float3(uvNDC * linearDepth * UNITY_MATRIX_I_P._m00_m11, -linearDepth);
}

//Calculate 4 linear eye depths at one time
float4 LinearEyeDepthFloat4(float4 depthTBLR, float4 zBufferParams)
{
    return rcp(depthTBLR * zBufferParams.z + zBufferParams.w);
}

//Heavily based on normal reconstruction method in github repository GameTechDev/XeGTAO.
//https://github.com/GameTechDev/XeGTAO/blob/0d177ce06bfa642f64d8af4de1197ad1bcb862d4/Source/Rendering/Shaders/XeGTAO.hlsli#L143-L160
[numthreads(8,8,1)]
void ReconstructNormalFast (uint3 dispatchThreadID : SV_DispatchThreadID)
{
    float4 depthGatherBL = _DepthTexture.GatherRed(sampler_LinearClamp, dispatchThreadID.xy * _TextureSize.zw);
    float4 depthGatherTR = _DepthTexture.GatherRed(sampler_LinearClamp, dispatchThreadID.xy * _TextureSize.zw, int2(1, 1));

    float depthC = depthGatherBL.y;
    float depthT = depthGatherTR.x;
    float depthB = depthGatherBL.z;
    float depthL = depthGatherBL.x;
    float depthR = depthGatherTR.z;

    float linearDepth = LinearEyeDepth(depthC, _ZBufferParams);
    float4 linearDepths = LinearEyeDepthFloat4(float4(depthT, depthB, depthL, depthR), _ZBufferParams);

    float4 depthDifferenceTBLR = linearDepths - linearDepth;
    float slopeTB = (depthDifferenceTBLR.x - depthDifferenceTBLR.y) * 0.5;
    float slopeLR = (depthDifferenceTBLR.w - depthDifferenceTBLR.z) * 0.5;
    float4 depthDifferenceTBLRAverage = depthDifferenceTBLR + float4(-slopeTB, slopeTB, slopeLR, -slopeLR);
    depthDifferenceTBLR = min(abs(depthDifferenceTBLR), abs(depthDifferenceTBLRAverage));

    //0: edge; 1: non-edge
    float4 edgesTBLR = saturate(1.25 - depthDifferenceTBLR / (linearDepth * 0.011));
    //TL, TR, BR, BL 
    float4 acceptedNormals = saturate(float4(edgesTBLR.x * edgesTBLR.z, edgesTBLR.w * edgesTBLR.x, edgesTBLR.y * edgesTBLR.w, edgesTBLR.z * edgesTBLR.y) + 0.001);

    float3 viewPosC = GetViewSpacePosition((dispatchThreadID.xy + float2(0.5, 0.5)) * _TextureSize.zw, depthC);
    float3 viewPosT = GetViewSpacePosition((dispatchThreadID.xy + float2(0.5, 1.5)) * _TextureSize.zw, depthT);
    float3 viewPosB = GetViewSpacePosition((dispatchThreadID.xy + float2(0.5, -0.5)) * _TextureSize.zw, depthB);
    float3 viewPosL = GetViewSpacePosition((dispatchThreadID.xy + float2(-0.5, 0.5)) * _TextureSize.zw, depthL);
    float3 viewPosR = GetViewSpacePosition((dispatchThreadID.xy + float2(1.5, 0.5)) * _TextureSize.zw, depthR);

    float3 t = normalize(viewPosT - viewPosC);
    float3 b = normalize(viewPosC - viewPosB);
    float3 l = normalize(viewPosC - viewPosL);
    float3 r = normalize(viewPosR - viewPosC);

    float3 normalVS =   acceptedNormals.x * cross(l, t) + 
                        acceptedNormals.y * cross(r, t) + 
                        acceptedNormals.z * cross(r, b) + 
                        acceptedNormals.w * cross(l, b);
    normalVS = normalize(normalVS);
    _RW_NormalTexture[dispatchThreadID.xy] = float4(normalVS, 1.0);
}

//Heavily based on github gist bgolus/WorldNormalFromDepthTexture.shader
//https://gist.github.com/bgolus/a07ed65602c009d5e2f753826e8078a0#file-worldnormalfromdepthtexture-shader-L153-L218
//https://atyuwen.github.io/posts/normal-reconstruction/
[numthreads(8,8,1)]
void ReconstructNormalAccurate (uint3 dispatchThreadID : SV_DispatchThreadID)
{
#if FEWER_SAMPLES
    float4 depthGatherTL = _DepthTexture.GatherRed(sampler_LinearClamp, dispatchThreadID.xy * _TextureSize.zw, int2(-1, 1));
    float4 depthGatherTR = _DepthTexture.GatherRed(sampler_LinearClamp, dispatchThreadID.xy * _TextureSize.zw, int2(1, 2));
    float4 depthGatherBR = _DepthTexture.GatherRed(sampler_LinearClamp, dispatchThreadID.xy * _TextureSize.zw, int2(2, 0));
    float4 depthGatherBL = _DepthTexture.GatherRed(sampler_LinearClamp, dispatchThreadID.xy * _TextureSize.zw, int2(0, -1));
    
    float depthC = _DepthTexture.Load(int3(dispatchThreadID.xy, 0));
    float depthT = depthGatherTR.w;
    float depthB = depthGatherBL.y;
    float depthL = depthGatherTL.z;
    float depthR = depthGatherBR.x;
    float depthT2 = depthGatherTR.x;
    float depthB2 = depthGatherBL.z;
    float depthL2 = depthGatherTL.w;
    float depthR2 = depthGatherBR.y;

#else
    float4 depthGatherBL = _DepthTexture.GatherRed(sampler_LinearClamp, dispatchThreadID.xy * _TextureSize.zw);
    float4 depthGatherTR = _DepthTexture.GatherRed(sampler_LinearClamp, dispatchThreadID.xy * _TextureSize.zw, int2(1, 1));

    float depthC = depthGatherBL.y;
    float depthT = depthGatherTR.x;
    float depthB = depthGatherBL.z;
    float depthL = depthGatherBL.x;
    float depthR = depthGatherTR.z;
    float depthT2 = _DepthTexture.Load(int3(dispatchThreadID.xy + int2(0, 2), 0));
    float depthB2 = _DepthTexture.Load(int3(dispatchThreadID.xy + int2(0, -2), 0));
    float depthL2 = _DepthTexture.Load(int3(dispatchThreadID.xy + int2(-2, 0), 0));
    float depthR2 = _DepthTexture.Load(int3(dispatchThreadID.xy + int2(2, 0), 0));
#endif

    float3 viewPosC = GetViewSpacePosition((dispatchThreadID.xy + float2(0.5, 0.5)) * _TextureSize.zw, depthC);
    float3 viewPosT = GetViewSpacePosition((dispatchThreadID.xy + float2(0.5, 1.5)) * _TextureSize.zw, depthT);
    float3 viewPosB = GetViewSpacePosition((dispatchThreadID.xy + float2(0.5, -0.5)) * _TextureSize.zw, depthB);
    float3 viewPosL = GetViewSpacePosition((dispatchThreadID.xy + float2(-0.5, 0.5)) * _TextureSize.zw, depthL);
    float3 viewPosR = GetViewSpacePosition((dispatchThreadID.xy + float2(1.5, 0.5)) * _TextureSize.zw, depthR);

    float3 t = viewPosT - viewPosC;
    float3 b = viewPosC - viewPosB;
    float3 l = viewPosC - viewPosL;
    float3 r = viewPosR - viewPosC;

    float4 H = float4(depthL, depthR, depthL2, depthR2);
    float4 V = float4(depthB, depthT, depthB2, depthT2);

    float2 he = abs((2 * H.xy - H.zw) - depthC);
    float2 ve = abs((2 * V.xy - V.zw) - depthC);

    float3 hDeriv = he.x < he.y ? l : r;
    float3 vDeriv = ve.x < ve.y ? b : t;
    float3 normalVS = normalize(cross(hDeriv, vDeriv));
    _RW_NormalTexture[dispatchThreadID.xy] = float4(normalVS, 1.0);
}

最后的思考

本来还想使用3x3的采样,使用类似于吴彧文的方法,延伸第三个点到当前像素来计算准确的法线的,但是实际操作了一下发现,只要四个点构成了平行四边形就会认为是接近于当前采样,于是就会导致计算出错误的法线了。XeGTAO里面使用的计算法线的方式确实很巧妙,应该多用用,之后可能会再写一篇计算GTAO的文章吧。