Skip to content

Commit 5f0b380

Browse files
authored
inline asm!: support writing _ in lieu of return types, for basic inference. (#376)
* Basic type constraints for all non-reserved SPIR-V instructions. * inline asm!: support writing _ in lieu of return types, for basic inference. * Demonstrate using result type inference in inline asm!. * inline asm!: allow inferring the result type of OpSampledImage.
1 parent 1454fe3 commit 5f0b380

File tree

5 files changed

+1042
-24
lines changed

5 files changed

+1042
-24
lines changed

crates/rustc_codegen_spirv/src/builder/spirv_asm.rs

Lines changed: 157 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ use crate::builder_spirv::SpirvValue;
22
use crate::spirv_type::SpirvType;
33

44
use super::Builder;
5+
use crate::codegen_cx::CodegenCx;
56
use rspirv::dr;
67
use rspirv::grammar::{LogicalOperand, OperandKind, OperandQuantifier};
78
use rspirv::spirv::{
@@ -131,8 +132,15 @@ impl<'a, 'tcx> AsmBuilderMethods<'tcx> for Builder<'a, 'tcx> {
131132
}
132133

133134
let mut id_map = HashMap::new();
135+
let mut id_to_type_map = HashMap::new();
136+
for operand in operands {
137+
if let InlineAsmOperandRef::In { reg: _, value } = operand {
138+
let value = value.immediate();
139+
id_to_type_map.insert(value.def(self), value.ty);
140+
}
141+
}
134142
for line in tokens {
135-
self.codegen_asm(&mut id_map, line.into_iter());
143+
self.codegen_asm(&mut id_map, &mut id_to_type_map, line.into_iter());
136144
}
137145
}
138146
}
@@ -248,6 +256,10 @@ impl<'cx, 'tcx> Builder<'cx, 'tcx> {
248256
self.err("OpTypeArray in asm! is not supported yet");
249257
return;
250258
}
259+
Op::TypeSampledImage => SpirvType::SampledImage {
260+
image_type: inst.operands[0].unwrap_id_ref(),
261+
}
262+
.def(self.span(), self),
251263
_ => {
252264
self.emit()
253265
.insert_into_block(dr::InsertPoint::End, inst)
@@ -265,6 +277,7 @@ impl<'cx, 'tcx> Builder<'cx, 'tcx> {
265277
fn codegen_asm<'a>(
266278
&mut self,
267279
id_map: &mut HashMap<&'a str, Word>,
280+
id_to_type_map: &mut HashMap<Word, Word>,
268281
mut tokens: impl Iterator<Item = Token<'a, 'cx, 'tcx>>,
269282
) where
270283
'cx: 'a,
@@ -339,7 +352,10 @@ impl<'cx, 'tcx> Builder<'cx, 'tcx> {
339352
result_id,
340353
operands: vec![],
341354
};
342-
self.parse_operands(id_map, tokens, &mut instruction);
355+
self.parse_operands(id_map, id_to_type_map, tokens, &mut instruction);
356+
if let Some(result_type) = instruction.result_type {
357+
id_to_type_map.insert(instruction.result_id.unwrap(), result_type);
358+
}
343359
self.insert_inst(id_map, instruction);
344360
if let Some(OutRegister::Place(place)) = out_register {
345361
self.emit()
@@ -356,13 +372,15 @@ impl<'cx, 'tcx> Builder<'cx, 'tcx> {
356372
fn parse_operands<'a>(
357373
&mut self,
358374
id_map: &mut HashMap<&'a str, Word>,
375+
id_to_type_map: &HashMap<Word, Word>,
359376
mut tokens: impl Iterator<Item = Token<'a, 'cx, 'tcx>>,
360377
instruction: &mut dr::Instruction,
361378
) where
362379
'cx: 'a,
363380
'tcx: 'a,
364381
{
365382
let mut saw_id_result = false;
383+
let mut need_result_type_infer = false;
366384
for &LogicalOperand { kind, quantifier } in instruction.class.operands {
367385
if kind == OperandKind::IdResult {
368386
assert_eq!(quantifier, OperandQuantifier::One);
@@ -375,6 +393,22 @@ impl<'cx, 'tcx> Builder<'cx, 'tcx> {
375393
saw_id_result = true;
376394
continue;
377395
}
396+
if kind == OperandKind::IdResultType {
397+
assert_eq!(quantifier, OperandQuantifier::One);
398+
if let Some(token) = tokens.next() {
399+
if let Token::Word("_") = token {
400+
need_result_type_infer = true;
401+
} else if let Some(id) = self.parse_id_in(id_map, token) {
402+
instruction.result_type = Some(id);
403+
}
404+
} else {
405+
self.err(&format!(
406+
"instruction {} expects a result type",
407+
instruction.class.opname
408+
));
409+
}
410+
continue;
411+
}
378412
match quantifier {
379413
OperandQuantifier::One => {
380414
if !self.parse_one_operand(id_map, instruction, kind, &mut tokens) {
@@ -406,6 +440,125 @@ impl<'cx, 'tcx> Builder<'cx, 'tcx> {
406440
instruction.class.opname
407441
));
408442
}
443+
444+
if need_result_type_infer {
445+
assert!(instruction.result_type.is_none());
446+
447+
match self.infer_result_type(id_to_type_map, instruction) {
448+
Some(result_type) => instruction.result_type = Some(result_type),
449+
None => self.err(&format!(
450+
"instruction {} cannot have its result type inferred",
451+
instruction.class.opname
452+
)),
453+
}
454+
}
455+
}
456+
457+
fn infer_result_type(
458+
&self,
459+
id_to_type_map: &HashMap<Word, Word>,
460+
instruction: &dr::Instruction,
461+
) -> Option<Word> {
462+
use crate::spirv_type_constraints::{instruction_signatures, InstSig, TyListPat, TyPat};
463+
464+
struct Mismatch;
465+
466+
/// Recursively match `ty` against `pat`, returning one of:
467+
/// * `Ok(None)`: `pat` matched but contained no type variables
468+
/// * `Ok(Some(var))`: `pat` matched and `var` is the type variable
469+
/// * `Err(Mismatch)`: `pat` didn't match or isn't supported right now
470+
fn apply_ty_pat(
471+
cx: &CodegenCx<'_>,
472+
pat: &TyPat<'_>,
473+
ty: Word,
474+
) -> Result<Option<Word>, Mismatch> {
475+
match pat {
476+
TyPat::Any => Ok(None),
477+
&TyPat::T => Ok(Some(ty)),
478+
TyPat::Either(a, b) => {
479+
apply_ty_pat(cx, a, ty).or_else(|Mismatch| apply_ty_pat(cx, b, ty))
480+
}
481+
_ => match (pat, cx.lookup_type(ty)) {
482+
(TyPat::Void, SpirvType::Void) => Ok(None),
483+
(TyPat::Pointer(pat), SpirvType::Pointer { pointee: ty, .. })
484+
| (TyPat::Vector(pat), SpirvType::Vector { element: ty, .. })
485+
| (
486+
TyPat::Vector4(pat),
487+
SpirvType::Vector {
488+
element: ty,
489+
count: 4,
490+
},
491+
)
492+
| (
493+
TyPat::Image(pat),
494+
SpirvType::Image {
495+
sampled_type: ty, ..
496+
},
497+
)
498+
| (TyPat::SampledImage(pat), SpirvType::SampledImage { image_type: ty }) => {
499+
apply_ty_pat(cx, pat, ty)
500+
}
501+
_ => Err(Mismatch),
502+
},
503+
}
504+
}
505+
506+
// FIXME(eddyb) try multiple signatures until one fits.
507+
let mut sig = match instruction_signatures(instruction.class.opcode)? {
508+
[sig @ InstSig {
509+
output: Some(_), ..
510+
}] => *sig,
511+
_ => return None,
512+
};
513+
514+
let mut combined_var = None;
515+
516+
let mut ids = instruction.operands.iter().filter_map(|o| o.id_ref_any());
517+
while let TyListPat::Cons { first: pat, suffix } = *sig.inputs {
518+
let &ty = id_to_type_map.get(&ids.next()?)?;
519+
match apply_ty_pat(self, pat, ty) {
520+
Ok(Some(var)) => match combined_var {
521+
Some(combined_var) => {
522+
// FIXME(eddyb) this could use some error reporting
523+
// (it's a type mismatch), although we could also
524+
// just use the first type and let validation take
525+
// care of the mismatch
526+
if var != combined_var {
527+
return None;
528+
}
529+
}
530+
None => combined_var = Some(var),
531+
},
532+
Ok(None) => {}
533+
Err(Mismatch) => return None,
534+
}
535+
sig.inputs = suffix;
536+
}
537+
match sig.inputs {
538+
TyListPat::Any => {}
539+
TyListPat::Nil => {
540+
if ids.next().is_some() {
541+
return None;
542+
}
543+
}
544+
_ => return None,
545+
}
546+
547+
let var = combined_var?;
548+
match sig.output.unwrap() {
549+
&TyPat::T => Some(var),
550+
TyPat::Vector4(&TyPat::T) => Some(
551+
SpirvType::Vector {
552+
element: var,
553+
count: 4,
554+
}
555+
.def(self.span(), self),
556+
),
557+
TyPat::SampledImage(&TyPat::T) => {
558+
Some(SpirvType::SampledImage { image_type: var }.def(self.span(), self))
559+
}
560+
_ => None,
561+
}
409562
}
410563

411564
fn check_reg(&mut self, span: Span, reg: &InlineAsmRegOrRegClass) {
@@ -668,12 +821,9 @@ impl<'cx, 'tcx> Builder<'cx, 'tcx> {
668821
Token::Typeof(_, _, _) => None,
669822
};
670823
match (kind, word) {
671-
(OperandKind::IdResultType, _) => {
672-
if let Some(id) = self.parse_id_in(id_map, token) {
673-
inst.result_type = Some(id)
674-
}
824+
(OperandKind::IdResultType, _) | (OperandKind::IdResult, _) => {
825+
bug!("should be handled by parse_operands")
675826
}
676-
(OperandKind::IdResult, _) => bug!("should be handled by parse_operands"),
677827
(OperandKind::IdMemorySemantics, _) => {
678828
if let Some(id) = self.parse_id_in(id_map, token) {
679829
inst.operands.push(dr::Operand::IdMemorySemantics(id))

crates/rustc_codegen_spirv/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ mod decorations;
8888
mod link;
8989
mod linker;
9090
mod spirv_type;
91+
mod spirv_type_constraints;
9192
mod symbols;
9293

9394
use builder::Builder;

0 commit comments

Comments
 (0)