Skip to the content.
Posts

After so much fun with mesh shaders on D3D12, I wanted to take a stab at mesh shaders on Vulkan. I have a bunch of observations about mesh shading on Vulkan which will be in a later post. This post will focus on getting up and going with with Vulkan mesh shaders. Also, because I was feeling lazy, the shader code is in HLSL.

Getting up and going with mesh shading on Vulkan is simple and straight foward. Load the extension, load the function, create the graphics pipeline and you’re off to the races.

Source Code

GREX Project: 110_mesh_shader_triangle_vulkan
C++ Source: 110_mesh_shader_triangle_vulkan.cpp

VK_EXT_mesh_shader vs VK_NV_mesh_shader

All the GREX mesh shader examples use VK_EXT_mesh_shader. I briefly looked at VK_NV_mesh_shader but I couldn’t get anything to render using a similar code path to VK_EXT_mesh_shader. So in the interest of time, I paused exploration of VK_NV_mesh_shader - will circle back to it when there’s more time.

Device Creation

Like other Vulkan extensions, VK_EXT_mesh_shader has a feature struct that needs to go into the pNext of VkDeviceCreateInfo and requires the the extension name to appear in the list of VkDeviceCreateInfo::ppEnabledExtensionNames.


VkPhysicalDeviceMeshShaderFeaturesEXT meshShaderFeatures
    = {VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_MESH_SHADER_FEATURES_EXT};
meshShaderFeatures.taskShader        = VK_TRUE; // Task/amplification shader
meshShaderFeatures.meshShader        = VK_TRUE; // Mesh shader
meshShaderFeatures.meshShaderQueries = VK_TRUE; // For pipeline statistics

std::vector<const char*> enabledExtensionNames = {VK_EXT_MESH_SHADER_EXTENSION_NAME};

VkDeviceCreateInfo vkci      = {VK_STRUCTURE_TYPE_DEVICE_CREATE_INFO};
vkci.pNext                   = &meshShaderFeatures;
vkci.ppEnabledExtensionNames = enabledExtensionNames.data();

Function Loading

Since it’s an EXT extension, you’ll also need to load vkCmdDrawMeshTasksEXT.

PFN_vkCmdDrawMeshTasksEXT my_vkCmdDrawMeshTasksEXT
    = (PFN_vkCmdDrawMeshTasksEXT)vkGetInstanceProcAddr(instanceHandle, "vkCmdDrawMeshTasksEXT");

Mesh Shader

Same as D3D12 mesh shading. Compiled using ms_6_5 profile. For this example, MeshOutput uses the same semantics that would appear in a typical VSOutput struct.

struct MeshOutput {
    float4 Position : SV_POSITION;
    float3 Color    : COLOR;
};

[outputtopology("triangle")]
[numthreads(1, 1, 1)]
void msmain(out indices uint3 triangles[1], out vertices MeshOutput vertices[3]) {
    SetMeshOutputCounts(3, 1);
    triangles[0] = uint3(0, 1, 2);

    vertices[0].Position = float4(-0.5, 0.5, 0.0, 1.0);
    vertices[0].Color = float3(1.0, 0.0, 0.0);

    vertices[1].Position = float4(0.5, 0.5, 0.0, 1.0);
    vertices[1].Color = float3(0.0, 1.0, 0.0);

    vertices[2].Position = float4(0.0, -0.5, 0.0, 1.0);
    vertices[2].Color = float3(0.0, 0.0, 1.0);
}

Pixel Shader

Same as D3D12 mesh shading. Compiled using ps_6_5 profile to match the mesh shader for completeness. No surprises here.

float4 psmain(MeshOutput input) : SV_TARGET
{
    return float4(input.Color, 1);
}

Mesh Shading Graphics Pipeline

Creating a graphics pipeline for mesh shading is simple and straight forward. Provide a shader module using VkPipelineShaderStageCreateInfo for VK_SHADER_STAGE_MESH_BIT_EXT. It’s not necessary to explicitly NULL out pVertexInputState and pInputAssemblyState since they’re ignored if pipeline includes mesh shader stage.

    std::vector<VkPipelineShaderStageCreateInfo> shaderStages;

// Mesh shader
VkPipelineShaderStageCreateInfo shaderStageCreateInfo
    = {VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO};
shaderStageCreateInfo.stage  = VK_SHADER_STAGE_MESH_BIT_EXT;
shaderStageCreateInfo.module = msShaderModule;
shaderStageCreateInfo.pName  = "msmain";
shaderStages.push_back(shaderStageCreateInfo);

// Fragment shader
shaderStageCreateInfo        = {VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO};
shaderStageCreateInfo.stage  = VK_SHADER_STAGE_FRAGMENT_BIT;
shaderStageCreateInfo.module = fsShaderModule;
shaderStageCreateInfo.pName  = "psmain";
shaderStages.push_back(shaderStageCreateInfo);

VkGraphicsPipelineCreateInfo pipelineInfo 
    = {VK_STRUCTURE_TYPE_GRAPHICS_PIPELINE_CREATE_INFO};
pipelineInfo.stageCount          = CountU32(shaderStages);
pipelineInfo.pStages             = DataPtr(shaderStages);
pipelineInfo.pVertexInputState   = nullptr; // Ignored if pipeline includes mesh shader stage
pipelineInfo.pInputAssemblyState = nullptr; // Ignored if pipeline includes mesh shader stage

One of the nice thing about mesh shading is that it saves you from having to fill out a VkPipelineVertexInputStateCreateInfo and VkPipelineInputAssemblyStateCreateInfo. But of course, it’s Vulkan, so there’s plenty of other create info structs waiting for you!

Dispatching Mesh Shader

Finally, to render using vkCmdDrawMeshTasksEXT(), the command buffer recording looks something like:

vkCmdBeginRendering(commandBuffer, &vkri);

VkViewport viewport = {
    0, 
    static_cast<float>(gWindowHeight), 
    static_cast<float>(gWindowWidth), 
    -static_cast<float>(gWindowHeight),
    0.0f,
    1.0f};
vkCmdSetViewport(commandBuffer, 0, 1, &viewport);

VkRect2D scissor = {0, 0, gWindowWidth, gWindowHeight};
vkCmdSetScissor(commandBuffer, 0, 1, &scissor);

vkCmdBindPipeline(commandBuffer, VK_PIPELINE_BIND_POINT_GRAPHICS, pipeline);

my_vkCmdDrawMeshTasksEXT(commandBuffer, 1, 1, 1);

vkCmdEndRendering(commandBuffer);