Skip to content

Commit acb05d3

Browse files
committed
Arena-allocate slices to replace Vecs in SpirvType and SpirvConst.
1 parent 1000dec commit acb05d3

File tree

11 files changed

+260
-162
lines changed

11 files changed

+260
-162
lines changed

crates/rustc_codegen_spirv/src/abi.rs

Lines changed: 21 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@ use rustc_middle::ty::{
1717
};
1818
use rustc_middle::{bug, span_bug};
1919
use rustc_span::def_id::DefId;
20-
use rustc_span::Span;
2120
use rustc_span::DUMMY_SP;
21+
use rustc_span::{Span, Symbol};
2222
use rustc_target::abi::call::{ArgAbi, ArgAttributes, FnAbi, PassMode};
2323
use rustc_target::abi::{
2424
Abi, Align, FieldsShape, LayoutS, Primitive, Scalar, Size, TagEncoding, VariantIdx, Variants,
@@ -300,6 +300,7 @@ impl<'tcx> ConvSpirvType<'tcx> for PointeeTy<'tcx> {
300300

301301
impl<'tcx> ConvSpirvType<'tcx> for FnAbi<'tcx, Ty<'tcx>> {
302302
fn spirv_type(&self, span: Span, cx: &CodegenCx<'tcx>) -> Word {
303+
// FIXME(eddyb) use `AccumulateVec`s just like `rustc` itself does.
303304
let mut argument_types = Vec::new();
304305

305306
let return_type = match self.ret.mode {
@@ -332,7 +333,7 @@ impl<'tcx> ConvSpirvType<'tcx> for FnAbi<'tcx, Ty<'tcx>> {
332333

333334
SpirvType::Function {
334335
return_type,
335-
arguments: argument_types,
336+
arguments: &argument_types,
336337
}
337338
.def(span, cx)
338339
}
@@ -364,8 +365,8 @@ impl<'tcx> ConvSpirvType<'tcx> for TyAndLayout<'tcx> {
364365
def_id: def_id_for_spirv_type_adt(*self),
365366
size: Some(Size::ZERO),
366367
align: Align::from_bytes(0).unwrap(),
367-
field_types: Vec::new(),
368-
field_offsets: Vec::new(),
368+
field_types: &[],
369+
field_offsets: &[],
369370
field_names: None,
370371
}
371372
.def_with_name(cx, span, TyLayoutNameKey::from(*self)),
@@ -416,23 +417,24 @@ impl<'tcx> ConvSpirvType<'tcx> for TyAndLayout<'tcx> {
416417
} else {
417418
Some(self.size)
418419
};
420+
// FIXME(eddyb) use `ArrayVec` here.
419421
let mut field_names = Vec::new();
420422
if let TyKind::Adt(adt, _) = self.ty.kind() {
421423
if let Variants::Single { index } = self.variants {
422424
for i in self.fields.index_by_increasing_offset() {
423425
let field = &adt.variants()[index].fields[i];
424-
field_names.push(field.name.to_ident_string());
426+
field_names.push(field.name);
425427
}
426428
}
427429
}
428430
SpirvType::Adt {
429431
def_id: def_id_for_spirv_type_adt(*self),
430432
size,
431433
align: self.align.abi,
432-
field_types: vec![a, b],
433-
field_offsets: vec![a_offset, b_offset],
434+
field_types: &[a, b],
435+
field_offsets: &[a_offset, b_offset],
434436
field_names: if field_names.len() == 2 {
435-
Some(field_names)
437+
Some(&field_names)
436438
} else {
437439
None
438440
},
@@ -598,8 +600,8 @@ fn trans_aggregate<'tcx>(cx: &CodegenCx<'tcx>, span: Span, ty: TyAndLayout<'tcx>
598600
def_id: def_id_for_spirv_type_adt(ty),
599601
size: Some(Size::ZERO),
600602
align: Align::from_bytes(0).unwrap(),
601-
field_types: Vec::new(),
602-
field_offsets: Vec::new(),
603+
field_types: &[],
604+
field_offsets: &[],
603605
field_names: None,
604606
}
605607
.def_with_name(cx, span, TyLayoutNameKey::from(ty))
@@ -664,6 +666,7 @@ pub fn auto_struct_layout<'tcx>(
664666
cx: &CodegenCx<'tcx>,
665667
field_types: &[Word],
666668
) -> (Vec<Size>, Option<Size>, Align) {
669+
// FIXME(eddyb) use `AccumulateVec`s just like `rustc` itself does.
667670
let mut field_offsets = Vec::with_capacity(field_types.len());
668671
let mut offset = Some(Size::ZERO);
669672
let mut max_align = Align::from_bytes(0).unwrap();
@@ -688,6 +691,7 @@ pub fn auto_struct_layout<'tcx>(
688691
fn trans_struct<'tcx>(cx: &CodegenCx<'tcx>, span: Span, ty: TyAndLayout<'tcx>) -> Word {
689692
let size = if ty.is_unsized() { None } else { Some(ty.size) };
690693
let align = ty.align.abi;
694+
// FIXME(eddyb) use `AccumulateVec`s just like `rustc` itself does.
691695
let mut field_types = Vec::new();
692696
let mut field_offsets = Vec::new();
693697
let mut field_names = Vec::new();
@@ -699,17 +703,18 @@ fn trans_struct<'tcx>(cx: &CodegenCx<'tcx>, span: Span, ty: TyAndLayout<'tcx>) -
699703
if let Variants::Single { index } = ty.variants {
700704
if let TyKind::Adt(adt, _) = ty.ty.kind() {
701705
let field = &adt.variants()[index].fields[i];
702-
field_names.push(field.name.to_ident_string());
706+
field_names.push(field.name);
703707
} else {
704-
field_names.push(format!("{}", i));
708+
// FIXME(eddyb) this looks like something that should exist in rustc.
709+
field_names.push(Symbol::intern(&format!("{i}")));
705710
}
706711
} else {
707712
if let TyKind::Adt(_, _) = ty.ty.kind() {
708713
} else {
709714
span_bug!(span, "Variants::Multiple not TyKind::Adt");
710715
}
711716
if i == 0 {
712-
field_names.push("discriminant".to_string());
717+
field_names.push(cx.sym.discriminant);
713718
} else {
714719
cx.tcx.sess.fatal("Variants::Multiple has multiple fields")
715720
}
@@ -719,9 +724,9 @@ fn trans_struct<'tcx>(cx: &CodegenCx<'tcx>, span: Span, ty: TyAndLayout<'tcx>) -
719724
def_id: def_id_for_spirv_type_adt(ty),
720725
size,
721726
align,
722-
field_types,
723-
field_offsets,
724-
field_names: Some(field_names),
727+
field_types: &field_types,
728+
field_offsets: &field_offsets,
729+
field_names: Some(&field_names),
725730
}
726731
.def_with_name(cx, span, TyLayoutNameKey::from(ty))
727732
}

crates/rustc_codegen_spirv/src/builder/builder_methods.rs

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
176176
semantics
177177
}
178178

179-
fn memset_const_pattern(&self, ty: &SpirvType, fill_byte: u8) -> Word {
179+
fn memset_const_pattern(&self, ty: &SpirvType<'tcx>, fill_byte: u8) -> Word {
180180
match *ty {
181181
SpirvType::Void => self.fatal("memset invalid on void pattern"),
182182
SpirvType::Bool => self.fatal("memset invalid on bool pattern"),
@@ -212,7 +212,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
212212
SpirvType::Vector { element, count } | SpirvType::Matrix { element, count } => {
213213
let elem_pat = self.memset_const_pattern(&self.lookup_type(element), fill_byte);
214214
self.constant_composite(
215-
ty.clone().def(self.span(), self),
215+
ty.def(self.span(), self),
216216
iter::repeat(elem_pat).take(count as usize),
217217
)
218218
.def(self)
@@ -221,7 +221,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
221221
let elem_pat = self.memset_const_pattern(&self.lookup_type(element), fill_byte);
222222
let count = self.builder.lookup_const_u64(count).unwrap() as usize;
223223
self.constant_composite(
224-
ty.clone().def(self.span(), self),
224+
ty.def(self.span(), self),
225225
iter::repeat(elem_pat).take(count),
226226
)
227227
.def(self)
@@ -242,7 +242,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
242242
}
243243
}
244244

245-
fn memset_dynamic_pattern(&self, ty: &SpirvType, fill_var: Word) -> Word {
245+
fn memset_dynamic_pattern(&self, ty: &SpirvType<'tcx>, fill_var: Word) -> Word {
246246
match *ty {
247247
SpirvType::Void => self.fatal("memset invalid on void pattern"),
248248
SpirvType::Bool => self.fatal("memset invalid on bool pattern"),
@@ -270,7 +270,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
270270
let count = self.builder.lookup_const_u64(count).unwrap() as usize;
271271
self.emit()
272272
.composite_construct(
273-
ty.clone().def(self.span(), self),
273+
ty.def(self.span(), self),
274274
None,
275275
iter::repeat(elem_pat).take(count),
276276
)
@@ -280,7 +280,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
280280
let elem_pat = self.memset_dynamic_pattern(&self.lookup_type(element), fill_var);
281281
self.emit()
282282
.composite_construct(
283-
ty.clone().def(self.span(), self),
283+
ty.def(self.span(), self),
284284
None,
285285
iter::repeat(elem_pat).take(count as usize),
286286
)
@@ -1260,9 +1260,7 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
12601260
};
12611261
let pointee_kind = self.lookup_type(pointee);
12621262
let result_pointee_type = match pointee_kind {
1263-
SpirvType::Adt {
1264-
ref field_types, ..
1265-
} => field_types[idx as usize],
1263+
SpirvType::Adt { field_types, .. } => field_types[idx as usize],
12661264
SpirvType::Array { element, .. }
12671265
| SpirvType::RuntimeArray { element, .. }
12681266
| SpirvType::Vector { element, .. }
@@ -2345,7 +2343,7 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
23452343
),
23462344
};
23472345

2348-
for (argument, argument_type) in args.iter().zip(argument_types) {
2346+
for (argument, &argument_type) in args.iter().zip(argument_types) {
23492347
assert_ty_eq!(self, argument.ty, argument_type);
23502348
}
23512349
let libm_intrinsic = self.libm_intrinsics.borrow().get(&callee_val).copied();

crates/rustc_codegen_spirv/src/builder/spirv_asm.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1081,7 +1081,7 @@ impl<'cx, 'tcx> Builder<'cx, 'tcx> {
10811081
(OperandKind::LiteralContextDependentNumber, Some(word)) => {
10821082
assert!(matches!(inst.class.opcode, Op::Constant | Op::SpecConstant));
10831083
let ty = inst.result_type.unwrap();
1084-
fn parse(ty: SpirvType, w: &str) -> Result<dr::Operand, String> {
1084+
fn parse(ty: SpirvType<'_>, w: &str) -> Result<dr::Operand, String> {
10851085
fn fmt(x: impl ToString) -> String {
10861086
x.to_string()
10871087
}

crates/rustc_codegen_spirv/src/builder_spirv.rs

Lines changed: 56 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@ use rustc_span::symbol::Symbol;
1313
use rustc_span::{Span, DUMMY_SP};
1414
use std::assert_matches::assert_matches;
1515
use std::cell::{RefCell, RefMut};
16-
use std::rc::Rc;
1716
use std::{fs::File, io::Write, path::Path};
1817

1918
#[derive(Copy, Clone, Debug, Ord, PartialOrd, Eq, PartialEq, Hash)]
@@ -68,7 +67,7 @@ impl SpirvValue {
6867
pub fn const_fold_load(self, cx: &CodegenCx<'_>) -> Option<Self> {
6968
match self.kind {
7069
SpirvValueKind::Def(id) | SpirvValueKind::IllegalConst(id) => {
71-
let entry = cx.builder.id_to_const.borrow().get(&id)?.clone();
70+
let &entry = cx.builder.id_to_const.borrow().get(&id)?;
7271
match entry.val {
7372
SpirvConst::PtrTo { pointee } => {
7473
let ty = match cx.lookup_type(self.ty) {
@@ -213,8 +212,8 @@ impl SpirvValueExt for Word {
213212
}
214213
}
215214

216-
#[derive(Debug, Clone, Ord, PartialOrd, Eq, PartialEq, Hash)]
217-
pub enum SpirvConst {
215+
#[derive(Debug, Copy, Clone, Ord, PartialOrd, Eq, PartialEq, Hash)]
216+
pub enum SpirvConst<'tcx> {
218217
U32(u32),
219218
U64(u64),
220219
/// f32 isn't hash, so store bits
@@ -232,8 +231,7 @@ pub enum SpirvConst {
232231
// different functions, but of the same type, don't overlap their zombies.
233232
ZombieUndefForFnAddr,
234233

235-
// FIXME(eddyb) use `tcx.arena.dropless` to get `&'tcx [_]`, instead of `Rc`.
236-
Composite(Rc<[Word]>),
234+
Composite(&'tcx [Word]),
237235

238236
/// Pointer to constant data, i.e. `&pointee`, represented as an `OpVariable`
239237
/// in the `Private` storage class, and with `pointee` as its initializer.
@@ -242,6 +240,40 @@ pub enum SpirvConst {
242240
},
243241
}
244242

243+
impl SpirvConst<'_> {
244+
/// Replace `&[T]` fields with `&'tcx [T]` ones produced by calling
245+
/// `tcx.arena.dropless.alloc_slice(...)` - this is done late for two reasons:
246+
/// 1. it avoids allocating in the arena when the cache would be hit anyway,
247+
/// which would create "garbage" (as in, unreachable allocations)
248+
/// (ideally these would also be interned, but that's even more refactors)
249+
/// 2. an empty slice is disallowed (as it's usually handled as a special
250+
/// case elsewhere, e.g. `rustc`'s `ty::List` - sadly we can't use that)
251+
fn tcx_arena_alloc_slices<'tcx>(self, cx: &CodegenCx<'tcx>) -> SpirvConst<'tcx> {
252+
fn arena_alloc_slice<'tcx, T: Copy>(cx: &CodegenCx<'tcx>, xs: &[T]) -> &'tcx [T] {
253+
if xs.is_empty() {
254+
&[]
255+
} else {
256+
cx.tcx.arena.dropless.alloc_slice(xs)
257+
}
258+
}
259+
260+
match self {
261+
// FIXME(eddyb) these are all noop cases, could they be automated?
262+
SpirvConst::U32(v) => SpirvConst::U32(v),
263+
SpirvConst::U64(v) => SpirvConst::U64(v),
264+
SpirvConst::F32(v) => SpirvConst::F32(v),
265+
SpirvConst::F64(v) => SpirvConst::F64(v),
266+
SpirvConst::Bool(v) => SpirvConst::Bool(v),
267+
SpirvConst::Null => SpirvConst::Null,
268+
SpirvConst::Undef => SpirvConst::Undef,
269+
SpirvConst::ZombieUndefForFnAddr => SpirvConst::ZombieUndefForFnAddr,
270+
SpirvConst::PtrTo { pointee } => SpirvConst::PtrTo { pointee },
271+
272+
SpirvConst::Composite(fields) => SpirvConst::Composite(arena_alloc_slice(cx, fields)),
273+
}
274+
}
275+
}
276+
245277
#[derive(Debug, Clone, Ord, PartialOrd, Eq, PartialEq, Hash)]
246278
struct WithType<V> {
247279
ty: Word,
@@ -317,22 +349,22 @@ pub struct BuilderCursor {
317349
pub block: Option<usize>,
318350
}
319351

320-
pub struct BuilderSpirv {
352+
pub struct BuilderSpirv<'tcx> {
321353
builder: RefCell<Builder>,
322354

323355
// Bidirectional maps between `SpirvConst` and the ID of the defined global
324356
// (e.g. `OpConstant...`) instruction.
325357
// NOTE(eddyb) both maps have `WithConstLegality` around their keys, which
326358
// allows getting that legality information without additional lookups.
327-
const_to_id: RefCell<FxHashMap<WithType<SpirvConst>, WithConstLegality<Word>>>,
328-
id_to_const: RefCell<FxHashMap<Word, WithConstLegality<SpirvConst>>>,
359+
const_to_id: RefCell<FxHashMap<WithType<SpirvConst<'tcx>>, WithConstLegality<Word>>>,
360+
id_to_const: RefCell<FxHashMap<Word, WithConstLegality<SpirvConst<'tcx>>>>,
329361
string_cache: RefCell<FxHashMap<String, Word>>,
330362

331363
enabled_capabilities: FxHashSet<Capability>,
332364
enabled_extensions: FxHashSet<Symbol>,
333365
}
334366

335-
impl BuilderSpirv {
367+
impl<'tcx> BuilderSpirv<'tcx> {
336368
pub fn new(sym: &Symbols, target: &SpirvTarget, features: &[TargetFeature]) -> Self {
337369
let version = target.spirv_version();
338370
let memory_model = target.memory_model();
@@ -457,7 +489,12 @@ impl BuilderSpirv {
457489
bug!("Function not found: {}", id);
458490
}
459491

460-
pub fn def_constant(&self, ty: Word, val: SpirvConst) -> SpirvValue {
492+
pub(crate) fn def_constant_cx(
493+
&self,
494+
ty: Word,
495+
val: SpirvConst<'_>,
496+
cx: &CodegenCx<'tcx>,
497+
) -> SpirvValue {
461498
let val_with_type = WithType { ty, val };
462499
let mut builder = self.builder(BuilderCursor::default());
463500
if let Some(entry) = self.const_to_id.borrow().get(&val_with_type) {
@@ -486,7 +523,7 @@ impl BuilderSpirv {
486523
SpirvConst::Null => builder.constant_null(ty),
487524
SpirvConst::Undef | SpirvConst::ZombieUndefForFnAddr => builder.undef(ty, None),
488525

489-
SpirvConst::Composite(ref v) => builder.constant_composite(ty, v.iter().copied()),
526+
SpirvConst::Composite(v) => builder.constant_composite(ty, v.iter().copied()),
490527

491528
SpirvConst::PtrTo { pointee } => {
492529
builder.variable(ty, None, StorageClass::Private, Some(pointee))
@@ -517,7 +554,7 @@ impl BuilderSpirv {
517554
Ok(())
518555
}
519556

520-
SpirvConst::Composite(ref v) => v.iter().fold(Ok(()), |composite_legal, field| {
557+
SpirvConst::Composite(v) => v.iter().fold(Ok(()), |composite_legal, field| {
521558
let field_entry = &self.id_to_const.borrow()[field];
522559
let field_legal_in_composite = field_entry.legal.and(
523560
// `field` is itself some legal `SpirvConst`, but can we have
@@ -556,14 +593,11 @@ impl BuilderSpirv {
556593
}
557594
},
558595
};
596+
let val = val.tcx_arena_alloc_slices(cx);
559597
assert_matches!(
560-
self.const_to_id.borrow_mut().insert(
561-
WithType {
562-
ty,
563-
val: val.clone()
564-
},
565-
WithConstLegality { val: id, legal }
566-
),
598+
self.const_to_id
599+
.borrow_mut()
600+
.insert(WithType { ty, val }, WithConstLegality { val: id, legal }),
567601
None
568602
);
569603
assert_matches!(
@@ -581,10 +615,10 @@ impl BuilderSpirv {
581615
SpirvValue { kind, ty }
582616
}
583617

584-
pub fn lookup_const(&self, def: SpirvValue) -> Option<SpirvConst> {
618+
pub fn lookup_const(&self, def: SpirvValue) -> Option<SpirvConst<'tcx>> {
585619
match def.kind {
586620
SpirvValueKind::Def(id) | SpirvValueKind::IllegalConst(id) => {
587-
Some(self.id_to_const.borrow().get(&id)?.val.clone())
621+
Some(self.id_to_const.borrow().get(&id)?.val)
588622
}
589623
_ => None,
590624
}

0 commit comments

Comments
 (0)