d3d11_pipeline.cpp (16704B)
1 // SPDX-FileCopyrightText: 2019-2024 Connor McLaughlin <stenzek@gmail.com> 2 // SPDX-License-Identifier: (GPL-3.0 OR PolyForm-Strict-1.0.0) 3 4 #include "d3d11_pipeline.h" 5 #include "d3d11_device.h" 6 #include "d3d_common.h" 7 8 #include "common/assert.h" 9 #include "common/error.h" 10 #include "common/hash_combine.h" 11 12 #include "fmt/format.h" 13 14 #include <array> 15 #include <malloc.h> 16 17 D3D11Shader::D3D11Shader(GPUShaderStage stage, Microsoft::WRL::ComPtr<ID3D11DeviceChild> shader, 18 std::vector<u8> bytecode) 19 : GPUShader(stage), m_shader(std::move(shader)), m_bytecode(std::move(bytecode)) 20 { 21 } 22 23 D3D11Shader::~D3D11Shader() = default; 24 25 ID3D11VertexShader* D3D11Shader::GetVertexShader() const 26 { 27 DebugAssert(m_stage == GPUShaderStage::Vertex); 28 return static_cast<ID3D11VertexShader*>(m_shader.Get()); 29 } 30 31 ID3D11PixelShader* D3D11Shader::GetPixelShader() const 32 { 33 DebugAssert(m_stage == GPUShaderStage::Fragment); 34 return static_cast<ID3D11PixelShader*>(m_shader.Get()); 35 } 36 37 ID3D11GeometryShader* D3D11Shader::GetGeometryShader() const 38 { 39 DebugAssert(m_stage == GPUShaderStage::Geometry); 40 return static_cast<ID3D11GeometryShader*>(m_shader.Get()); 41 } 42 43 ID3D11ComputeShader* D3D11Shader::GetComputeShader() const 44 { 45 DebugAssert(m_stage == GPUShaderStage::Compute); 46 return static_cast<ID3D11ComputeShader*>(m_shader.Get()); 47 } 48 49 void D3D11Shader::SetDebugName(std::string_view name) 50 { 51 SetD3DDebugObjectName(m_shader.Get(), name); 52 } 53 54 std::unique_ptr<GPUShader> D3D11Device::CreateShaderFromBinary(GPUShaderStage stage, std::span<const u8> data, 55 Error* error) 56 { 57 ComPtr<ID3D11DeviceChild> shader; 58 std::vector<u8> bytecode; 59 HRESULT hr; 60 switch (stage) 61 { 62 case GPUShaderStage::Vertex: 63 hr = m_device->CreateVertexShader(data.data(), data.size(), nullptr, 64 reinterpret_cast<ID3D11VertexShader**>(shader.GetAddressOf())); 65 bytecode.resize(data.size()); 66 std::memcpy(bytecode.data(), data.data(), data.size()); 67 break; 68 69 case GPUShaderStage::Fragment: 70 hr = m_device->CreatePixelShader(data.data(), data.size(), nullptr, 71 reinterpret_cast<ID3D11PixelShader**>(shader.GetAddressOf())); 72 break; 73 74 case GPUShaderStage::Geometry: 75 hr = m_device->CreateGeometryShader(data.data(), data.size(), nullptr, 76 reinterpret_cast<ID3D11GeometryShader**>(shader.GetAddressOf())); 77 break; 78 79 case GPUShaderStage::Compute: 80 hr = m_device->CreateComputeShader(data.data(), data.size(), nullptr, 81 reinterpret_cast<ID3D11ComputeShader**>(shader.GetAddressOf())); 82 break; 83 84 default: 85 UnreachableCode(); 86 hr = S_FALSE; 87 break; 88 } 89 90 if (FAILED(hr) || !shader) 91 { 92 Error::SetHResult(error, "Create[Typed]Shader() failed: ", hr); 93 return {}; 94 } 95 96 return std::unique_ptr<GPUShader>(new D3D11Shader(stage, std::move(shader), std::move(bytecode))); 97 } 98 99 std::unique_ptr<GPUShader> D3D11Device::CreateShaderFromSource(GPUShaderStage stage, GPUShaderLanguage language, 100 std::string_view source, const char* entry_point, 101 DynamicHeapArray<u8>* out_binary, Error* error) 102 { 103 const u32 shader_model = D3DCommon::GetShaderModelForFeatureLevel(m_device->GetFeatureLevel()); 104 if (language != GPUShaderLanguage::HLSL) 105 { 106 return TranspileAndCreateShaderFromSource(stage, language, source, entry_point, GPUShaderLanguage::HLSL, 107 shader_model, out_binary, error); 108 } 109 110 std::optional<DynamicHeapArray<u8>> bytecode = 111 D3DCommon::CompileShader(shader_model, m_debug_device, stage, source, entry_point, error); 112 if (!bytecode.has_value()) 113 return {}; 114 115 std::unique_ptr<GPUShader> ret = CreateShaderFromBinary(stage, bytecode.value(), error); 116 if (ret && out_binary) 117 *out_binary = std::move(bytecode.value()); 118 119 return ret; 120 } 121 122 D3D11Pipeline::D3D11Pipeline(ComPtr<ID3D11RasterizerState> rs, ComPtr<ID3D11DepthStencilState> ds, 123 ComPtr<ID3D11BlendState> bs, ComPtr<ID3D11InputLayout> il, ComPtr<ID3D11VertexShader> vs, 124 ComPtr<ID3D11GeometryShader> gs, ComPtr<ID3D11PixelShader> ps, 125 D3D11_PRIMITIVE_TOPOLOGY topology, u32 vertex_stride, u32 blend_factor) 126 : m_rs(std::move(rs)), m_ds(std::move(ds)), m_bs(std::move(bs)), m_il(std::move(il)), m_vs(std::move(vs)), 127 m_gs(std::move(gs)), m_ps(std::move(ps)), m_topology(topology), m_vertex_stride(vertex_stride), 128 m_blend_factor(blend_factor), m_blend_factor_float(GPUDevice::RGBA8ToFloat(blend_factor)) 129 { 130 } 131 132 D3D11Pipeline::~D3D11Pipeline() 133 { 134 D3D11Device::GetInstance().UnbindPipeline(this); 135 } 136 137 void D3D11Pipeline::SetDebugName(std::string_view name) 138 { 139 // can't label this directly 140 } 141 142 D3D11Device::ComPtr<ID3D11RasterizerState> D3D11Device::GetRasterizationState(const GPUPipeline::RasterizationState& rs, 143 Error* error) 144 { 145 ComPtr<ID3D11RasterizerState> drs; 146 147 const auto it = m_rasterization_states.find(rs.key); 148 if (it != m_rasterization_states.end()) 149 { 150 drs = it->second; 151 return drs; 152 } 153 154 static constexpr std::array<D3D11_CULL_MODE, static_cast<u32>(GPUPipeline::CullMode::MaxCount)> cull_mapping = {{ 155 D3D11_CULL_NONE, // None 156 D3D11_CULL_FRONT, // Front 157 D3D11_CULL_BACK, // Back 158 }}; 159 160 D3D11_RASTERIZER_DESC desc = {}; 161 desc.FillMode = D3D11_FILL_SOLID; 162 desc.CullMode = cull_mapping[static_cast<u8>(rs.cull_mode.GetValue())]; 163 desc.ScissorEnable = TRUE; 164 // desc.MultisampleEnable ??? 165 166 HRESULT hr = m_device->CreateRasterizerState(&desc, drs.GetAddressOf()); 167 if (FAILED(hr)) [[unlikely]] 168 Error::SetHResult(error, "CreateRasterizerState() failed: ", hr); 169 else 170 m_rasterization_states.emplace(rs.key, drs); 171 172 return drs; 173 } 174 175 D3D11Device::ComPtr<ID3D11DepthStencilState> D3D11Device::GetDepthState(const GPUPipeline::DepthState& ds, Error* error) 176 { 177 ComPtr<ID3D11DepthStencilState> dds; 178 179 const auto it = m_depth_states.find(ds.key); 180 if (it != m_depth_states.end()) 181 { 182 dds = it->second; 183 return dds; 184 } 185 186 static constexpr std::array<D3D11_COMPARISON_FUNC, static_cast<u32>(GPUPipeline::DepthFunc::MaxCount)> func_mapping = 187 {{ 188 D3D11_COMPARISON_NEVER, // Never 189 D3D11_COMPARISON_ALWAYS, // Always 190 D3D11_COMPARISON_LESS, // Less 191 D3D11_COMPARISON_LESS_EQUAL, // LessEqual 192 D3D11_COMPARISON_GREATER, // Greater 193 D3D11_COMPARISON_GREATER_EQUAL, // GreaterEqual 194 D3D11_COMPARISON_EQUAL, // Equal 195 }}; 196 197 D3D11_DEPTH_STENCIL_DESC desc = {}; 198 desc.DepthEnable = ds.depth_test != GPUPipeline::DepthFunc::Always || ds.depth_write; 199 desc.DepthFunc = func_mapping[static_cast<u8>(ds.depth_test.GetValue())]; 200 desc.DepthWriteMask = ds.depth_write ? D3D11_DEPTH_WRITE_MASK_ALL : D3D11_DEPTH_WRITE_MASK_ZERO; 201 202 HRESULT hr = m_device->CreateDepthStencilState(&desc, dds.GetAddressOf()); 203 if (FAILED(hr)) [[unlikely]] 204 Error::SetHResult(error, "CreateDepthStencilState() failed: ", hr); 205 else 206 m_depth_states.emplace(ds.key, dds); 207 208 return dds; 209 } 210 211 size_t D3D11Device::BlendStateMapHash::operator()(const BlendStateMapKey& key) const 212 { 213 size_t h = std::hash<u64>()(key.first); 214 hash_combine(h, key.second); 215 return h; 216 } 217 218 D3D11Device::ComPtr<ID3D11BlendState> D3D11Device::GetBlendState(const GPUPipeline::BlendState& bs, u32 num_rts, Error* error) 219 { 220 ComPtr<ID3D11BlendState> dbs; 221 222 const std::pair<u64, u32> key(bs.key, num_rts); 223 const auto it = m_blend_states.find(key); 224 if (it != m_blend_states.end()) 225 { 226 dbs = it->second; 227 return dbs; 228 } 229 230 static constexpr std::array<D3D11_BLEND, static_cast<u32>(GPUPipeline::BlendFunc::MaxCount)> blend_mapping = {{ 231 D3D11_BLEND_ZERO, // Zero 232 D3D11_BLEND_ONE, // One 233 D3D11_BLEND_SRC_COLOR, // SrcColor 234 D3D11_BLEND_INV_SRC_COLOR, // InvSrcColor 235 D3D11_BLEND_DEST_COLOR, // DstColor 236 D3D11_BLEND_INV_DEST_COLOR, // InvDstColor 237 D3D11_BLEND_SRC_ALPHA, // SrcAlpha 238 D3D11_BLEND_INV_SRC_ALPHA, // InvSrcAlpha 239 D3D11_BLEND_SRC1_ALPHA, // SrcAlpha1 240 D3D11_BLEND_INV_SRC1_ALPHA, // InvSrcAlpha1 241 D3D11_BLEND_DEST_ALPHA, // DstAlpha 242 D3D11_BLEND_INV_DEST_ALPHA, // InvDstAlpha 243 D3D11_BLEND_BLEND_FACTOR, // ConstantColor 244 D3D11_BLEND_INV_BLEND_FACTOR, // InvConstantColor 245 }}; 246 247 static constexpr std::array<D3D11_BLEND_OP, static_cast<u32>(GPUPipeline::BlendOp::MaxCount)> op_mapping = {{ 248 D3D11_BLEND_OP_ADD, // Add 249 D3D11_BLEND_OP_SUBTRACT, // Subtract 250 D3D11_BLEND_OP_REV_SUBTRACT, // ReverseSubtract 251 D3D11_BLEND_OP_MIN, // Min 252 D3D11_BLEND_OP_MAX, // Max 253 }}; 254 255 D3D11_BLEND_DESC blend_desc = {}; 256 for (u32 i = 0; i < num_rts; i++) 257 { 258 D3D11_RENDER_TARGET_BLEND_DESC& tgt_desc = blend_desc.RenderTarget[i]; 259 tgt_desc.BlendEnable = bs.enable; 260 tgt_desc.RenderTargetWriteMask = bs.write_mask; 261 if (bs.enable) 262 { 263 tgt_desc.SrcBlend = blend_mapping[static_cast<u8>(bs.src_blend.GetValue())]; 264 tgt_desc.DestBlend = blend_mapping[static_cast<u8>(bs.dst_blend.GetValue())]; 265 tgt_desc.BlendOp = op_mapping[static_cast<u8>(bs.blend_op.GetValue())]; 266 tgt_desc.SrcBlendAlpha = blend_mapping[static_cast<u8>(bs.src_alpha_blend.GetValue())]; 267 tgt_desc.DestBlendAlpha = blend_mapping[static_cast<u8>(bs.dst_alpha_blend.GetValue())]; 268 tgt_desc.BlendOpAlpha = op_mapping[static_cast<u8>(bs.alpha_blend_op.GetValue())]; 269 } 270 } 271 272 HRESULT hr = m_device->CreateBlendState(&blend_desc, dbs.GetAddressOf()); 273 if (FAILED(hr)) [[unlikely]] 274 Error::SetHResult(error, "CreateBlendState() failed: ", hr); 275 else 276 m_blend_states.emplace(key, dbs); 277 278 return dbs; 279 } 280 281 D3D11Device::ComPtr<ID3D11InputLayout> D3D11Device::GetInputLayout(const GPUPipeline::InputLayout& il, 282 const D3D11Shader* vs, Error* error) 283 { 284 ComPtr<ID3D11InputLayout> dil; 285 const auto it = m_input_layouts.find(il); 286 if (it != m_input_layouts.end()) 287 { 288 dil = it->second; 289 return dil; 290 } 291 292 static constexpr u32 MAX_COMPONENTS = 4; 293 static constexpr const DXGI_FORMAT 294 format_mapping[static_cast<u8>(GPUPipeline::VertexAttribute::Type::MaxCount)][MAX_COMPONENTS] = { 295 {DXGI_FORMAT_R32_FLOAT, DXGI_FORMAT_R32G32_FLOAT, DXGI_FORMAT_R32G32B32_FLOAT, 296 DXGI_FORMAT_R32G32B32A32_FLOAT}, // Float 297 {DXGI_FORMAT_R8_UINT, DXGI_FORMAT_R8G8_UINT, DXGI_FORMAT_UNKNOWN, DXGI_FORMAT_R8G8B8A8_UINT}, // UInt8 298 {DXGI_FORMAT_R8_SINT, DXGI_FORMAT_R8G8_SINT, DXGI_FORMAT_UNKNOWN, DXGI_FORMAT_R8G8B8A8_SINT}, // SInt8 299 {DXGI_FORMAT_R8_UNORM, DXGI_FORMAT_R8G8_UNORM, DXGI_FORMAT_UNKNOWN, DXGI_FORMAT_R8G8B8A8_UNORM}, // UNorm8 300 {DXGI_FORMAT_R16_UINT, DXGI_FORMAT_R16G16_UINT, DXGI_FORMAT_UNKNOWN, DXGI_FORMAT_R16G16B16A16_UINT}, // UInt16 301 {DXGI_FORMAT_R16_SINT, DXGI_FORMAT_R16G16_SINT, DXGI_FORMAT_UNKNOWN, DXGI_FORMAT_R16G16B16A16_SINT}, // SInt16 302 {DXGI_FORMAT_R16_UNORM, DXGI_FORMAT_R16G16_UNORM, DXGI_FORMAT_UNKNOWN, DXGI_FORMAT_R16G16B16A16_UNORM}, // UNorm16 303 {DXGI_FORMAT_R32_UINT, DXGI_FORMAT_R32G32_UINT, DXGI_FORMAT_UNKNOWN, DXGI_FORMAT_R32G32B32A32_UINT}, // UInt32 304 {DXGI_FORMAT_R32_SINT, DXGI_FORMAT_R32G32_SINT, DXGI_FORMAT_UNKNOWN, DXGI_FORMAT_R32G32B32A32_SINT}, // SInt32 305 }; 306 307 D3D11_INPUT_ELEMENT_DESC* elems = 308 static_cast<D3D11_INPUT_ELEMENT_DESC*>(alloca(sizeof(D3D11_INPUT_ELEMENT_DESC) * il.vertex_attributes.size())); 309 for (size_t i = 0; i < il.vertex_attributes.size(); i++) 310 { 311 const GPUPipeline::VertexAttribute& va = il.vertex_attributes[i]; 312 Assert(va.components > 0 && va.components <= MAX_COMPONENTS); 313 314 D3D11_INPUT_ELEMENT_DESC& elem = elems[i]; 315 elem.SemanticName = "ATTR"; 316 elem.SemanticIndex = va.index; 317 elem.Format = format_mapping[static_cast<u8>(va.type.GetValue())][va.components - 1]; 318 elem.InputSlot = 0; 319 elem.AlignedByteOffset = va.offset; 320 elem.InputSlotClass = D3D11_INPUT_PER_VERTEX_DATA; 321 elem.InstanceDataStepRate = 0; 322 } 323 324 HRESULT hr = m_device->CreateInputLayout(elems, static_cast<UINT>(il.vertex_attributes.size()), 325 vs->GetBytecode().data(), vs->GetBytecode().size(), dil.GetAddressOf()); 326 if (FAILED(hr)) [[unlikely]] 327 Error::SetHResult(error, "CreateInputLayout() failed: ", hr); 328 else 329 m_input_layouts.emplace(il, dil); 330 331 return dil; 332 } 333 334 std::unique_ptr<GPUPipeline> D3D11Device::CreatePipeline(const GPUPipeline::GraphicsConfig& config, Error* error) 335 { 336 ComPtr<ID3D11RasterizerState> rs = GetRasterizationState(config.rasterization, error); 337 ComPtr<ID3D11DepthStencilState> ds = GetDepthState(config.depth, error); 338 ComPtr<ID3D11BlendState> bs = GetBlendState(config.blend, config.GetRenderTargetCount(), error); 339 if (!rs || !ds || !bs) 340 return {}; 341 342 ComPtr<ID3D11InputLayout> il; 343 u32 vertex_stride = 0; 344 if (!config.input_layout.vertex_attributes.empty()) 345 { 346 il = GetInputLayout(config.input_layout, static_cast<const D3D11Shader*>(config.vertex_shader), error); 347 vertex_stride = config.input_layout.vertex_stride; 348 if (!il) 349 return {}; 350 } 351 352 static constexpr std::array<D3D11_PRIMITIVE_TOPOLOGY, static_cast<u32>(GPUPipeline::Primitive::MaxCount)> primitives = 353 {{ 354 D3D11_PRIMITIVE_TOPOLOGY_POINTLIST, // Points 355 D3D11_PRIMITIVE_TOPOLOGY_LINELIST, // Lines 356 D3D11_PRIMITIVE_TOPOLOGY_TRIANGLELIST, // Triangles 357 D3D11_PRIMITIVE_TOPOLOGY_TRIANGLESTRIP, // TriangleStrips 358 }}; 359 360 return std::unique_ptr<GPUPipeline>(new D3D11Pipeline( 361 std::move(rs), std::move(ds), std::move(bs), std::move(il), 362 static_cast<const D3D11Shader*>(config.vertex_shader)->GetVertexShader(), 363 config.geometry_shader ? static_cast<const D3D11Shader*>(config.geometry_shader)->GetGeometryShader() : nullptr, 364 static_cast<const D3D11Shader*>(config.fragment_shader)->GetPixelShader(), 365 primitives[static_cast<u8>(config.primitive)], vertex_stride, config.blend.constant)); 366 } 367 368 void D3D11Device::SetPipeline(GPUPipeline* pipeline) 369 { 370 if (m_current_pipeline == pipeline) 371 return; 372 373 D3D11Pipeline* const PL = static_cast<D3D11Pipeline*>(pipeline); 374 m_current_pipeline = PL; 375 376 if (ID3D11InputLayout* il = PL->GetInputLayout(); m_current_input_layout != il) 377 { 378 m_current_input_layout = il; 379 m_context->IASetInputLayout(il); 380 } 381 382 if (const u32 vertex_stride = PL->GetVertexStride(); m_current_vertex_stride != vertex_stride) 383 { 384 const UINT offset = 0; 385 m_current_vertex_stride = PL->GetVertexStride(); 386 m_context->IASetVertexBuffers(0, 1, m_vertex_buffer.GetD3DBufferArray(), &m_current_vertex_stride, &offset); 387 } 388 389 if (D3D_PRIMITIVE_TOPOLOGY topology = PL->GetPrimitiveTopology(); m_current_primitive_topology != topology) 390 { 391 m_current_primitive_topology = topology; 392 m_context->IASetPrimitiveTopology(topology); 393 } 394 395 if (ID3D11VertexShader* vs = PL->GetVertexShader(); m_current_vertex_shader != vs) 396 { 397 m_current_vertex_shader = vs; 398 m_context->VSSetShader(vs, nullptr, 0); 399 } 400 401 if (ID3D11GeometryShader* gs = PL->GetGeometryShader(); m_current_geometry_shader != gs) 402 { 403 m_current_geometry_shader = gs; 404 m_context->GSSetShader(gs, nullptr, 0); 405 } 406 407 if (ID3D11PixelShader* ps = PL->GetPixelShader(); m_current_pixel_shader != ps) 408 { 409 m_current_pixel_shader = ps; 410 m_context->PSSetShader(ps, nullptr, 0); 411 } 412 413 if (ID3D11RasterizerState* rs = PL->GetRasterizerState(); m_current_rasterizer_state != rs) 414 { 415 m_current_rasterizer_state = rs; 416 m_context->RSSetState(rs); 417 } 418 419 if (ID3D11DepthStencilState* ds = PL->GetDepthStencilState(); m_current_depth_state != ds) 420 { 421 m_current_depth_state = ds; 422 m_context->OMSetDepthStencilState(ds, 0); 423 } 424 425 if (ID3D11BlendState* bs = PL->GetBlendState(); 426 m_current_blend_state != bs || m_current_blend_factor != PL->GetBlendFactor()) 427 { 428 m_current_blend_state = bs; 429 m_current_blend_factor = PL->GetBlendFactor(); 430 m_context->OMSetBlendState(bs, RGBA8ToFloat(m_current_blend_factor).data(), 0xFFFFFFFFu); 431 } 432 } 433 434 void D3D11Device::UnbindPipeline(D3D11Pipeline* pl) 435 { 436 if (m_current_pipeline != pl) 437 return; 438 439 // Let the runtime deal with the dead objects... 440 m_current_pipeline = nullptr; 441 }