From 44366a0e3083e02aa7db2b9643db121233a1fe60 Mon Sep 17 00:00:00 2001 From: Katarzyna Marek Date: Fri, 27 Jun 2025 15:26:34 +0200 Subject: [PATCH] fix: simplify infer type for apply --- .../dotty/tools/pc/InferExpectedType.scala | 38 +--------- .../tools/pc/completions/Completions.scala | 2 +- .../pc/tests/InferExpectedTypeSuite.scala | 75 +++++++++++++++++++ .../SingletonCompletionsSuite.scala | 26 +++++++ 4 files changed, 106 insertions(+), 35 deletions(-) diff --git a/presentation-compiler/src/main/dotty/tools/pc/InferExpectedType.scala b/presentation-compiler/src/main/dotty/tools/pc/InferExpectedType.scala index 2e6c7b39ba65..8640f518c0f1 100644 --- a/presentation-compiler/src/main/dotty/tools/pc/InferExpectedType.scala +++ b/presentation-compiler/src/main/dotty/tools/pc/InferExpectedType.scala @@ -50,12 +50,12 @@ class InferExpectedType( val indexedCtx = IndexedContext(pos)(using locatedCtx) val printer = ShortenedTypePrinter(search, IncludeDefaultParam.ResolveLater)(using indexedCtx) - InterCompletionType.inferType(path)(using newctx).map{ + InferCompletionType.inferType(path)(using newctx).map{ tpe => printer.tpe(tpe) } case None => None -object InterCompletionType: +object InferCompletionType: def inferType(path: List[Tree])(using Context): Option[Type] = path match case (lit: Literal) :: Select(Literal(_), _) :: Apply(Select(Literal(_), _), List(s: Select)) :: rest if s.symbol == defn.Predef_undefined => inferType(rest, lit.span) @@ -94,37 +94,7 @@ object InterCompletionType: else Some(UnapplyArgs(fun.tpe.finalResultType, fun, pats, NoSourcePosition).argTypes(ind)) // f(@@) case ApplyExtractor(app) => - val argsAndParams = ApplyArgsExtractor.getArgsAndParams(None, app, span).headOption - argsAndParams.flatMap: - case (args, params) => - val idx = args.indexWhere(_.span.contains(span)) - val param = - if idx >= 0 && params.length > idx then Some(params(idx).info) - else None - param match - // def f[T](a: T): T = ??? - // f[Int](@@) - // val _: Int = f(@@) - case Some(t : TypeRef) if t.symbol.is(Flags.TypeParam) => - for - (typeParams, args) <- - app match - case Apply(TypeApply(fun, args), _) => - val typeParams = fun.symbol.paramSymss.headOption.filter(_.forall(_.isTypeParam)) - typeParams.map((_, args.map(_.tpe))) - // val f: (j: "a") => Int - // f(@@) - case Apply(Select(v, StdNames.nme.apply), _) => - v.symbol.info match - case AppliedType(des, args) => - Some((des.typeSymbol.typeParams, args)) - case _ => None - case _ => None - ind = typeParams.indexOf(t.symbol) - tpe <- args.get(ind) - if !tpe.isErroneous - yield tpe - case Some(tpe) => Some(tpe) - case _ => None + val idx = app.args.indexWhere(_.span.contains(span)) + app.fun.tpe.widenTermRefExpr.paramInfoss.flatten.get(idx) case _ => None diff --git a/presentation-compiler/src/main/dotty/tools/pc/completions/Completions.scala b/presentation-compiler/src/main/dotty/tools/pc/completions/Completions.scala index a07f501eedbb..e7902ef8aa44 100644 --- a/presentation-compiler/src/main/dotty/tools/pc/completions/Completions.scala +++ b/presentation-compiler/src/main/dotty/tools/pc/completions/Completions.scala @@ -520,7 +520,7 @@ class Completions( config.isCompletionSnippetsEnabled() ) (args, false) - val singletonCompletions = InterCompletionType.inferType(path).map( + val singletonCompletions = InferCompletionType.inferType(path).map( SingletonCompletions.contribute(path, _, completionPos) ).getOrElse(Nil) (singletonCompletions ++ advanced, exclusive) diff --git a/presentation-compiler/test/dotty/tools/pc/tests/InferExpectedTypeSuite.scala b/presentation-compiler/test/dotty/tools/pc/tests/InferExpectedTypeSuite.scala index ba96488471b6..b796f44f12ca 100644 --- a/presentation-compiler/test/dotty/tools/pc/tests/InferExpectedTypeSuite.scala +++ b/presentation-compiler/test/dotty/tools/pc/tests/InferExpectedTypeSuite.scala @@ -47,6 +47,24 @@ class InferExpectedTypeSuite extends BasePCSuite: |""".stripMargin ) + @Test def `basic-params` = + check( + """|def paint(c: Int, f: String, d: List[String]) = ??? + |val _ = paint(1, "aa", @@) + |""".stripMargin, + """|List[String] + |""".stripMargin + ) + + @Test def `basic-type-param` = + check( + """|def paint[T](c: T) = ??? + |val _ = paint[Int](@@) + |""".stripMargin, + """|Int + |""".stripMargin + ) + @Test def `type-ascription` = check( """|def doo = (@@ : Double) @@ -335,3 +353,60 @@ class InferExpectedTypeSuite extends BasePCSuite: """|String |""".stripMargin ) + + @Test def using = + check( + """|def go(using Ordering[Int])(x: Int, y: Int): Int = + | Ordering[Int].compare(x, y) + | + |def test = + | go(???, @@) + |""".stripMargin, + """|Int + |""".stripMargin + ) + + @Test def `apply-dynamic` = + check( + """|object TypedHoleApplyDynamic { + | val obj: reflect.Selectable { + | def method(x: Int): Unit + | } = new reflect.Selectable { + | def method(x: Int): Unit = () + | } + | + | obj.method(@@) + |} + |""".stripMargin, + "Int" + ) + + @Test def `apply-dynamic-2` = + check( + """|object TypedHoleApplyDynamic { + | val obj: reflect.Selectable { + | def method[T](x: Int, y: T): Unit + | } = new reflect.Selectable { + | def method[T](x: Int, y: T): Unit = () + | } + | + | obj.method[String](1, @@) + |} + |""".stripMargin, + "String" + ) + + @Test def `apply-dynamic-3` = + check( + """|object TypedHoleApplyDynamic { + | val obj: reflect.Selectable { + | def method[T](a: Int)(x: Int, y: T): Unit + | } = new reflect.Selectable { + | def method[T](a: Int)(x: Int, y: T): Unit = () + | } + | + | obj.method[String](1)(1, @@) + |} + |""".stripMargin, + "String" + ) diff --git a/presentation-compiler/test/dotty/tools/pc/tests/completion/SingletonCompletionsSuite.scala b/presentation-compiler/test/dotty/tools/pc/tests/completion/SingletonCompletionsSuite.scala index 25d1418900fd..17e4ad2ad9f3 100644 --- a/presentation-compiler/test/dotty/tools/pc/tests/completion/SingletonCompletionsSuite.scala +++ b/presentation-compiler/test/dotty/tools/pc/tests/completion/SingletonCompletionsSuite.scala @@ -297,4 +297,30 @@ class SingletonCompletionsSuite extends BaseCompletionSuite { """|"foo": "foo" |""".stripMargin ) + + @Test def `type-apply` = + check( + """|class Consumer[A]: + | def eat(a: A) = () + | + |def test = + | Consumer[7].eat(@@) + |""".stripMargin, + "7: 7", + topLines = Some(1) + ) + + @Test def `type-apply-2` = + check( + """|class Consumer[A]: + | def eat(a: A) = () + | + |object Consumer7 extends Consumer[7] + | + |def test = + | Consumer7.eat(@@) + |""".stripMargin, + "7: 7", + topLines = Some(1) + ) }