Skip to the content.
Posts

First published: 2024-01-25
Last updated: 2024-01-27

About the Mesh Shading Series

This post is part 2 of a series about mesh shading. My intent in this series is to introduce the various parts of mesh shading in an easy to understand fashion. Well, as easy as I can make it. My objective isn’t to convince you to use mesh shading. I assume you’re reading this post because you’re already interested in mesh shading. Instead, my objective is to explain the mechanics of how to do mesh shading in Direct3D 12, Metal, and Vulkan as best I can. My hope is that you’re able to use this information in your own graphics projects and experiments.

Sample Projects for This Post

112_mesh_shader_amplification - Do the lights turn on? Demonstrates the most absolute basic functionality of a amplification/task shader (D3D12/Vulkan) and object function (Metal).

Introduction

Before going any further, for simplicity, I’m going to use the term amplification to refer to the stage that comes before the mesh shader stage. Each of the 3 graphics APIs has a different name for this stage: amplification shader stage (D3D12), task shader stage (Vulkan), and object function (Metal). Of course, when a distiction is necessary, the API’s name will be used.

This post aims to be shorter than Mesh Shading Part 1: Rendering Meshlets since we don’t have to introduce what a meshlet is and go over all the API and shader calls. There will be some changes to both the C++ code and the shader code for each section: Amplication and Instancing but they’re relatively small.

Lets jump to our first set of changes: amplification pipeline stage.

Amplification Pipeline Stage

I’m not sure if it’s because of graphics programming PTSD, but I was pretty surprised to find out: adding amplification to a graphics pipeline only requires adding the pipeline stage. No additional structs or struct fields.

Direct3D

D3DX12_MESH_SHADER_PIPELINE_STATE_DESC psoDesc  = {};
psoDesc.pRootSignature                          = rootSig.Get();
psoDesc.AS                                      = {dxilAS.data(), dxilAS.size()}; // **NEW**
psoDesc.MS                                      = {dxilMS.data(), dxilMS.size()};
psoDesc.PS                                      = {dxilPS.data(), dxilPS.size()};
//
// ...all the other stuff...
//
CD3DX12_PIPELINE_MESH_STATE_STREAM psoStream = CD3DX12_PIPELINE_MESH_STATE_STREAM(psoDesc);

D3D12_PIPELINE_STATE_STREAM_DESC steamDesc = {};
steamDesc.SizeInBytes                      = sizeof(psoStream);
steamDesc.pPipelineStateSubobjectStream    = &psoStream;

ComPtr<ID3D12PipelineState> pipelineState;
//
HRESULT hr = renderer->Device->CreatePipelineState(&steamDesc, IID_PPV_ARGS(&pipelineState));
if (FAILED(hr))
{
    assert(false && "Create pipeline state failed");
    return EXIT_FAILURE;
}

Metal

// Render pipeline state
auto desc = NS::TransferPtr(MTL::MeshRenderPipelineDescriptor::alloc()->init());
if (!desc) {
    assert(false && "MTL::MeshRenderPipelineDescriptor::alloc::init() failed");
    return EXIT_FAILURE;        
}

desc->setObjectFunction(osShader.Function.get()); // **NEW**
desc->setMeshFunction(msShader.Function.get());
desc->setFragmentFunction(fsShader.Function.get());
desc->colorAttachments()->object(0)->setPixelFormat(GREX_DEFAULT_RTV_FORMAT);
desc->setDepthAttachmentPixelFormat(GREX_DEFAULT_DSV_FORMAT);

NS::Error* pError = nullptr;
renderPipelineState.State = NS::TransferPtr(renderer->Device->newRenderPipelineState(desc.get(), MTL::PipelineOptionNone, nullptr, &pError));
if (renderPipelineState.State.get() == nullptr) {
    assert(false && "MTL::Device::newRenderPipelineState() failed");
    return EXIT_FAILURE;
}

Vulkan

std::vector<VkPipelineShaderStageCreateInfo> shaderStages;
// Task (amplification) shader **NEW**
VkPipelineShaderStageCreateInfo shaderStageCreateInfo = {VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO};
shaderStageCreateInfo.stage                           = VK_SHADER_STAGE_TASK_BIT_EXT;
shaderStageCreateInfo.module                          = asShaderModule;
shaderStageCreateInfo.pName                           = "asmain";
shaderStages.push_back(shaderStageCreateInfo);
// Mesh shader
VkPipelineShaderStageCreateInfo shaderStageCreateInfo = {VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO};
shaderStageCreateInfo.stage                           = VK_SHADER_STAGE_MESH_BIT_EXT;
shaderStageCreateInfo.module                          = msShaderModule;
shaderStageCreateInfo.pName                           = "msmain";
shaderStages.push_back(shaderStageCreateInfo);
// Fragment shader
shaderStageCreateInfo        = {VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO};
shaderStageCreateInfo.stage  = VK_SHADER_STAGE_FRAGMENT_BIT;
shaderStageCreateInfo.module = fsShaderModule;
shaderStageCreateInfo.pName  = "psmain";
shaderStages.push_back(shaderStageCreateInfo);
//
// ...other stuff...
//
VkGraphicsPipelineCreateInfo pipeline_info = {VK_STRUCTURE_TYPE_GRAPHICS_PIPELINE_CREATE_INFO};
pipeline_info.pNext                        = &pipeline_rendering_create_info;
pipeline_info.stageCount                   = CountU32(shaderStages);
pipeline_info.pStages                      = DataPtr(shaderStages);
//
// ...other stuff...
//
VkResult vkres = vkCreateGraphicsPipelines(
    pRenderer->Device,
    VK_NULL_HANDLE, // Not using a pipeline cache
    1,
    &pipeline_info,
    nullptr,
    pPipeline);
if (vkres != VK_SUCCESS) {
    assert(false && "Create pipeline state failed");
    return EXIT_FAILURE;
}    

C++ Dispatch Call Changes

For our basic do the lights turn on amplification sample, we need to make some chagnes to the threadgroup count argument for the dispatch calls.

REMINDER: Threadgroup size refers to the number of threads within a threadgroup.

// -----------------------------------------------------------------------------
// Direct3D
// -----------------------------------------------------------------------------
// Amplification shader uses 32 for threadgroup size
UINT threadGroupCountX = static_cast<UINT>((meshlets.size() / 32) + 1);
commandList->DispatchMesh(threadGroupCountX, 1, 1);

// -----------------------------------------------------------------------------
// Metal
// -----------------------------------------------------------------------------
// Object function uses 32 for threadgroup size
uint32_t threadGroupCountX = static_cast<uint32_t>((meshlets.size() / 32) + 1);
pRenderEncoder->drawMeshThreadgroups(MTL::Size(threadGroupCountX, 1, 1), MTL::Size(32, 1, 1), MTL::Size(128, 1, 1));

// -----------------------------------------------------------------------------
// Vulkan
// -----------------------------------------------------------------------------
// Task (amplification) shader uses 32 for threadgroup size
uint32_t threadGroupCountX = static_cast<UINT>((meshlets.size() / 32) + 1);
fn_vkCmdDrawMeshTasksEXT(cmdBuf.CommandBuffer, threadGroupCountX, 1, 1);

threadGroupCountX Breakdown

If you recall from Mesh Shading Part 1: Rendering Meshlets, we dispatched 241 threadgroups since that’s the number of meshlets we generated. We still have the 241 meshlets in this post, but we need to coordinate the work a bit differently.

With the introduction of the amplification shader stage, the dispatch calls from the C++ side no longer invokes the mesh shader directly. Instead, the dispatch call invokes the the amplificaiton shader. And as we’ll see later, the amplification shader will dispatch the mesh shader calls.

In our case here, our amplification shader will have a threadgroup size of 32 threads. So we want to distribute meshlets across the threads for the amplification shader invocations.

threadGroupCountX = (meshlets.size() / 32) + 1);
threadGroupCountX = (241 / 32) + 1;
threadGroupCountX = 7 + 1;
threadGroupCountX = 8;

Looks like we’re dispatching 8 amplificaiton workgroups for a total of 8*32=256 amplificaiton threads. 256 sufficiently covers our 241 meshlets.

New Amplification Shader

Next up are the changes to our shader code. Lets talk first about the newly introduced amplification shader.

NOTE: Since we’re carrying forward our mesh shader from Mesh Shading Part 1: Rendering Meshlets, we’ll use 32 for the threadgroup count for the dispatch call in the amplification shader. But if you need to do some debugging with something like a single triangle, it’s perfectly fine to dispatch with 1 for the threadgroup count.

HLSL for Direct3D and Vulkan

#define AS_GROUP_SIZE 32

struct Payload {
    uint MeshletIndices[AS_GROUP_SIZE];
};

groupshared Payload sPayload;

[numthreads(AS_GROUP_SIZE, 1, 1)]
void asmain(
    uint gtid : SV_GroupThreadID,
    uint dtid : SV_DispatchThreadID,
    uint gid  : SV_GroupID
)
{
    sPayload.MeshletIndices[gtid] = dtid;
    // Assumes all meshlets are visible
    DispatchMesh(AS_GROUP_SIZE, 1, 1, sPayload);
}

MSL for Metal

#define AS_GROUP_SIZE 32

struct Payload {
    uint MeshletIndices[AS_GROUP_SIZE];
};

[[object]]
void objectMain(
    uint                 gtid       [[thread_position_in_threadgroup]],
    uint                 dtid       [[thread_position_in_grid]],
    object_data Payload& outPayload [[payload]],
    mesh_grid_properties outGrid)
{
    outPayload.MeshletIndices[gtid] = dtid;
    // Assumes all meshlets are visible
    outGrid.set_threadgroups_per_grid(uint3(AS_GROUP_SIZE, 1, 1));
}

Why does AS_GROUP_SIZE equal 32?

32 is selected to match the number of threads in a GPU wave, aka the GPU’s wave size. We’ll get into the details of this when we we get to instancing.

Amplification Shader Breakdown

What the heck is dtid?

dtid stands for dispatch thread id - this is an index into the launched threads within the entire dispatch. In Mesh Shading Part 1: Rendering Meshlets, we said that gtid is an index into the launched threads within the workgroup, which means that gtid is local to the threadgroup. dtid, on the other hand, is local to the entire dispatch. Earlier in this post, we said that we’re dispatching 8 threadgroups with each threadgroup having 32 threads for a total of 256 threads. Each thread has a dtid assigned to it ranging from 0 to 255.

Payloads: Sending Data from Amplification Shader to Mesh Shader

They payload is a chunk of memory that’s shared by the threadgroup, hence the attribute groupshared in the HLSL code. For MSL, the decoration object_data is used in the payload parameter declartion to signify that it’s in the object_data address space. The object_data behaves like the threadgroup address space and allows threads in the same threadgroup to share data.

The D3D12 Mesh Shader Spec says that the maximum size for payload per threadgroup is 16k. Apple’s WWDC Transform your geometry with Metal mesh shaders talk also states that the maximum size for payload per threadgroup is 16k. I couldn’t find a definitive source for Vulkan, but probably safe to guess that the maximum size for payload per threadgroup is 16k.

Of course, we’re not going to need anywhere near 16k for our payload:

#define AS_GROUP_SIZE 32

struct Payload {
    uint MeshletIndices[AS_GROUP_SIZE];
};

Both the HLSL and MSL code write to the payload the same way in our basic little sample:

sPayload.MeshletIndices[gtid] = dtid;

What’s happening here is basically we’re using the group thread id as an index into the MeshletIndices and then writing the dispatch thread id. This is the amplification shader saying, “Hey mesh shader invocations, whoever is looking at index N, the value there is the meshlet you should be looking at!”

Here’s some conceptual code of how the payload is handled:

// Dispatch
#define AS_GROUP_SIZE 32

struct Payload {
    uint MeshletIndices[AS_GROUP_SIZE];
};

// Payloads for all our threadgroups
static std::vector<Payload> sPayloads;

// Dispatch 8 amplification shader groups
const int threadGroupCountX = 8;
Dispatch(threadGroupCountX);

void DispatchAmplificationShader(int threadGroupCountX)
    // Allocate storage based on threadGroupCountX
    sPayloads.resize(threadGroupCountX);

    for (int i = 0; i < threadGroupCountX; ++i) {
        int gid = i;
        // Select payload that will travel all the way to the mesh shader
        auto& payload = sPayloads[i];        
        LaunchAmplificationThreadGroup(gid, payload);
    }
}

void LaunchAmplificationThreadGroup(int gid, Payload& payload)
{
    for (int i = 0; i < AS_GROUP_SIZE; ++i) {
        int gtid = i;
        int dtid = gid * AS_GROUP_SIZE + i; // Index of meshlet we're interested in        
        LaunchAmplificationThread(gid, gtid, dtid, payload);
    }   
    // All amplification threads in group must be completed before proceeding

    // Dispatch 32 mesh shader groups
    DispatchMeshShader(AS_GROUP_SIZE, payload ); // Assumes all meshlets are visible    
}

void LaunchAmplificationThread(Payload& payload, int gid, int gtid, int dtid)
{
    // Tell mesh shader which meshlet to look at
    payload.MeshletIndices[gtid] = dtid;
}

void DispatchMeshShader(int threadGroupCountX, Payload& payload)
{
    for (int i = 0; i < threadGroupCountX; ++i) {
        int gid = i;
        LaunchMeshShaderThreadGroup(gid, payload);
    }
}

// Our mesh shader has a threadgroup size of 128
void LaunchMeshShaderThreadGroup(int gid, Payload& payload)
{
    const int kNumThreads = 128;
    for (int i = 0; i < kNumThreads; ++i) {
        int gtid = i;
        LaunchMeshShaderThread(gid, gtid, payload);
    }
}

void LaunchMeshShaderThread(int gid, int gtid, Payload& payload)
{
    // Look up meshlet
    uint meshletIndex = payload.MeshletIndices[gid];
    Meshlet m = Meshlets[meshletIndex];

    // Do some stuff
}

In the conceptual code above, the payload gets selected early on during the the dispatch and this selectged payload travels all the way to the mesh shader. Obvioulsy GPUs have their own way of doing this, but the overcall concept based on how the shader code flows should remain the same.

API Notes

For all APIs, calling the dispatch mesh shader function will end execution of the entire amplification shader threadgroup’s execution. That is, calling DispatchMesh() or set_threadgroups_per_grid() will end execution for threadgroup.

Here’s the verbiage from APIs:

Mesh Shader Changes

The only changes to the mesh shaders are:

HLSL for Direct3D and Vulkan

[outputtopology("triangle")]
[numthreads(128, 1, 1)]
void msmain(
                 uint       gtid : SV_GroupThreadID, 
                 uint       gid  : SV_GroupID, 
     in payload  Payload    payload,  // **NEW**
    out indices  uint3      triangles[128], 
    out vertices MeshOutput vertices[64]) 
{
    uint meshletIndex = payload.MeshletIndices[gid]; // **NEW**

    Meshlet m = Meshlets[meshletIndex]; // **NEW**
    SetMeshOutputCounts(m.VertexCount, m.TriangleCount);
       
    if (gtid < m.TriangleCount) {
        //
        // meshopt stores the triangle offset in bytes since it stores the
        // triangle indices as 3 consecutive bytes. 
        //
        // Since we repacked those 3 bytes to a 32-bit uint, our offset is now
        // aligned to 4 and we can easily grab it as a uint without any 
        // additional offset math.
        //
        uint packed = TriangleIndices[m.TriangleOffset + gtid];
        uint vIdx0  = (packed >>  0) & 0xFF;
        uint vIdx1  = (packed >>  8) & 0xFF;
        uint vIdx2  = (packed >> 16) & 0xFF;
        triangles[gtid] = uint3(vIdx0, vIdx1, vIdx2);
    }

    if (gtid < m.VertexCount) {
        uint vertexIndex = m.VertexOffset + gtid;        
        vertexIndex = VertexIndices[vertexIndex];

        vertices[gtid].Position = mul(Cam.MVP, float4(Vertices[vertexIndex].Position, 1.0));
        
        float3 color = float3(
            float(meshletIndex & 1),
            float(meshletIndex & 3) / 4,
            float(meshletIndex & 7) / 8);
        vertices[gtid].Color = color;
    }
}

MSL for Metal

using MeshOutput = metal::mesh<MeshVertex, void, 128, 256, topology::triangle>;

[[mesh]]
void meshMain(
    constant CameraProperties& Cam                   [[buffer(0)]],
    device const Vertex*       Vertices              [[buffer(1)]],
    device const Meshlet*      Meshlets              [[buffer(2)]],
    device const uint*         MeshletVertexIndices  [[buffer(3)]],
    device const uint*         MeshletTriangeIndices [[buffer(4)]],
    object_data const Payload& payload               [[payload]],  // **NEW**
    uint                       gtid                  [[thread_position_in_threadgroup]],
    uint                       gid                   [[threadgroup_position_in_grid]],
    MeshOutput                 outMesh)
{
    uint meshletIndex = payload.MeshletIndices[gid];  // **NEW**

    device const Meshlet& m = Meshlets[meshletIndex];
    outMesh.set_primitive_count(m.TriangleCount);

    if (gtid < m.TriangleCount) {
        //
        // meshopt stores the triangle offset in bytes since it stores the
        // triangle indices as 3 consecutive bytes. 
        //
        // Since we repacked those 3 bytes to a 32-bit uint, our offset is now
        // aligned to 4 and we can easily grab it as a uint without any 
        // additional offset math.
        //
        uint packed = MeshletTriangeIndices[m.TriangleOffset + gtid];
        uint vIdx0  = (packed >>  0) & 0xFF;
        uint vIdx1  = (packed >>  8) & 0xFF;
        uint vIdx2  = (packed >> 16) & 0xFF;
        
        uint triIdx = 3 * gtid;
        outMesh.set_index(triIdx + 0, vIdx0);
        outMesh.set_index(triIdx + 1, vIdx1);
        outMesh.set_index(triIdx + 2, vIdx2);
    }

    if (gtid < m.VertexCount) {
        uint vertexIndex = m.VertexOffset + gtid;
        vertexIndex = MeshletVertexIndices[vertexIndex];

        MeshVertex vtx;
        vtx.PositionCS = Cam.MVP * float4(Vertices[vertexIndex].Position, 1.0);
        vtx.Color = float3(
            float(gid & 1),
            float(gid & 3) / 4,
            float(gid & 7) / 8);

        outMesh.set_vertex(gtid, vtx);   
    }
}

Full Shader Source

The shader code is getting big, so unfortunately it’s not practical to embed them here. That would make the post unnecessary long and probably less pleasant to look at.

Here are links to the HLSL and MSL:

Rendered Image

The 112_mesh_shader_amplification (do the lights turn on) sample renders the following image - not at all different from the result of the last post. But don’t worry, we’ll get a different image in the next post!