diff --git a/compiler/rustc_smir/src/stable_mir/ty.rs b/compiler/rustc_smir/src/stable_mir/ty.rs index 4415cd6e2e3ba..92fa97566c5ad 100644 --- a/compiler/rustc_smir/src/stable_mir/ty.rs +++ b/compiler/rustc_smir/src/stable_mir/ty.rs @@ -757,6 +757,12 @@ crate_def! { } impl CoroutineDef { + /// Retrieves the body of the coroutine definition. Returns None if the body + /// isn't available. + pub fn body(&self) -> Option { + with(|cx| cx.has_body(self.0).then(|| cx.mir_body(self.0))) + } + pub fn discriminant_for_variant(&self, args: &GenericArgs, idx: VariantIdx) -> Discr { with(|cx| cx.coroutine_discr_for_variant(*self, args, idx)) } diff --git a/tests/ui-fulldeps/stable-mir/check_coroutine_body.rs b/tests/ui-fulldeps/stable-mir/check_coroutine_body.rs new file mode 100644 index 0000000000000..677734929589d --- /dev/null +++ b/tests/ui-fulldeps/stable-mir/check_coroutine_body.rs @@ -0,0 +1,105 @@ +//@ run-pass +//! Tests stable mir API for retrieving the body of a coroutine. + +//@ ignore-stage1 +//@ ignore-cross-compile +//@ ignore-remote +//@ edition: 2024 + +#![feature(rustc_private)] +#![feature(assert_matches)] + +extern crate rustc_middle; +#[macro_use] +extern crate rustc_smir; +extern crate rustc_driver; +extern crate rustc_interface; +extern crate stable_mir; + +use std::io::Write; +use std::ops::ControlFlow; + +use stable_mir::mir::Body; +use stable_mir::ty::{RigidTy, TyKind}; + +const CRATE_NAME: &str = "crate_coroutine_body"; + +fn test_coroutine_body() -> ControlFlow<()> { + let crate_items = stable_mir::all_local_items(); + if let Some(body) = crate_items.iter().find_map(|item| { + let item_ty = item.ty(); + if let TyKind::RigidTy(RigidTy::Coroutine(def, ..)) = &item_ty.kind() { + if def.0.name() == "gbc::{closure#0}".to_string() { + def.body() + } else { + None + } + } else { + None + } + }) { + check_coroutine_body(body); + } else { + panic!("Cannot find `gbc::{{closure#0}}`. All local items are: {:#?}", crate_items); + } + + ControlFlow::Continue(()) +} + +fn check_coroutine_body(body: Body) { + let ret_ty = &body.locals()[0].ty; + let local_3 = &body.locals()[3].ty; + let local_4 = &body.locals()[4].ty; + + let TyKind::RigidTy(RigidTy::Adt(def, ..)) = &ret_ty.kind() + else { + panic!("Expected RigidTy::Adt, got: {:#?}", ret_ty); + }; + + assert_eq!("std::task::Poll", def.0.name()); + + let TyKind::RigidTy(RigidTy::Coroutine(def, ..)) = &local_3.kind() + else { + panic!("Expected RigidTy::Coroutine, got: {:#?}", local_3); + }; + + assert_eq!("gbc::{closure#0}::{closure#0}", def.0.name()); + + let TyKind::RigidTy(RigidTy::Coroutine(def, ..)) = &local_4.kind() + else { + panic!("Expected RigidTy::Coroutine, got: {:#?}", local_4); + }; + + assert_eq!("gbc::{closure#0}::{closure#0}", def.0.name()); +} + +fn main() { + let path = "coroutine_body.rs"; + generate_input(&path).unwrap(); + let args = &[ + "rustc".to_string(), + "-Cpanic=abort".to_string(), + "--edition".to_string(), + "2024".to_string(), + "--crate-name".to_string(), + CRATE_NAME.to_string(), + path.to_string(), + ]; + run!(args, test_coroutine_body).unwrap(); +} + +fn generate_input(path: &str) -> std::io::Result<()> { + let mut file = std::fs::File::create(path)?; + write!( + file, + r#" + async fn gbc() -> i32 {{ + let a = async {{ 1 }}.await; + a + }} + + fn main() {{}} + "# + )?; + Ok(()) +}