diff --git a/wgpu-core/src/device/global.rs b/wgpu-core/src/device/global.rs index 5e6e4fc463e..3c3e8d38dd0 100644 --- a/wgpu-core/src/device/global.rs +++ b/wgpu-core/src/device/global.rs @@ -984,6 +984,18 @@ impl Global { runtime_checks: wgt::ShaderRuntimeChecks::unchecked(), } } + pipeline::ShaderModuleDescriptorPassthrough::Dxil(inner) => { + pipeline::ShaderModuleDescriptor { + label: inner.label.clone(), + runtime_checks: wgt::ShaderRuntimeChecks::unchecked(), + } + } + pipeline::ShaderModuleDescriptorPassthrough::Hlsl(inner) => { + pipeline::ShaderModuleDescriptor { + label: inner.label.clone(), + runtime_checks: wgt::ShaderRuntimeChecks::unchecked(), + } + } }, data, }); diff --git a/wgpu-core/src/device/resource.rs b/wgpu-core/src/device/resource.rs index 53e9586da64..86bb5eb9b1d 100644 --- a/wgpu-core/src/device/resource.rs +++ b/wgpu-core/src/device/resource.rs @@ -1801,6 +1801,22 @@ impl Device { num_workgroups: inner.num_workgroups, } } + pipeline::ShaderModuleDescriptorPassthrough::Dxil(inner) => { + self.require_features(wgt::Features::HLSL_DXIL_SHADER_PASSTHROUGH)?; + hal::ShaderInput::Dxil { + shader: inner.source, + entry_point: inner.entry_point.clone(), + num_workgroups: inner.num_workgroups, + } + } + pipeline::ShaderModuleDescriptorPassthrough::Hlsl(inner) => { + self.require_features(wgt::Features::HLSL_DXIL_SHADER_PASSTHROUGH)?; + hal::ShaderInput::Hlsl { + shader: inner.source, + entry_point: inner.entry_point.clone(), + num_workgroups: inner.num_workgroups, + } + } }; let hal_desc = hal::ShaderModuleDescriptor { diff --git a/wgpu-hal/src/dx12/device.rs b/wgpu-hal/src/dx12/device.rs index 918dd5193b9..24cd3826d4b 100644 --- a/wgpu-hal/src/dx12/device.rs +++ b/wgpu-hal/src/dx12/device.rs @@ -1,3 +1,4 @@ +use alloc::borrow::ToOwned; use alloc::{ borrow::Cow, string::{String, ToString as _}, @@ -264,27 +265,8 @@ impl super::Device { naga_stage: naga::ShaderStage, fragment_stage: Option<&crate::ProgrammableStage>, ) -> Result { - use naga::back::hlsl; - - let frag_ep = fragment_stage - .map(|fs_stage| { - hlsl::FragmentEntryPoint::new(&fs_stage.module.naga.module, fs_stage.entry_point) - .ok_or(crate::PipelineError::EntryPoint( - naga::ShaderStage::Fragment, - )) - }) - .transpose()?; - let stage_bit = auxil::map_naga_stage(naga_stage); - let (module, info) = naga::back::pipeline_constants::process_overrides( - &stage.module.naga.module, - &stage.module.naga.info, - Some((naga_stage, stage.entry_point)), - stage.constants, - ) - .map_err(|e| crate::PipelineError::PipelineConstants(stage_bit, format!("HLSL: {e:?}")))?; - let needs_temp_options = stage.zero_initialize_workgroup_memory != layout.naga_options.zero_initialize_workgroup_memory || stage.module.runtime_checks.bounds_checks != layout.naga_options.restrict_indexing @@ -301,43 +283,90 @@ impl super::Device { &layout.naga_options }; - let pipeline_options = hlsl::PipelineOptions { - entry_point: Some((naga_stage, stage.entry_point.to_string())), - }; + let key = match &stage.module.source { + super::ShaderModuleSource::Naga(naga_shader) => { + use naga::back::hlsl; - //TODO: reuse the writer - let (source, entry_point) = { - let mut source = String::new(); - let mut writer = hlsl::Writer::new(&mut source, naga_options, &pipeline_options); + let frag_ep = match fragment_stage { + Some(crate::ProgrammableStage { + module: + super::ShaderModule { + source: super::ShaderModuleSource::Naga(naga_shader), + .. + }, + entry_point, + .. + }) => Some( + hlsl::FragmentEntryPoint::new(&naga_shader.module, entry_point).ok_or( + crate::PipelineError::EntryPoint(naga::ShaderStage::Fragment), + ), + ), + _ => None, + } + .transpose()?; + let (module, info) = naga::back::pipeline_constants::process_overrides( + &naga_shader.module, + &naga_shader.info, + Some((naga_stage, stage.entry_point)), + stage.constants, + ) + .map_err(|e| { + crate::PipelineError::PipelineConstants(stage_bit, format!("HLSL: {e:?}")) + })?; - profiling::scope!("naga::back::hlsl::write"); - let mut reflection_info = writer - .write(&module, &info, frag_ep.as_ref()) - .map_err(|e| crate::PipelineError::Linkage(stage_bit, format!("HLSL: {e:?}")))?; + let pipeline_options = hlsl::PipelineOptions { + entry_point: Some((naga_stage, stage.entry_point.to_string())), + }; - assert_eq!(reflection_info.entry_point_names.len(), 1); + //TODO: reuse the writer + let (source, entry_point) = { + let mut source = String::new(); + let mut writer = + hlsl::Writer::new(&mut source, naga_options, &pipeline_options); - let entry_point = reflection_info - .entry_point_names - .pop() - .unwrap() - .map_err(|e| crate::PipelineError::Linkage(stage_bit, format!("{e}")))?; + profiling::scope!("naga::back::hlsl::write"); + let mut reflection_info = writer + .write(&module, &info, frag_ep.as_ref()) + .map_err(|e| { + crate::PipelineError::Linkage(stage_bit, format!("HLSL: {e:?}")) + })?; - (source, entry_point) - }; + assert_eq!(reflection_info.entry_point_names.len(), 1); - log::info!( - "Naga generated shader for {:?} at {:?}:\n{}", - entry_point, - naga_stage, - source - ); + let entry_point = reflection_info + .entry_point_names + .pop() + .unwrap() + .map_err(|e| crate::PipelineError::Linkage(stage_bit, format!("{e}")))?; - let key = ShaderCacheKey { - source, - entry_point, - stage: naga_stage, - shader_model: naga_options.shader_model, + (source, entry_point) + }; + log::info!( + "Naga generated shader for {:?} at {:?}:\n{}", + entry_point, + naga_stage, + source + ); + + ShaderCacheKey { + source, + entry_point, + stage: naga_stage, + shader_model: naga_options.shader_model, + } + } + super::ShaderModuleSource::HlslPassthrough(passthrough) => ShaderCacheKey { + source: passthrough.shader.clone(), + entry_point: passthrough.entry_point.clone(), + stage: naga_stage, + shader_model: naga_options.shader_model, + }, + + super::ShaderModuleSource::DxilPassthrough(passthrough) => { + return Ok(super::CompiledShader::Precompiled( + passthrough.shader.clone(), + )) + } }; { @@ -351,11 +380,7 @@ impl super::Device { let source_name = stage.module.raw_name.as_deref(); - let full_stage = format!( - "{}_{}", - naga_stage.to_hlsl_str(), - naga_options.shader_model.to_str() - ); + let full_stage = format!("{}_{}", naga_stage.to_hlsl_str(), key.shader_model.to_str()); let compiled_shader = self.compiler_container.compile( self, @@ -1671,7 +1696,7 @@ impl crate::Device for super::Device { .and_then(|label| alloc::ffi::CString::new(label).ok()); match shader { crate::ShaderInput::Naga(naga) => Ok(super::ShaderModule { - naga, + source: super::ShaderModuleSource::Naga(naga), raw_name, runtime_checks: desc.runtime_checks, }), @@ -1681,6 +1706,32 @@ impl crate::Device for super::Device { crate::ShaderInput::Msl { .. } => { panic!("MSL_SHADER_PASSTHROUGH is not enabled for this backend") } + crate::ShaderInput::Dxil { + shader, + entry_point, + num_workgroups, + } => Ok(super::ShaderModule { + source: super::ShaderModuleSource::DxilPassthrough(super::DxilPassthroughShader { + shader: shader.to_vec(), + entry_point, + num_workgroups, + }), + raw_name, + runtime_checks: desc.runtime_checks, + }), + crate::ShaderInput::Hlsl { + shader, + entry_point, + num_workgroups, + } => Ok(super::ShaderModule { + source: super::ShaderModuleSource::HlslPassthrough(super::HlslPassthroughShader { + shader: shader.to_owned(), + entry_point, + num_workgroups, + }), + raw_name, + runtime_checks: desc.runtime_checks, + }), } } unsafe fn destroy_shader_module(&self, _module: super::ShaderModule) { diff --git a/wgpu-hal/src/dx12/mod.rs b/wgpu-hal/src/dx12/mod.rs index efbffecc4ba..c8695e2f079 100644 --- a/wgpu-hal/src/dx12/mod.rs +++ b/wgpu-hal/src/dx12/mod.rs @@ -1077,7 +1077,7 @@ impl crate::DynPipelineLayout for PipelineLayout {} #[derive(Debug)] pub struct ShaderModule { - naga: crate::NagaShader, + source: ShaderModuleSource, raw_name: Option, runtime_checks: wgt::ShaderRuntimeChecks, } @@ -1109,6 +1109,7 @@ pub(super) struct ShaderCacheValue { pub(super) enum CompiledShader { Dxc(Direct3D::Dxc::IDxcBlob), Fxc(Direct3D::ID3DBlob), + Precompiled(Vec), } impl CompiledShader { @@ -1122,6 +1123,10 @@ impl CompiledShader { pShaderBytecode: unsafe { shader.GetBufferPointer() }, BytecodeLength: unsafe { shader.GetBufferSize() }, }, + CompiledShader::Precompiled(shader) => Direct3D12::D3D12_SHADER_BYTECODE { + pShaderBytecode: shader.as_ptr().cast(), + BytecodeLength: shader.len(), + }, } } } @@ -1490,3 +1495,23 @@ impl crate::Queue for Queue { (1_000_000_000.0 / frequency as f64) as f32 } } +#[derive(Debug)] +pub struct DxilPassthroughShader { + pub shader: Vec, + pub entry_point: String, + pub num_workgroups: (u32, u32, u32), +} + +#[derive(Debug)] +pub struct HlslPassthroughShader { + pub shader: String, + pub entry_point: String, + pub num_workgroups: (u32, u32, u32), +} + +#[derive(Debug)] +pub enum ShaderModuleSource { + Naga(crate::NagaShader), + DxilPassthrough(DxilPassthroughShader), + HlslPassthrough(HlslPassthroughShader), +} diff --git a/wgpu-hal/src/gles/device.rs b/wgpu-hal/src/gles/device.rs index 347904c6ec7..0f36f734b8c 100644 --- a/wgpu-hal/src/gles/device.rs +++ b/wgpu-hal/src/gles/device.rs @@ -1346,6 +1346,9 @@ impl crate::Device for super::Device { panic!("`Features::MSL_SHADER_PASSTHROUGH` is not enabled") } crate::ShaderInput::Naga(naga) => naga, + crate::ShaderInput::Dxil { .. } | crate::ShaderInput::Hlsl { .. } => { + panic!("`Features::HLSL_DXIL_SHADER_PASSTHROUGH` is not enabled") + } }, label: desc.label.map(|str| str.to_string()), id: self.shared.next_shader_id.fetch_add(1, Ordering::Relaxed), diff --git a/wgpu-hal/src/lib.rs b/wgpu-hal/src/lib.rs index 4184c8c6f1b..82a5af8d427 100644 --- a/wgpu-hal/src/lib.rs +++ b/wgpu-hal/src/lib.rs @@ -2104,6 +2104,16 @@ pub enum ShaderInput<'a> { num_workgroups: (u32, u32, u32), }, SpirV(&'a [u32]), + Dxil { + shader: &'a [u8], + entry_point: String, + num_workgroups: (u32, u32, u32), + }, + Hlsl { + shader: &'a str, + entry_point: String, + num_workgroups: (u32, u32, u32), + }, } pub struct ShaderModuleDescriptor<'a> { diff --git a/wgpu-hal/src/metal/device.rs b/wgpu-hal/src/metal/device.rs index b74970e2bd9..6b3aeb3f9a8 100644 --- a/wgpu-hal/src/metal/device.rs +++ b/wgpu-hal/src/metal/device.rs @@ -1039,6 +1039,9 @@ impl crate::Device for super::Device { crate::ShaderInput::SpirV(_) => { panic!("SPIRV_SHADER_PASSTHROUGH is not enabled for this backend") } + crate::ShaderInput::Dxil { .. } | crate::ShaderInput::Hlsl { .. } => { + panic!("`Features::HLSL_DXIL_SHADER_PASSTHROUGH` is not enabled for this backend") + } } } diff --git a/wgpu-hal/src/vulkan/device.rs b/wgpu-hal/src/vulkan/device.rs index 77b9ebe9244..78e958fb1e2 100644 --- a/wgpu-hal/src/vulkan/device.rs +++ b/wgpu-hal/src/vulkan/device.rs @@ -1908,6 +1908,9 @@ impl crate::Device for super::Device { crate::ShaderInput::Msl { .. } => { panic!("MSL_SHADER_PASSTHROUGH is not enabled for this backend") } + crate::ShaderInput::Dxil { .. } | crate::ShaderInput::Hlsl { .. } => { + panic!("`Features::HLSL_DXIL_SHADER_PASSTHROUGH` is not enabled") + } crate::ShaderInput::SpirV(spv) => Cow::Borrowed(spv), }; diff --git a/wgpu-types/src/features.rs b/wgpu-types/src/features.rs index b614b8cb652..e01885fc412 100644 --- a/wgpu-types/src/features.rs +++ b/wgpu-types/src/features.rs @@ -1244,6 +1244,16 @@ bitflags_array! { /// /// [BlasTriangleGeometrySizeDescriptor::vertex_format]: super::BlasTriangleGeometrySizeDescriptor const EXTENDED_ACCELERATION_STRUCTURE_VERTEX_FORMATS = 1 << 51; + + /// Enables creating shader modules from DirectX HLSL or DXIL shaders (unsafe) + /// + /// HLSL/DXIL data is not parsed or interpreted in any way + /// + /// Supported platforms: + /// - DX12 + /// + /// This is a native only feature. + const HLSL_DXIL_SHADER_PASSTHROUGH = 1 << 53; } /// Features that are not guaranteed to be supported. diff --git a/wgpu-types/src/lib.rs b/wgpu-types/src/lib.rs index 18180f0a6e5..7915b88020d 100644 --- a/wgpu-types/src/lib.rs +++ b/wgpu-types/src/lib.rs @@ -7765,6 +7765,10 @@ pub enum CreateShaderModuleDescriptorPassthrough<'a, L> { SpirV(ShaderModuleDescriptorSpirV<'a, L>), /// Passthrough for MSL source code. Msl(ShaderModuleDescriptorMsl<'a, L>), + /// Passthrough for DXIL compiled with DXC + Dxil(ShaderModuleDescriptorDxil<'a, L>), + /// Passthrough for HLSL + Hlsl(ShaderModuleDescriptorHlsl<'a, L>), } impl<'a, L> CreateShaderModuleDescriptorPassthrough<'a, L> { @@ -7790,6 +7794,22 @@ impl<'a, L> CreateShaderModuleDescriptorPassthrough<'a, L> { source: inner.source.clone(), }) } + CreateShaderModuleDescriptorPassthrough::Dxil(inner) => { + CreateShaderModuleDescriptorPassthrough::<'_, K>::Dxil(ShaderModuleDescriptorDxil { + entry_point: inner.entry_point.clone(), + label: fun(&inner.label), + num_workgroups: inner.num_workgroups, + source: inner.source, + }) + } + CreateShaderModuleDescriptorPassthrough::Hlsl(inner) => { + CreateShaderModuleDescriptorPassthrough::<'_, K>::Hlsl(ShaderModuleDescriptorHlsl { + entry_point: inner.entry_point.clone(), + label: fun(&inner.label), + num_workgroups: inner.num_workgroups, + source: inner.source, + }) + } } } @@ -7798,6 +7818,8 @@ impl<'a, L> CreateShaderModuleDescriptorPassthrough<'a, L> { match self { CreateShaderModuleDescriptorPassthrough::SpirV(inner) => &inner.label, CreateShaderModuleDescriptorPassthrough::Msl(inner) => &inner.label, + CreateShaderModuleDescriptorPassthrough::Dxil(inner) => &inner.label, + CreateShaderModuleDescriptorPassthrough::Hlsl(inner) => &inner.label, } } @@ -7809,6 +7831,8 @@ impl<'a, L> CreateShaderModuleDescriptorPassthrough<'a, L> { bytemuck::cast_slice(&inner.source) } CreateShaderModuleDescriptorPassthrough::Msl(inner) => inner.source.as_bytes(), + CreateShaderModuleDescriptorPassthrough::Dxil(inner) => inner.source, + CreateShaderModuleDescriptorPassthrough::Hlsl(inner) => inner.source.as_bytes(), } } @@ -7818,6 +7842,8 @@ impl<'a, L> CreateShaderModuleDescriptorPassthrough<'a, L> { match self { CreateShaderModuleDescriptorPassthrough::SpirV(..) => "spv", CreateShaderModuleDescriptorPassthrough::Msl(..) => "msl", + CreateShaderModuleDescriptorPassthrough::Dxil(..) => "dxil", + CreateShaderModuleDescriptorPassthrough::Hlsl(..) => "hlsl", } } } @@ -7838,6 +7864,38 @@ pub struct ShaderModuleDescriptorMsl<'a, L> { pub source: Cow<'a, str>, } +/// Descriptor for a shader module given by DirectX DXIL source. +/// +/// This type is unique to the Rust API of `wgpu`. In the WebGPU specification, +/// only WGSL source code strings are accepted. +#[derive(Debug, Clone)] +pub struct ShaderModuleDescriptorDxil<'a, L> { + /// Entrypoint. + pub entry_point: String, + /// Debug label of the shader module. This will show up in graphics debuggers for easy identification. + pub label: L, + /// Number of workgroups in each dimension x, y and z. + pub num_workgroups: (u32, u32, u32), + /// Shader DXIL source. + pub source: &'a [u8], +} + +/// Descriptor for a shader module given by DirectX HLSL source. +/// +/// This type is unique to the Rust API of `wgpu`. In the WebGPU specification, +/// only WGSL source code strings are accepted. +#[derive(Debug, Clone)] +pub struct ShaderModuleDescriptorHlsl<'a, L> { + /// Entrypoint. + pub entry_point: String, + /// Debug label of the shader module. This will show up in graphics debuggers for easy identification. + pub label: L, + /// Number of workgroups in each dimension x, y and z. + pub num_workgroups: (u32, u32, u32), + /// Shader HLSL source. + pub source: &'a str, +} + /// Descriptor for a shader module given by SPIR-V binary. /// /// This type is unique to the Rust API of `wgpu`. In the WebGPU specification, diff --git a/wgpu/src/api/shader_module.rs b/wgpu/src/api/shader_module.rs index 1a84d5414cc..c481de6218a 100644 --- a/wgpu/src/api/shader_module.rs +++ b/wgpu/src/api/shader_module.rs @@ -247,3 +247,15 @@ pub type ShaderModuleDescriptorMsl<'a> = wgt::ShaderModuleDescriptorMsl<'a, Labe /// This type is unique to the Rust API of `wgpu`. In the WebGPU specification, /// only WGSL source code strings are accepted. pub type ShaderModuleDescriptorSpirV<'a> = wgt::ShaderModuleDescriptorSpirV<'a, Label<'a>>; + +/// Descriptor for a shader module given by DirectX HLSL source. +/// +/// This type is unique to the Rust API of `wgpu`. In the WebGPU specification, +/// only WGSL source code strings are accepted. +pub type ShaderModuleDescriptorHlsl<'a> = wgt::ShaderModuleDescriptorHlsl<'a, Label<'a>>; + +/// Descriptor for a shader module given by DirectX DXIL source. +/// +/// This type is unique to the Rust API of `wgpu`. In the WebGPU specification, +/// only WGSL source code strings are accepted. +pub type ShaderModuleDescriptorDxil<'a> = wgt::ShaderModuleDescriptorDxil<'a, Label<'a>>;