@@ -2,6 +2,7 @@ use crate::builder_spirv::SpirvValue;
2
2
use crate :: spirv_type:: SpirvType ;
3
3
4
4
use super :: Builder ;
5
+ use crate :: codegen_cx:: CodegenCx ;
5
6
use rspirv:: dr;
6
7
use rspirv:: grammar:: { LogicalOperand , OperandKind , OperandQuantifier } ;
7
8
use rspirv:: spirv:: {
@@ -131,8 +132,15 @@ impl<'a, 'tcx> AsmBuilderMethods<'tcx> for Builder<'a, 'tcx> {
131
132
}
132
133
133
134
let mut id_map = HashMap :: new ( ) ;
135
+ let mut id_to_type_map = HashMap :: new ( ) ;
136
+ for operand in operands {
137
+ if let InlineAsmOperandRef :: In { reg : _, value } = operand {
138
+ let value = value. immediate ( ) ;
139
+ id_to_type_map. insert ( value. def ( self ) , value. ty ) ;
140
+ }
141
+ }
134
142
for line in tokens {
135
- self . codegen_asm ( & mut id_map, line. into_iter ( ) ) ;
143
+ self . codegen_asm ( & mut id_map, & mut id_to_type_map , line. into_iter ( ) ) ;
136
144
}
137
145
}
138
146
}
@@ -248,6 +256,10 @@ impl<'cx, 'tcx> Builder<'cx, 'tcx> {
248
256
self . err ( "OpTypeArray in asm! is not supported yet" ) ;
249
257
return ;
250
258
}
259
+ Op :: TypeSampledImage => SpirvType :: SampledImage {
260
+ image_type : inst. operands [ 0 ] . unwrap_id_ref ( ) ,
261
+ }
262
+ . def ( self . span ( ) , self ) ,
251
263
_ => {
252
264
self . emit ( )
253
265
. insert_into_block ( dr:: InsertPoint :: End , inst)
@@ -265,6 +277,7 @@ impl<'cx, 'tcx> Builder<'cx, 'tcx> {
265
277
fn codegen_asm < ' a > (
266
278
& mut self ,
267
279
id_map : & mut HashMap < & ' a str , Word > ,
280
+ id_to_type_map : & mut HashMap < Word , Word > ,
268
281
mut tokens : impl Iterator < Item = Token < ' a , ' cx , ' tcx > > ,
269
282
) where
270
283
' cx : ' a ,
@@ -339,7 +352,10 @@ impl<'cx, 'tcx> Builder<'cx, 'tcx> {
339
352
result_id,
340
353
operands : vec ! [ ] ,
341
354
} ;
342
- self . parse_operands ( id_map, tokens, & mut instruction) ;
355
+ self . parse_operands ( id_map, id_to_type_map, tokens, & mut instruction) ;
356
+ if let Some ( result_type) = instruction. result_type {
357
+ id_to_type_map. insert ( instruction. result_id . unwrap ( ) , result_type) ;
358
+ }
343
359
self . insert_inst ( id_map, instruction) ;
344
360
if let Some ( OutRegister :: Place ( place) ) = out_register {
345
361
self . emit ( )
@@ -356,13 +372,15 @@ impl<'cx, 'tcx> Builder<'cx, 'tcx> {
356
372
fn parse_operands < ' a > (
357
373
& mut self ,
358
374
id_map : & mut HashMap < & ' a str , Word > ,
375
+ id_to_type_map : & HashMap < Word , Word > ,
359
376
mut tokens : impl Iterator < Item = Token < ' a , ' cx , ' tcx > > ,
360
377
instruction : & mut dr:: Instruction ,
361
378
) where
362
379
' cx : ' a ,
363
380
' tcx : ' a ,
364
381
{
365
382
let mut saw_id_result = false ;
383
+ let mut need_result_type_infer = false ;
366
384
for & LogicalOperand { kind, quantifier } in instruction. class . operands {
367
385
if kind == OperandKind :: IdResult {
368
386
assert_eq ! ( quantifier, OperandQuantifier :: One ) ;
@@ -375,6 +393,22 @@ impl<'cx, 'tcx> Builder<'cx, 'tcx> {
375
393
saw_id_result = true ;
376
394
continue ;
377
395
}
396
+ if kind == OperandKind :: IdResultType {
397
+ assert_eq ! ( quantifier, OperandQuantifier :: One ) ;
398
+ if let Some ( token) = tokens. next ( ) {
399
+ if let Token :: Word ( "_" ) = token {
400
+ need_result_type_infer = true ;
401
+ } else if let Some ( id) = self . parse_id_in ( id_map, token) {
402
+ instruction. result_type = Some ( id) ;
403
+ }
404
+ } else {
405
+ self . err ( & format ! (
406
+ "instruction {} expects a result type" ,
407
+ instruction. class. opname
408
+ ) ) ;
409
+ }
410
+ continue ;
411
+ }
378
412
match quantifier {
379
413
OperandQuantifier :: One => {
380
414
if !self . parse_one_operand ( id_map, instruction, kind, & mut tokens) {
@@ -406,6 +440,125 @@ impl<'cx, 'tcx> Builder<'cx, 'tcx> {
406
440
instruction. class. opname
407
441
) ) ;
408
442
}
443
+
444
+ if need_result_type_infer {
445
+ assert ! ( instruction. result_type. is_none( ) ) ;
446
+
447
+ match self . infer_result_type ( id_to_type_map, instruction) {
448
+ Some ( result_type) => instruction. result_type = Some ( result_type) ,
449
+ None => self . err ( & format ! (
450
+ "instruction {} cannot have its result type inferred" ,
451
+ instruction. class. opname
452
+ ) ) ,
453
+ }
454
+ }
455
+ }
456
+
457
+ fn infer_result_type (
458
+ & self ,
459
+ id_to_type_map : & HashMap < Word , Word > ,
460
+ instruction : & dr:: Instruction ,
461
+ ) -> Option < Word > {
462
+ use crate :: spirv_type_constraints:: { instruction_signatures, InstSig , TyListPat , TyPat } ;
463
+
464
+ struct Mismatch ;
465
+
466
+ /// Recursively match `ty` against `pat`, returning one of:
467
+ /// * `Ok(None)`: `pat` matched but contained no type variables
468
+ /// * `Ok(Some(var))`: `pat` matched and `var` is the type variable
469
+ /// * `Err(Mismatch)`: `pat` didn't match or isn't supported right now
470
+ fn apply_ty_pat (
471
+ cx : & CodegenCx < ' _ > ,
472
+ pat : & TyPat < ' _ > ,
473
+ ty : Word ,
474
+ ) -> Result < Option < Word > , Mismatch > {
475
+ match pat {
476
+ TyPat :: Any => Ok ( None ) ,
477
+ & TyPat :: T => Ok ( Some ( ty) ) ,
478
+ TyPat :: Either ( a, b) => {
479
+ apply_ty_pat ( cx, a, ty) . or_else ( |Mismatch | apply_ty_pat ( cx, b, ty) )
480
+ }
481
+ _ => match ( pat, cx. lookup_type ( ty) ) {
482
+ ( TyPat :: Void , SpirvType :: Void ) => Ok ( None ) ,
483
+ ( TyPat :: Pointer ( pat) , SpirvType :: Pointer { pointee : ty, .. } )
484
+ | ( TyPat :: Vector ( pat) , SpirvType :: Vector { element : ty, .. } )
485
+ | (
486
+ TyPat :: Vector4 ( pat) ,
487
+ SpirvType :: Vector {
488
+ element : ty,
489
+ count : 4 ,
490
+ } ,
491
+ )
492
+ | (
493
+ TyPat :: Image ( pat) ,
494
+ SpirvType :: Image {
495
+ sampled_type : ty, ..
496
+ } ,
497
+ )
498
+ | ( TyPat :: SampledImage ( pat) , SpirvType :: SampledImage { image_type : ty } ) => {
499
+ apply_ty_pat ( cx, pat, ty)
500
+ }
501
+ _ => Err ( Mismatch ) ,
502
+ } ,
503
+ }
504
+ }
505
+
506
+ // FIXME(eddyb) try multiple signatures until one fits.
507
+ let mut sig = match instruction_signatures ( instruction. class . opcode ) ? {
508
+ [ sig @ InstSig {
509
+ output : Some ( _) , ..
510
+ } ] => * sig,
511
+ _ => return None ,
512
+ } ;
513
+
514
+ let mut combined_var = None ;
515
+
516
+ let mut ids = instruction. operands . iter ( ) . filter_map ( |o| o. id_ref_any ( ) ) ;
517
+ while let TyListPat :: Cons { first : pat, suffix } = * sig. inputs {
518
+ let & ty = id_to_type_map. get ( & ids. next ( ) ?) ?;
519
+ match apply_ty_pat ( self , pat, ty) {
520
+ Ok ( Some ( var) ) => match combined_var {
521
+ Some ( combined_var) => {
522
+ // FIXME(eddyb) this could use some error reporting
523
+ // (it's a type mismatch), although we could also
524
+ // just use the first type and let validation take
525
+ // care of the mismatch
526
+ if var != combined_var {
527
+ return None ;
528
+ }
529
+ }
530
+ None => combined_var = Some ( var) ,
531
+ } ,
532
+ Ok ( None ) => { }
533
+ Err ( Mismatch ) => return None ,
534
+ }
535
+ sig. inputs = suffix;
536
+ }
537
+ match sig. inputs {
538
+ TyListPat :: Any => { }
539
+ TyListPat :: Nil => {
540
+ if ids. next ( ) . is_some ( ) {
541
+ return None ;
542
+ }
543
+ }
544
+ _ => return None ,
545
+ }
546
+
547
+ let var = combined_var?;
548
+ match sig. output . unwrap ( ) {
549
+ & TyPat :: T => Some ( var) ,
550
+ TyPat :: Vector4 ( & TyPat :: T ) => Some (
551
+ SpirvType :: Vector {
552
+ element : var,
553
+ count : 4 ,
554
+ }
555
+ . def ( self . span ( ) , self ) ,
556
+ ) ,
557
+ TyPat :: SampledImage ( & TyPat :: T ) => {
558
+ Some ( SpirvType :: SampledImage { image_type : var } . def ( self . span ( ) , self ) )
559
+ }
560
+ _ => None ,
561
+ }
409
562
}
410
563
411
564
fn check_reg ( & mut self , span : Span , reg : & InlineAsmRegOrRegClass ) {
@@ -668,12 +821,9 @@ impl<'cx, 'tcx> Builder<'cx, 'tcx> {
668
821
Token :: Typeof ( _, _, _) => None ,
669
822
} ;
670
823
match ( kind, word) {
671
- ( OperandKind :: IdResultType , _) => {
672
- if let Some ( id) = self . parse_id_in ( id_map, token) {
673
- inst. result_type = Some ( id)
674
- }
824
+ ( OperandKind :: IdResultType , _) | ( OperandKind :: IdResult , _) => {
825
+ bug ! ( "should be handled by parse_operands" )
675
826
}
676
- ( OperandKind :: IdResult , _) => bug ! ( "should be handled by parse_operands" ) ,
677
827
( OperandKind :: IdMemorySemantics , _) => {
678
828
if let Some ( id) = self . parse_id_in ( id_map, token) {
679
829
inst. operands . push ( dr:: Operand :: IdMemorySemantics ( id) )
0 commit comments