duckstation

duckstation, but archived from the revision just before upstream changed it to a proprietary software project, this version is the libre one
git clone https://git.neptards.moe/u3shit/duckstation.git
Log | Files | Refs | README | LICENSE

d3d11_pipeline.cpp (16704B)


      1 // SPDX-FileCopyrightText: 2019-2024 Connor McLaughlin <stenzek@gmail.com>
      2 // SPDX-License-Identifier: (GPL-3.0 OR PolyForm-Strict-1.0.0)
      3 
      4 #include "d3d11_pipeline.h"
      5 #include "d3d11_device.h"
      6 #include "d3d_common.h"
      7 
      8 #include "common/assert.h"
      9 #include "common/error.h"
     10 #include "common/hash_combine.h"
     11 
     12 #include "fmt/format.h"
     13 
     14 #include <array>
     15 #include <malloc.h>
     16 
     17 D3D11Shader::D3D11Shader(GPUShaderStage stage, Microsoft::WRL::ComPtr<ID3D11DeviceChild> shader,
     18                          std::vector<u8> bytecode)
     19   : GPUShader(stage), m_shader(std::move(shader)), m_bytecode(std::move(bytecode))
     20 {
     21 }
     22 
     23 D3D11Shader::~D3D11Shader() = default;
     24 
     25 ID3D11VertexShader* D3D11Shader::GetVertexShader() const
     26 {
     27   DebugAssert(m_stage == GPUShaderStage::Vertex);
     28   return static_cast<ID3D11VertexShader*>(m_shader.Get());
     29 }
     30 
     31 ID3D11PixelShader* D3D11Shader::GetPixelShader() const
     32 {
     33   DebugAssert(m_stage == GPUShaderStage::Fragment);
     34   return static_cast<ID3D11PixelShader*>(m_shader.Get());
     35 }
     36 
     37 ID3D11GeometryShader* D3D11Shader::GetGeometryShader() const
     38 {
     39   DebugAssert(m_stage == GPUShaderStage::Geometry);
     40   return static_cast<ID3D11GeometryShader*>(m_shader.Get());
     41 }
     42 
     43 ID3D11ComputeShader* D3D11Shader::GetComputeShader() const
     44 {
     45   DebugAssert(m_stage == GPUShaderStage::Compute);
     46   return static_cast<ID3D11ComputeShader*>(m_shader.Get());
     47 }
     48 
     49 void D3D11Shader::SetDebugName(std::string_view name)
     50 {
     51   SetD3DDebugObjectName(m_shader.Get(), name);
     52 }
     53 
     54 std::unique_ptr<GPUShader> D3D11Device::CreateShaderFromBinary(GPUShaderStage stage, std::span<const u8> data,
     55                                                                Error* error)
     56 {
     57   ComPtr<ID3D11DeviceChild> shader;
     58   std::vector<u8> bytecode;
     59   HRESULT hr;
     60   switch (stage)
     61   {
     62     case GPUShaderStage::Vertex:
     63       hr = m_device->CreateVertexShader(data.data(), data.size(), nullptr,
     64                                         reinterpret_cast<ID3D11VertexShader**>(shader.GetAddressOf()));
     65       bytecode.resize(data.size());
     66       std::memcpy(bytecode.data(), data.data(), data.size());
     67       break;
     68 
     69     case GPUShaderStage::Fragment:
     70       hr = m_device->CreatePixelShader(data.data(), data.size(), nullptr,
     71                                        reinterpret_cast<ID3D11PixelShader**>(shader.GetAddressOf()));
     72       break;
     73 
     74     case GPUShaderStage::Geometry:
     75       hr = m_device->CreateGeometryShader(data.data(), data.size(), nullptr,
     76                                           reinterpret_cast<ID3D11GeometryShader**>(shader.GetAddressOf()));
     77       break;
     78 
     79     case GPUShaderStage::Compute:
     80       hr = m_device->CreateComputeShader(data.data(), data.size(), nullptr,
     81                                          reinterpret_cast<ID3D11ComputeShader**>(shader.GetAddressOf()));
     82       break;
     83 
     84     default:
     85       UnreachableCode();
     86       hr = S_FALSE;
     87       break;
     88   }
     89 
     90   if (FAILED(hr) || !shader)
     91   {
     92     Error::SetHResult(error, "Create[Typed]Shader() failed: ", hr);
     93     return {};
     94   }
     95 
     96   return std::unique_ptr<GPUShader>(new D3D11Shader(stage, std::move(shader), std::move(bytecode)));
     97 }
     98 
     99 std::unique_ptr<GPUShader> D3D11Device::CreateShaderFromSource(GPUShaderStage stage, GPUShaderLanguage language,
    100                                                                std::string_view source, const char* entry_point,
    101                                                                DynamicHeapArray<u8>* out_binary, Error* error)
    102 {
    103   const u32 shader_model = D3DCommon::GetShaderModelForFeatureLevel(m_device->GetFeatureLevel());
    104   if (language != GPUShaderLanguage::HLSL)
    105   {
    106     return TranspileAndCreateShaderFromSource(stage, language, source, entry_point, GPUShaderLanguage::HLSL,
    107                                               shader_model, out_binary, error);
    108   }
    109 
    110   std::optional<DynamicHeapArray<u8>> bytecode =
    111     D3DCommon::CompileShader(shader_model, m_debug_device, stage, source, entry_point, error);
    112   if (!bytecode.has_value())
    113     return {};
    114 
    115   std::unique_ptr<GPUShader> ret = CreateShaderFromBinary(stage, bytecode.value(), error);
    116   if (ret && out_binary)
    117     *out_binary = std::move(bytecode.value());
    118 
    119   return ret;
    120 }
    121 
    122 D3D11Pipeline::D3D11Pipeline(ComPtr<ID3D11RasterizerState> rs, ComPtr<ID3D11DepthStencilState> ds,
    123                              ComPtr<ID3D11BlendState> bs, ComPtr<ID3D11InputLayout> il, ComPtr<ID3D11VertexShader> vs,
    124                              ComPtr<ID3D11GeometryShader> gs, ComPtr<ID3D11PixelShader> ps,
    125                              D3D11_PRIMITIVE_TOPOLOGY topology, u32 vertex_stride, u32 blend_factor)
    126   : m_rs(std::move(rs)), m_ds(std::move(ds)), m_bs(std::move(bs)), m_il(std::move(il)), m_vs(std::move(vs)),
    127     m_gs(std::move(gs)), m_ps(std::move(ps)), m_topology(topology), m_vertex_stride(vertex_stride),
    128     m_blend_factor(blend_factor), m_blend_factor_float(GPUDevice::RGBA8ToFloat(blend_factor))
    129 {
    130 }
    131 
    132 D3D11Pipeline::~D3D11Pipeline()
    133 {
    134   D3D11Device::GetInstance().UnbindPipeline(this);
    135 }
    136 
    137 void D3D11Pipeline::SetDebugName(std::string_view name)
    138 {
    139   // can't label this directly
    140 }
    141 
    142 D3D11Device::ComPtr<ID3D11RasterizerState> D3D11Device::GetRasterizationState(const GPUPipeline::RasterizationState& rs,
    143                                                                               Error* error)
    144 {
    145   ComPtr<ID3D11RasterizerState> drs;
    146 
    147   const auto it = m_rasterization_states.find(rs.key);
    148   if (it != m_rasterization_states.end())
    149   {
    150     drs = it->second;
    151     return drs;
    152   }
    153 
    154   static constexpr std::array<D3D11_CULL_MODE, static_cast<u32>(GPUPipeline::CullMode::MaxCount)> cull_mapping = {{
    155     D3D11_CULL_NONE,  // None
    156     D3D11_CULL_FRONT, // Front
    157     D3D11_CULL_BACK,  // Back
    158   }};
    159 
    160   D3D11_RASTERIZER_DESC desc = {};
    161   desc.FillMode = D3D11_FILL_SOLID;
    162   desc.CullMode = cull_mapping[static_cast<u8>(rs.cull_mode.GetValue())];
    163   desc.ScissorEnable = TRUE;
    164   // desc.MultisampleEnable ???
    165 
    166   HRESULT hr = m_device->CreateRasterizerState(&desc, drs.GetAddressOf());
    167   if (FAILED(hr)) [[unlikely]]
    168     Error::SetHResult(error, "CreateRasterizerState() failed: ", hr);
    169   else
    170     m_rasterization_states.emplace(rs.key, drs);
    171 
    172   return drs;
    173 }
    174 
    175 D3D11Device::ComPtr<ID3D11DepthStencilState> D3D11Device::GetDepthState(const GPUPipeline::DepthState& ds, Error* error)
    176 {
    177   ComPtr<ID3D11DepthStencilState> dds;
    178 
    179   const auto it = m_depth_states.find(ds.key);
    180   if (it != m_depth_states.end())
    181   {
    182     dds = it->second;
    183     return dds;
    184   }
    185 
    186   static constexpr std::array<D3D11_COMPARISON_FUNC, static_cast<u32>(GPUPipeline::DepthFunc::MaxCount)> func_mapping =
    187     {{
    188       D3D11_COMPARISON_NEVER,         // Never
    189       D3D11_COMPARISON_ALWAYS,        // Always
    190       D3D11_COMPARISON_LESS,          // Less
    191       D3D11_COMPARISON_LESS_EQUAL,    // LessEqual
    192       D3D11_COMPARISON_GREATER,       // Greater
    193       D3D11_COMPARISON_GREATER_EQUAL, // GreaterEqual
    194       D3D11_COMPARISON_EQUAL,         // Equal
    195     }};
    196 
    197   D3D11_DEPTH_STENCIL_DESC desc = {};
    198   desc.DepthEnable = ds.depth_test != GPUPipeline::DepthFunc::Always || ds.depth_write;
    199   desc.DepthFunc = func_mapping[static_cast<u8>(ds.depth_test.GetValue())];
    200   desc.DepthWriteMask = ds.depth_write ? D3D11_DEPTH_WRITE_MASK_ALL : D3D11_DEPTH_WRITE_MASK_ZERO;
    201 
    202   HRESULT hr = m_device->CreateDepthStencilState(&desc, dds.GetAddressOf());
    203   if (FAILED(hr)) [[unlikely]]
    204     Error::SetHResult(error, "CreateDepthStencilState() failed: ", hr);
    205   else
    206     m_depth_states.emplace(ds.key, dds);
    207 
    208   return dds;
    209 }
    210 
    211 size_t D3D11Device::BlendStateMapHash::operator()(const BlendStateMapKey& key) const
    212 {
    213   size_t h = std::hash<u64>()(key.first);
    214   hash_combine(h, key.second);
    215   return h;
    216 }
    217 
    218 D3D11Device::ComPtr<ID3D11BlendState> D3D11Device::GetBlendState(const GPUPipeline::BlendState& bs, u32 num_rts, Error* error)
    219 {
    220   ComPtr<ID3D11BlendState> dbs;
    221 
    222   const std::pair<u64, u32> key(bs.key, num_rts);
    223   const auto it = m_blend_states.find(key);
    224   if (it != m_blend_states.end())
    225   {
    226     dbs = it->second;
    227     return dbs;
    228   }
    229 
    230   static constexpr std::array<D3D11_BLEND, static_cast<u32>(GPUPipeline::BlendFunc::MaxCount)> blend_mapping = {{
    231     D3D11_BLEND_ZERO,             // Zero
    232     D3D11_BLEND_ONE,              // One
    233     D3D11_BLEND_SRC_COLOR,        // SrcColor
    234     D3D11_BLEND_INV_SRC_COLOR,    // InvSrcColor
    235     D3D11_BLEND_DEST_COLOR,       // DstColor
    236     D3D11_BLEND_INV_DEST_COLOR,   // InvDstColor
    237     D3D11_BLEND_SRC_ALPHA,        // SrcAlpha
    238     D3D11_BLEND_INV_SRC_ALPHA,    // InvSrcAlpha
    239     D3D11_BLEND_SRC1_ALPHA,       // SrcAlpha1
    240     D3D11_BLEND_INV_SRC1_ALPHA,   // InvSrcAlpha1
    241     D3D11_BLEND_DEST_ALPHA,       // DstAlpha
    242     D3D11_BLEND_INV_DEST_ALPHA,   // InvDstAlpha
    243     D3D11_BLEND_BLEND_FACTOR,     // ConstantColor
    244     D3D11_BLEND_INV_BLEND_FACTOR, // InvConstantColor
    245   }};
    246 
    247   static constexpr std::array<D3D11_BLEND_OP, static_cast<u32>(GPUPipeline::BlendOp::MaxCount)> op_mapping = {{
    248     D3D11_BLEND_OP_ADD,          // Add
    249     D3D11_BLEND_OP_SUBTRACT,     // Subtract
    250     D3D11_BLEND_OP_REV_SUBTRACT, // ReverseSubtract
    251     D3D11_BLEND_OP_MIN,          // Min
    252     D3D11_BLEND_OP_MAX,          // Max
    253   }};
    254 
    255   D3D11_BLEND_DESC blend_desc = {};
    256   for (u32 i = 0; i < num_rts; i++)
    257   {
    258     D3D11_RENDER_TARGET_BLEND_DESC& tgt_desc = blend_desc.RenderTarget[i];
    259     tgt_desc.BlendEnable = bs.enable;
    260     tgt_desc.RenderTargetWriteMask = bs.write_mask;
    261     if (bs.enable)
    262     {
    263       tgt_desc.SrcBlend = blend_mapping[static_cast<u8>(bs.src_blend.GetValue())];
    264       tgt_desc.DestBlend = blend_mapping[static_cast<u8>(bs.dst_blend.GetValue())];
    265       tgt_desc.BlendOp = op_mapping[static_cast<u8>(bs.blend_op.GetValue())];
    266       tgt_desc.SrcBlendAlpha = blend_mapping[static_cast<u8>(bs.src_alpha_blend.GetValue())];
    267       tgt_desc.DestBlendAlpha = blend_mapping[static_cast<u8>(bs.dst_alpha_blend.GetValue())];
    268       tgt_desc.BlendOpAlpha = op_mapping[static_cast<u8>(bs.alpha_blend_op.GetValue())];
    269     }
    270   }
    271 
    272   HRESULT hr = m_device->CreateBlendState(&blend_desc, dbs.GetAddressOf());
    273   if (FAILED(hr)) [[unlikely]]
    274     Error::SetHResult(error, "CreateBlendState() failed: ", hr);
    275   else
    276     m_blend_states.emplace(key, dbs);
    277 
    278   return dbs;
    279 }
    280 
    281 D3D11Device::ComPtr<ID3D11InputLayout> D3D11Device::GetInputLayout(const GPUPipeline::InputLayout& il,
    282                                                                    const D3D11Shader* vs, Error* error)
    283 {
    284   ComPtr<ID3D11InputLayout> dil;
    285   const auto it = m_input_layouts.find(il);
    286   if (it != m_input_layouts.end())
    287   {
    288     dil = it->second;
    289     return dil;
    290   }
    291 
    292   static constexpr u32 MAX_COMPONENTS = 4;
    293   static constexpr const DXGI_FORMAT
    294     format_mapping[static_cast<u8>(GPUPipeline::VertexAttribute::Type::MaxCount)][MAX_COMPONENTS] = {
    295       {DXGI_FORMAT_R32_FLOAT, DXGI_FORMAT_R32G32_FLOAT, DXGI_FORMAT_R32G32B32_FLOAT,
    296        DXGI_FORMAT_R32G32B32A32_FLOAT},                                                                       // Float
    297       {DXGI_FORMAT_R8_UINT, DXGI_FORMAT_R8G8_UINT, DXGI_FORMAT_UNKNOWN, DXGI_FORMAT_R8G8B8A8_UINT},           // UInt8
    298       {DXGI_FORMAT_R8_SINT, DXGI_FORMAT_R8G8_SINT, DXGI_FORMAT_UNKNOWN, DXGI_FORMAT_R8G8B8A8_SINT},           // SInt8
    299       {DXGI_FORMAT_R8_UNORM, DXGI_FORMAT_R8G8_UNORM, DXGI_FORMAT_UNKNOWN, DXGI_FORMAT_R8G8B8A8_UNORM},        // UNorm8
    300       {DXGI_FORMAT_R16_UINT, DXGI_FORMAT_R16G16_UINT, DXGI_FORMAT_UNKNOWN, DXGI_FORMAT_R16G16B16A16_UINT},    // UInt16
    301       {DXGI_FORMAT_R16_SINT, DXGI_FORMAT_R16G16_SINT, DXGI_FORMAT_UNKNOWN, DXGI_FORMAT_R16G16B16A16_SINT},    // SInt16
    302       {DXGI_FORMAT_R16_UNORM, DXGI_FORMAT_R16G16_UNORM, DXGI_FORMAT_UNKNOWN, DXGI_FORMAT_R16G16B16A16_UNORM}, // UNorm16
    303       {DXGI_FORMAT_R32_UINT, DXGI_FORMAT_R32G32_UINT, DXGI_FORMAT_UNKNOWN, DXGI_FORMAT_R32G32B32A32_UINT},    // UInt32
    304       {DXGI_FORMAT_R32_SINT, DXGI_FORMAT_R32G32_SINT, DXGI_FORMAT_UNKNOWN, DXGI_FORMAT_R32G32B32A32_SINT},    // SInt32
    305     };
    306 
    307   D3D11_INPUT_ELEMENT_DESC* elems =
    308     static_cast<D3D11_INPUT_ELEMENT_DESC*>(alloca(sizeof(D3D11_INPUT_ELEMENT_DESC) * il.vertex_attributes.size()));
    309   for (size_t i = 0; i < il.vertex_attributes.size(); i++)
    310   {
    311     const GPUPipeline::VertexAttribute& va = il.vertex_attributes[i];
    312     Assert(va.components > 0 && va.components <= MAX_COMPONENTS);
    313 
    314     D3D11_INPUT_ELEMENT_DESC& elem = elems[i];
    315     elem.SemanticName = "ATTR";
    316     elem.SemanticIndex = va.index;
    317     elem.Format = format_mapping[static_cast<u8>(va.type.GetValue())][va.components - 1];
    318     elem.InputSlot = 0;
    319     elem.AlignedByteOffset = va.offset;
    320     elem.InputSlotClass = D3D11_INPUT_PER_VERTEX_DATA;
    321     elem.InstanceDataStepRate = 0;
    322   }
    323 
    324   HRESULT hr = m_device->CreateInputLayout(elems, static_cast<UINT>(il.vertex_attributes.size()),
    325                                            vs->GetBytecode().data(), vs->GetBytecode().size(), dil.GetAddressOf());
    326   if (FAILED(hr)) [[unlikely]]
    327     Error::SetHResult(error, "CreateInputLayout() failed: ", hr);
    328   else
    329     m_input_layouts.emplace(il, dil);
    330 
    331   return dil;
    332 }
    333 
    334 std::unique_ptr<GPUPipeline> D3D11Device::CreatePipeline(const GPUPipeline::GraphicsConfig& config, Error* error)
    335 {
    336   ComPtr<ID3D11RasterizerState> rs = GetRasterizationState(config.rasterization, error);
    337   ComPtr<ID3D11DepthStencilState> ds = GetDepthState(config.depth, error);
    338   ComPtr<ID3D11BlendState> bs = GetBlendState(config.blend, config.GetRenderTargetCount(), error);
    339   if (!rs || !ds || !bs)
    340     return {};
    341 
    342   ComPtr<ID3D11InputLayout> il;
    343   u32 vertex_stride = 0;
    344   if (!config.input_layout.vertex_attributes.empty())
    345   {
    346     il = GetInputLayout(config.input_layout, static_cast<const D3D11Shader*>(config.vertex_shader), error);
    347     vertex_stride = config.input_layout.vertex_stride;
    348     if (!il)
    349       return {};
    350   }
    351 
    352   static constexpr std::array<D3D11_PRIMITIVE_TOPOLOGY, static_cast<u32>(GPUPipeline::Primitive::MaxCount)> primitives =
    353     {{
    354       D3D11_PRIMITIVE_TOPOLOGY_POINTLIST,     // Points
    355       D3D11_PRIMITIVE_TOPOLOGY_LINELIST,      // Lines
    356       D3D11_PRIMITIVE_TOPOLOGY_TRIANGLELIST,  // Triangles
    357       D3D11_PRIMITIVE_TOPOLOGY_TRIANGLESTRIP, // TriangleStrips
    358     }};
    359 
    360   return std::unique_ptr<GPUPipeline>(new D3D11Pipeline(
    361     std::move(rs), std::move(ds), std::move(bs), std::move(il),
    362     static_cast<const D3D11Shader*>(config.vertex_shader)->GetVertexShader(),
    363     config.geometry_shader ? static_cast<const D3D11Shader*>(config.geometry_shader)->GetGeometryShader() : nullptr,
    364     static_cast<const D3D11Shader*>(config.fragment_shader)->GetPixelShader(),
    365     primitives[static_cast<u8>(config.primitive)], vertex_stride, config.blend.constant));
    366 }
    367 
    368 void D3D11Device::SetPipeline(GPUPipeline* pipeline)
    369 {
    370   if (m_current_pipeline == pipeline)
    371     return;
    372 
    373   D3D11Pipeline* const PL = static_cast<D3D11Pipeline*>(pipeline);
    374   m_current_pipeline = PL;
    375 
    376   if (ID3D11InputLayout* il = PL->GetInputLayout(); m_current_input_layout != il)
    377   {
    378     m_current_input_layout = il;
    379     m_context->IASetInputLayout(il);
    380   }
    381 
    382   if (const u32 vertex_stride = PL->GetVertexStride(); m_current_vertex_stride != vertex_stride)
    383   {
    384     const UINT offset = 0;
    385     m_current_vertex_stride = PL->GetVertexStride();
    386     m_context->IASetVertexBuffers(0, 1, m_vertex_buffer.GetD3DBufferArray(), &m_current_vertex_stride, &offset);
    387   }
    388 
    389   if (D3D_PRIMITIVE_TOPOLOGY topology = PL->GetPrimitiveTopology(); m_current_primitive_topology != topology)
    390   {
    391     m_current_primitive_topology = topology;
    392     m_context->IASetPrimitiveTopology(topology);
    393   }
    394 
    395   if (ID3D11VertexShader* vs = PL->GetVertexShader(); m_current_vertex_shader != vs)
    396   {
    397     m_current_vertex_shader = vs;
    398     m_context->VSSetShader(vs, nullptr, 0);
    399   }
    400 
    401   if (ID3D11GeometryShader* gs = PL->GetGeometryShader(); m_current_geometry_shader != gs)
    402   {
    403     m_current_geometry_shader = gs;
    404     m_context->GSSetShader(gs, nullptr, 0);
    405   }
    406 
    407   if (ID3D11PixelShader* ps = PL->GetPixelShader(); m_current_pixel_shader != ps)
    408   {
    409     m_current_pixel_shader = ps;
    410     m_context->PSSetShader(ps, nullptr, 0);
    411   }
    412 
    413   if (ID3D11RasterizerState* rs = PL->GetRasterizerState(); m_current_rasterizer_state != rs)
    414   {
    415     m_current_rasterizer_state = rs;
    416     m_context->RSSetState(rs);
    417   }
    418 
    419   if (ID3D11DepthStencilState* ds = PL->GetDepthStencilState(); m_current_depth_state != ds)
    420   {
    421     m_current_depth_state = ds;
    422     m_context->OMSetDepthStencilState(ds, 0);
    423   }
    424 
    425   if (ID3D11BlendState* bs = PL->GetBlendState();
    426       m_current_blend_state != bs || m_current_blend_factor != PL->GetBlendFactor())
    427   {
    428     m_current_blend_state = bs;
    429     m_current_blend_factor = PL->GetBlendFactor();
    430     m_context->OMSetBlendState(bs, RGBA8ToFloat(m_current_blend_factor).data(), 0xFFFFFFFFu);
    431   }
    432 }
    433 
    434 void D3D11Device::UnbindPipeline(D3D11Pipeline* pl)
    435 {
    436   if (m_current_pipeline != pl)
    437     return;
    438 
    439   // Let the runtime deal with the dead objects...
    440   m_current_pipeline = nullptr;
    441 }