From fe2adaef52855a0e964b46bd60568ad703dca90f Mon Sep 17 00:00:00 2001 From: Marimuthu Madasamy Date: Wed, 20 Oct 2021 00:32:16 -0400 Subject: [PATCH 1/2] Fix constant case expression scope, char conversion --- .../idrisjvm/runtime/Runtime.java | 10 +++++++- .../idrisjvm/runtime/Strings.java | 2 +- src/Compiler/Jvm/Asm.idr | 2 +- src/Compiler/Jvm/Codegen.idr | 12 ++++----- src/Compiler/Jvm/Optimizer.idr | 13 +++++++++- src/Compiler/Jvm/ShowUtil.idr | 25 +++++++++++++++---- src/Compiler/Jvm/Variable.idr | 5 +++- 7 files changed, 53 insertions(+), 16 deletions(-) diff --git a/idris-jvm-runtime/src/main/java/io/github/mmhelloworld/idrisjvm/runtime/Runtime.java b/idris-jvm-runtime/src/main/java/io/github/mmhelloworld/idrisjvm/runtime/Runtime.java index 64be0bc34..a43611cd5 100644 --- a/idris-jvm-runtime/src/main/java/io/github/mmhelloworld/idrisjvm/runtime/Runtime.java +++ b/idris-jvm-runtime/src/main/java/io/github/mmhelloworld/idrisjvm/runtime/Runtime.java @@ -136,7 +136,15 @@ public static int unwrapIntThunk(Object possibleThunk) { if (possibleThunk instanceof Thunk) { return ((Thunk) possibleThunk).getInt(); } else { - return (int) possibleThunk; + return Conversion.toInt1(possibleThunk); + } + } + + public static char unwrapIntThunkToChar(Object possibleThunk) { + if (possibleThunk instanceof Thunk) { + return (char) ((Thunk) possibleThunk).getInt(); + } else { + return (char) possibleThunk; } } diff --git a/idris-jvm-runtime/src/main/java/io/github/mmhelloworld/idrisjvm/runtime/Strings.java b/idris-jvm-runtime/src/main/java/io/github/mmhelloworld/idrisjvm/runtime/Strings.java index 3172a469c..d3f8ae836 100644 --- a/idris-jvm-runtime/src/main/java/io/github/mmhelloworld/idrisjvm/runtime/Strings.java +++ b/idris-jvm-runtime/src/main/java/io/github/mmhelloworld/idrisjvm/runtime/Strings.java @@ -31,7 +31,7 @@ public static String pack(IdrisList idrisCharacterList) { Object[] objectArray = idrisCharacterList.toArray(); char[] chars = new char[objectArray.length]; for (int index = 0; index < objectArray.length; index++) { - chars[index] = (char) objectArray[index]; + chars[index] = Conversion.toChar(objectArray[index]); } return String.valueOf(chars); } diff --git a/src/Compiler/Jvm/Asm.idr b/src/Compiler/Jvm/Asm.idr index d4cd5eb63..c527ea2ec 100644 --- a/src/Compiler/Jvm/Asm.idr +++ b/src/Compiler/Jvm/Asm.idr @@ -719,7 +719,7 @@ Show Scope where ("parentIndex", show $ parentIndex scope), ("nextVariableIndex", show $ nextVariableIndex scope), ("lineNumbers", show $ lineNumbers scope), - ("labels", show $ labels scope), + ("variableIndices", toString $ variableIndices scope), ("returnType", show $ returnType scope), ("nextVariableIndex", show $ nextVariableIndex scope), ("childIndices", show $ childIndices scope) diff --git a/src/Compiler/Jvm/Codegen.idr b/src/Compiler/Jvm/Codegen.idr index e16b74cfa..a16876324 100644 --- a/src/Compiler/Jvm/Codegen.idr +++ b/src/Compiler/Jvm/Codegen.idr @@ -1220,12 +1220,6 @@ mutual assembleCaseWithScope labelStart labelEnd expr assembleConstantSwitch returnType constantType fc sc alts def = do - constantExprVariableSuffixIndex <- newDynamicVariableIndex - let constantExprVariableName = "constantCaseExpr" ++ show constantExprVariableSuffixIndex - constantExprVariableIndex <- getVariableIndex constantExprVariableName - hashCodePositionVariableSuffixIndex <- newDynamicVariableIndex - let hashCodePositionVariableName = "hashCodePosition" ++ show hashCodePositionVariableSuffixIndex - hashCodePositionVariableIndex <- getVariableIndex hashCodePositionVariableName hashPositionAndAlts <- traverse (constantAltHashCodeExpr fc) $ List.zip [0 .. the Int $ cast $ length $ drop 1 alts] alts let positionAndAltsByHash = multiValueMap fst snd hashPositionAndAlts @@ -1236,6 +1230,12 @@ mutual CreateLabel switchEndLabel traverse_ CreateLabel labels assembleExpr False constantType sc + constantExprVariableSuffixIndex <- newDynamicVariableIndex + let constantExprVariableName = "constantCaseExpr" ++ show constantExprVariableSuffixIndex + constantExprVariableIndex <- getVariableIndex constantExprVariableName + hashCodePositionVariableSuffixIndex <- newDynamicVariableIndex + let hashCodePositionVariableName = "hashCodePosition" ++ show hashCodePositionVariableSuffixIndex + hashCodePositionVariableIndex <- getVariableIndex hashCodePositionVariableName storeVar constantType constantType constantExprVariableIndex constantClass <- getHashCodeSwitchClass fc constantType Iconst (-1) diff --git a/src/Compiler/Jvm/Optimizer.idr b/src/Compiler/Jvm/Optimizer.idr index 1bffaf69d..555b33601 100644 --- a/src/Compiler/Jvm/Optimizer.idr +++ b/src/Compiler/Jvm/Optimizer.idr @@ -82,7 +82,7 @@ getLineNumbers (lineStart, _) (lineEnd, colEnd) = getFileName : OriginDesc -> String getFileName (PhysicalIdrSrc moduleIdent) = case unsafeUnfoldModuleIdent moduleIdent of - (moduleName :: _) => moduleName + (moduleName :: _) => moduleName ++ ".idr" _ => "(unknown-source)" getFileName (PhysicalPkgSrc fname) = fname getFileName (Virtual Interactive) = "(Interactive)" @@ -470,13 +470,21 @@ getConstantType : List NamedConstAlt -> Asm InferredType getConstantType [] = Throw emptyFC "Unknown constant switch type" getConstantType ((MkNConstAlt constant _) :: _) = case constant of I _ => Pure IInt + B8 _ => Pure IInt + B16 _ => Pure IInt + B32 _ => Pure IInt Ch _ => Pure IInt Str _ => Pure inferredStringType BI _ => Pure inferredBigIntegerType + B64 _ => Pure inferredBigIntegerType unsupportedConstant => Throw emptyFC $ "Unsupported constant switch " ++ show unsupportedConstant export isTypeConst : TT.Constant -> Bool +isTypeConst Bits8Type = True +isTypeConst Bits16Type = True +isTypeConst Bits32Type = True +isTypeConst Bits64Type = True isTypeConst IntType = True isTypeConst IntegerType = True isTypeConst StringType = True @@ -488,6 +496,9 @@ isTypeConst _ = False export getIntConstantValue : FC -> TT.Constant -> Asm Int getIntConstantValue _ (I i) = Pure i +getIntConstantValue _ (B8 i) = Pure i +getIntConstantValue _ (B16 i) = Pure i +getIntConstantValue _ (B32 i) = Pure i getIntConstantValue _ (Ch c) = Pure $ ord c getIntConstantValue _ WorldVal = Pure 0 getIntConstantValue fc x = diff --git a/src/Compiler/Jvm/ShowUtil.idr b/src/Compiler/Jvm/ShowUtil.idr index c21acaf22..db89e9fc1 100644 --- a/src/Compiler/Jvm/ShowUtil.idr +++ b/src/Compiler/Jvm/ShowUtil.idr @@ -22,6 +22,22 @@ showType typeName properties = showObj (("__name", quoted typeName) :: propertie indent : Nat -> String -> String indent n s = concat (List.replicate (n * 4) " ") ++ s +showConstant : Constant -> String +showConstant (I value) = "prim$I$" ++ show value +showConstant (BI value) = "prim$BI$" ++ show value +showConstant (Ch value) = "prim$Ch$" ++ show value +showConstant (Str value) = "prim$Str$" ++ show value +showConstant (I8 value) = "prim$I8$" ++ show value +showConstant (I16 value) = "prim$I16$" ++ show value +showConstant (I32 value) = "prim$I32$" ++ show value +showConstant (I64 value) = "prim$I64$" ++ show value +showConstant (B8 value) = "prim$B8$" ++ show value +showConstant (B16 value) = "prim$B16$" ++ show value +showConstant (B32 value) = "prim$B32$" ++ show value +showConstant (B64 value) = "prim$B64$" ++ show value +showConstant (Db value) = "prim$Db$" ++ show value +showConstant value = "prim$" ++ show value + mutual export showNamedCExp : Nat -> NamedCExp -> String @@ -43,17 +59,16 @@ mutual showNamedCExp n (NmForce _ _ x) = "force" ++ "(" ++ showNamedCExp n x ++ ")" showNamedCExp n (NmDelay _ _ x) = "delay" ++ "(" ++ showNamedCExp n x ++ ")" showNamedCExp n (NmConCase fc sc xs def) = "\n" ++ - indent n ("constructorswitch" ++ "(" ++ showNamedCExp n sc ++ ") \n") ++ + indent n ("constructorswitch" ++ "(" ++ showNamedCExp n sc ++ ")\n") ++ showSep "\n" (showNamedConAlt (n + 1) <$> xs) ++ maybe "" (\defExp => "\n" ++ indent (n + 1) "default:\n" ++ indent (n + 1) (showNamedCExp (n + 1) defExp)) def showNamedCExp n (NmConstCase fc sc xs def) = "\n" ++ - indent n ("constantswitch" ++"(" ++ show sc ++ ")\n") ++ + indent n ("constantswitch" ++"(" ++ showNamedCExp n sc ++ ")\n") ++ showSep "\n" (showNamedConstAlt (n + 1) <$> xs) ++ maybe "" (\defExp => "\n" ++ indent (n + 1) "default:\n" ++ indent (n + 1) (showNamedCExp (n + 2) defExp)) def - showNamedCExp n (NmPrimVal fc (BI value)) = "prim$BI$" ++ show value - showNamedCExp n (NmPrimVal fc x) = "prim$" ++ show x + showNamedCExp n (NmPrimVal fc x) = showConstant x showNamedCExp n (NmErased fc) = "erased" showNamedCExp n (NmCrash fc x) = "crash " ++ show x @@ -65,7 +80,7 @@ mutual export showNamedConstAlt : Nat -> NamedConstAlt -> String - showNamedConstAlt n (MkNConstAlt x exp) = indent n ("case " ++ show x ++ ":\n") ++ + showNamedConstAlt n (MkNConstAlt x exp) = indent n ("case " ++ showConstant x ++ ":\n") ++ indent (n + 1) (showNamedCExp (n + 1) exp) diff --git a/src/Compiler/Jvm/Variable.idr b/src/Compiler/Jvm/Variable.idr index f9936634c..5acf3ad64 100644 --- a/src/Compiler/Jvm/Variable.idr +++ b/src/Compiler/Jvm/Variable.idr @@ -92,6 +92,9 @@ unboxToDoubleThunk = unwrapIntThunk : Asm () unwrapIntThunk = InvokeMethod InvokeStatic runtimeClass "unwrapIntThunk" "(Ljava/lang/Object;)I" False +unwrapIntThunkToChar : Asm () +unwrapIntThunkToChar = InvokeMethod InvokeStatic runtimeClass "unwrapIntThunkToChar" "(Ljava/lang/Object;)C" False + unwrapLongThunk : Asm () unwrapLongThunk = InvokeMethod InvokeStatic runtimeClass "unwrapLongThunk" "(Ljava/lang/Object;)J" False @@ -295,7 +298,7 @@ loadAndUnboxByte ty sourceLocTys var = loadAndUnboxChar : InferredType -> Map Int InferredType -> Int -> Asm () loadAndUnboxChar ty sourceLocTys var = - let loadInstr = \index => do Aload index; if ty == intThunkType then unwrapIntThunk else objToChar + let loadInstr = \index => do Aload index; if ty == intThunkType then unwrapIntThunkToChar else objToChar in opWithWordSize sourceLocTys loadInstr var loadAndUnboxShort : InferredType -> Map Int InferredType -> Int -> Asm () From 5f9d2f4005504d2110c8f11013e8cc8c4b21586b Mon Sep 17 00:00:00 2001 From: Marimuthu Madasamy Date: Mon, 27 Dec 2021 02:03:03 -0500 Subject: [PATCH 2/2] Enhance trampoline, fix function signature, optimize lambda --- .github/workflows/release.yml | 2 +- .../idrisjvm/assembler/AsmGlobalState.java | 20 +- .../idrisjvm/assembler/Assembler.java | 10 +- idris-jvm-compiler/pom.xml | 4 +- .../idrisjvm/runtime/Function3.java | 5 + .../idrisjvm/runtime/Function4.java | 6 + .../idrisjvm/runtime/Function5.java | 6 + .../idrisjvm/runtime/Functions.java | 37 ++++ .../idrisjvm/runtime/MemoizedDelayed.java | 4 +- .../idrisjvm/runtime/Runtime.java | 12 +- .../mmhelloworld/idrisjvm/runtime/Thunk.java | 8 +- src/Compiler/Jvm/Asm.idr | 33 +++- src/Compiler/Jvm/Codegen.idr | 179 ++++++++++++++---- src/Compiler/Jvm/Foreign.idr | 21 +- src/Compiler/Jvm/InferredType.idr | 17 +- src/Compiler/Jvm/Optimizer.idr | 139 +++++++++----- src/Compiler/Jvm/Tree.idr | 18 +- src/Compiler/Jvm/Variable.idr | 10 +- 18 files changed, 381 insertions(+), 150 deletions(-) create mode 100644 idris-jvm-runtime/src/main/java/io/github/mmhelloworld/idrisjvm/runtime/Function3.java create mode 100644 idris-jvm-runtime/src/main/java/io/github/mmhelloworld/idrisjvm/runtime/Function4.java create mode 100644 idris-jvm-runtime/src/main/java/io/github/mmhelloworld/idrisjvm/runtime/Function5.java create mode 100644 idris-jvm-runtime/src/main/java/io/github/mmhelloworld/idrisjvm/runtime/Functions.java diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index f63a77a45..b78b5a152 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -39,6 +39,6 @@ jobs: with: repo_token: "${{ secrets.GITHUB_TOKEN }}" prerelease: false - title: "Release 0.4.0-rc.2" + title: "Release 0.4.0-rc.3" files: | idris-jvm-compiler/target/idris2-0.4.0-SNAPSHOT.zip diff --git a/idris-jvm-assembler/src/main/java/io/github/mmhelloworld/idrisjvm/assembler/AsmGlobalState.java b/idris-jvm-assembler/src/main/java/io/github/mmhelloworld/idrisjvm/assembler/AsmGlobalState.java index c4e73cc21..ce8612b3d 100644 --- a/idris-jvm-assembler/src/main/java/io/github/mmhelloworld/idrisjvm/assembler/AsmGlobalState.java +++ b/idris-jvm-assembler/src/main/java/io/github/mmhelloworld/idrisjvm/assembler/AsmGlobalState.java @@ -7,14 +7,19 @@ import java.io.IOException; import java.io.OutputStream; import java.lang.reflect.InvocationTargetException; +import java.util.Collection; import java.util.HashSet; import java.util.Map; import java.util.Map.Entry; import java.util.Set; import java.util.concurrent.ConcurrentHashMap; +import java.util.function.Predicate; +import java.util.regex.Pattern; import java.util.stream.Stream; +import static java.util.Collections.emptyList; import static java.util.Collections.synchronizedSet; +import static java.util.stream.Collectors.toList; import static java.util.stream.Collectors.toMap; public final class AsmGlobalState { @@ -23,16 +28,25 @@ public final class AsmGlobalState { private final Set untypedFunctions; private final Set constructors; private final String programName; + private final Collection> trampolinePredicates; private final Map assemblers; - public AsmGlobalState(String programName) { + public AsmGlobalState(String programName, Collection trampolinePatterns) { this.programName = programName; + this.trampolinePredicates = trampolinePatterns.stream() + .map(Pattern::compile) + .map(Pattern::asPredicate) + .collect(toList()); functions = new ConcurrentHashMap<>(); untypedFunctions = synchronizedSet(new HashSet<>()); constructors = synchronizedSet(new HashSet<>()); assemblers = new ConcurrentHashMap<>(); } + public AsmGlobalState(String programName) { + this(programName, emptyList()); + } + public synchronized void addFunction(String name, Object value) { functions.put(name, value); } @@ -105,4 +119,8 @@ public void writeClass(String className, ClassWriter classWriter, String outputC } } + public boolean shouldTrampoline(String name) { + return trampolinePredicates.stream() + .anyMatch(trampolinePredicate -> trampolinePredicate.test(name)); + } } diff --git a/idris-jvm-assembler/src/main/java/io/github/mmhelloworld/idrisjvm/assembler/Assembler.java b/idris-jvm-assembler/src/main/java/io/github/mmhelloworld/idrisjvm/assembler/Assembler.java index 6a3f116b9..324fc6fe3 100644 --- a/idris-jvm-assembler/src/main/java/io/github/mmhelloworld/idrisjvm/assembler/Assembler.java +++ b/idris-jvm-assembler/src/main/java/io/github/mmhelloworld/idrisjvm/assembler/Assembler.java @@ -1132,11 +1132,13 @@ public void lineNumber(int lineNumber, String label) { public void localVariable(String name, String typeDescriptor, String signature, String lineNumberStartLabel, String lineNumberEndLabel, int index) { Label start = (Label) env.get(lineNumberStartLabel); - requireNonNull(start, format("Line number start label '%s' for variable %s at index %d must not be null", - lineNumberStartLabel, name, index)); + requireNonNull(start, + format("Line number start label '%s' for variable %s at index %d must not be null for method %s/%s", + lineNumberStartLabel, name, index, className, methodName)); Label end = (Label) env.get(lineNumberEndLabel); - requireNonNull(end, format("Line number end label '%s' for variable %s at index %d must not be null", - lineNumberEndLabel, name, index)); + requireNonNull(end, + format("Line number end label '%s' for variable %s at index %d must not be null for method %s/%s", + lineNumberEndLabel, name, index, className, methodName)); mv.visitLocalVariable(name, typeDescriptor, signature, start, end, index); } diff --git a/idris-jvm-compiler/pom.xml b/idris-jvm-compiler/pom.xml index 5c137c50f..8deebe7fe 100644 --- a/idris-jvm-compiler/pom.xml +++ b/idris-jvm-compiler/pom.xml @@ -175,7 +175,7 @@ exec/idris2_app flat ${project.parent.basedir}/build - -Xss36m -Xms3g -Xmx3g + -Xss70m -Xms3g -Xmx3g idris2.Main @@ -194,7 +194,7 @@ lib flat ${project.build.directory}/assembly - -Xss36m -Xms3g -Xmx3g + -Xss70m -Xms3g -Xmx3g idris2.Main diff --git a/idris-jvm-runtime/src/main/java/io/github/mmhelloworld/idrisjvm/runtime/Function3.java b/idris-jvm-runtime/src/main/java/io/github/mmhelloworld/idrisjvm/runtime/Function3.java new file mode 100644 index 000000000..d23031d53 --- /dev/null +++ b/idris-jvm-runtime/src/main/java/io/github/mmhelloworld/idrisjvm/runtime/Function3.java @@ -0,0 +1,5 @@ +package io.github.mmhelloworld.idrisjvm.runtime; + +public interface Function3 { + R apply(T1 t1, T2 t2, T3 t3); +} diff --git a/idris-jvm-runtime/src/main/java/io/github/mmhelloworld/idrisjvm/runtime/Function4.java b/idris-jvm-runtime/src/main/java/io/github/mmhelloworld/idrisjvm/runtime/Function4.java new file mode 100644 index 000000000..06c17fbaf --- /dev/null +++ b/idris-jvm-runtime/src/main/java/io/github/mmhelloworld/idrisjvm/runtime/Function4.java @@ -0,0 +1,6 @@ +package io.github.mmhelloworld.idrisjvm.runtime; + +public interface Function4 { + R apply(T1 t1, T2 t2, T3 t3, T4 t4); +} + diff --git a/idris-jvm-runtime/src/main/java/io/github/mmhelloworld/idrisjvm/runtime/Function5.java b/idris-jvm-runtime/src/main/java/io/github/mmhelloworld/idrisjvm/runtime/Function5.java new file mode 100644 index 000000000..8a6f30761 --- /dev/null +++ b/idris-jvm-runtime/src/main/java/io/github/mmhelloworld/idrisjvm/runtime/Function5.java @@ -0,0 +1,6 @@ +package io.github.mmhelloworld.idrisjvm.runtime; + +public interface Function5 { + R apply(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5); +} + diff --git a/idris-jvm-runtime/src/main/java/io/github/mmhelloworld/idrisjvm/runtime/Functions.java b/idris-jvm-runtime/src/main/java/io/github/mmhelloworld/idrisjvm/runtime/Functions.java new file mode 100644 index 000000000..e4c175106 --- /dev/null +++ b/idris-jvm-runtime/src/main/java/io/github/mmhelloworld/idrisjvm/runtime/Functions.java @@ -0,0 +1,37 @@ +package io.github.mmhelloworld.idrisjvm.runtime; + +import java.util.function.BiFunction; +import java.util.function.Function; + +public final class Functions { + private Functions() { + } + + public static final Function IDENTITY = a -> a; + + public static final Function> IDENTITY_1 = c -> IDENTITY; + + public static final Function>> IDENTITY_2 = d -> IDENTITY_1; + + public static final Function> CONSTANT = a -> b -> a; + + public static final Function>> CONSTANT_1 = c -> CONSTANT; + + public static Function> curry(BiFunction f) { + return t1 -> t2 -> f.apply(t1, t2); + } + + public static Function>> curry(Function3 f) { + return t1 -> t2 -> t3 -> f.apply(t1, t2, t3); + } + + public static Function>>> curry( + Function4 f) { + return t1 -> t2 -> t3 -> t4 -> f.apply(t1, t2, t3, t4); + } + + public static Function>>>> curry( + Function5 f) { + return t1 -> t2 -> t3 -> t4 -> t5 -> f.apply(t1, t2, t3, t4, t5); + } +} diff --git a/idris-jvm-runtime/src/main/java/io/github/mmhelloworld/idrisjvm/runtime/MemoizedDelayed.java b/idris-jvm-runtime/src/main/java/io/github/mmhelloworld/idrisjvm/runtime/MemoizedDelayed.java index aa8e37fd4..ea089cb92 100644 --- a/idris-jvm-runtime/src/main/java/io/github/mmhelloworld/idrisjvm/runtime/MemoizedDelayed.java +++ b/idris-jvm-runtime/src/main/java/io/github/mmhelloworld/idrisjvm/runtime/MemoizedDelayed.java @@ -1,5 +1,7 @@ package io.github.mmhelloworld.idrisjvm.runtime; +import static io.github.mmhelloworld.idrisjvm.runtime.Runtime.unwrap; + public final class MemoizedDelayed implements Delayed { private boolean initialized; private Delayed delayed; @@ -8,7 +10,7 @@ public MemoizedDelayed(Delayed delayed) { this.delayed = () -> { synchronized(this) { if(!initialized) { - Object value = delayed.evaluate(); + Object value = unwrap(delayed.evaluate()); this.delayed = () -> value; initialized = true; } diff --git a/idris-jvm-runtime/src/main/java/io/github/mmhelloworld/idrisjvm/runtime/Runtime.java b/idris-jvm-runtime/src/main/java/io/github/mmhelloworld/idrisjvm/runtime/Runtime.java index a43611cd5..6340f7df7 100644 --- a/idris-jvm-runtime/src/main/java/io/github/mmhelloworld/idrisjvm/runtime/Runtime.java +++ b/idris-jvm-runtime/src/main/java/io/github/mmhelloworld/idrisjvm/runtime/Runtime.java @@ -122,7 +122,11 @@ public static Thunk createThunk(Object value) { public static Object unwrap(Object possibleThunk) { if (possibleThunk instanceof Thunk) { - return ((Thunk) possibleThunk).getObject(); + Thunk thunk = (Thunk) possibleThunk; + while (thunk != null && thunk.isRedex()) { + thunk = thunk.evaluate(); + } + return thunk == null ? null : thunk.getObject(); } else { return possibleThunk; } @@ -165,9 +169,9 @@ public static double unwrapDoubleThunk(Object possibleThunk) { } public static ForkJoinTask fork(Function action) { - return commonPool().submit((Runnable) () -> { + return commonPool().submit(() -> { try { - action.apply(0); + unwrap(action.apply(0)); } catch (Exception e) { e.printStackTrace(); } @@ -175,7 +179,7 @@ public static ForkJoinTask fork(Function action) { } public static ForkJoinTask fork(Delayed action) { - return commonPool().submit(action::evaluate); + return commonPool().submit(() -> unwrap(action.evaluate())); } public static void await(ForkJoinTask task) { diff --git a/idris-jvm-runtime/src/main/java/io/github/mmhelloworld/idrisjvm/runtime/Thunk.java b/idris-jvm-runtime/src/main/java/io/github/mmhelloworld/idrisjvm/runtime/Thunk.java index 6e55703b2..89bf1c368 100644 --- a/idris-jvm-runtime/src/main/java/io/github/mmhelloworld/idrisjvm/runtime/Thunk.java +++ b/idris-jvm-runtime/src/main/java/io/github/mmhelloworld/idrisjvm/runtime/Thunk.java @@ -1,5 +1,7 @@ package io.github.mmhelloworld.idrisjvm.runtime; +import java.util.NoSuchElementException; + import static java.util.Objects.requireNonNull; @FunctionalInterface @@ -11,11 +13,7 @@ default boolean isRedex() { } default Object getObject() { - Thunk thunk = this; - while (thunk != null && thunk.isRedex()) { - thunk = thunk.evaluate(); - } - return thunk == null ? null : thunk.getObject(); + throw new NoSuchElementException("Unevaluated thunk"); } default int getInt() { diff --git a/src/Compiler/Jvm/Asm.idr b/src/Compiler/Jvm/Asm.idr index c527ea2ec..fa55f3fb0 100644 --- a/src/Compiler/Jvm/Asm.idr +++ b/src/Compiler/Jvm/Asm.idr @@ -361,12 +361,12 @@ namespace AsmGlobalState public export %foreign - "jvm:(String io/github/mmhelloworld/idrisjvm/assembler/AsmGlobalState),io/github/mmhelloworld/idrisjvm/assembler/AsmGlobalState" - prim_newAsmGlobalState : String -> PrimIO AsmGlobalState + "jvm:(String java/util/Collection io/github/mmhelloworld/idrisjvm/assembler/AsmGlobalState),io/github/mmhelloworld/idrisjvm/assembler/AsmGlobalState" + prim_newAsmGlobalState : String -> List String -> PrimIO AsmGlobalState public export - newAsmGlobalState : HasIO io => String -> io AsmGlobalState - newAsmGlobalState programName = primIO $ prim_newAsmGlobalState programName + newAsmGlobalState : HasIO io => String -> List String -> io AsmGlobalState + newAsmGlobalState programName trampolinePatterns = primIO $ prim_newAsmGlobalState programName trampolinePatterns public export %foreign jvm' "io/github/mmhelloworld/idrisjvm/assembler/AsmGlobalState" ".getAssembler" @@ -466,6 +466,15 @@ namespace AsmGlobalState classCodeEnd state outputDirectory outputFile mainClass = primIO $ prim_classCodeEnd state outputDirectory outputFile mainClass + public export + %foreign jvm' "io/github/mmhelloworld/idrisjvm/assembler/AsmGlobalState" ".shouldTrampoline" + "io/github/mmhelloworld/idrisjvm/assembler/AsmGlobalState String" "boolean" + prim_shouldTrampoline : AsmGlobalState -> String -> PrimIO Bool + + public export + shouldTrampoline : HasIO io => AsmGlobalState -> String -> io Bool + shouldTrampoline state name = primIO $ prim_shouldTrampoline state name + public export record AsmState where constructor MkAsmState @@ -1162,14 +1171,21 @@ export getVariableType : String -> Asm InferredType getVariableType name = getVariableTypeAtScope !getCurrentScopeIndex name +updateArgumentsForUntyped : Map Int InferredType -> Nat -> IO () +updateArgumentsForUntyped _ Z = pure () +updateArgumentsForUntyped types (S n) = do + ignore $ Map.put types (cast n) inferredObjectType + updateArgumentsForUntyped types n + export -updateScopeVariableTypes : Asm () -updateScopeVariableTypes = go (scopeCounter !GetState - 1) where +updateScopeVariableTypes : Nat -> Asm () +updateScopeVariableTypes arity = go (scopeCounter !GetState - 1) where go : Int -> Asm () go scopeIndex = if scopeIndex < 0 then Pure () else do variableTypes <- retrieveVariableTypesAtScope scopeIndex + when (scopeIndex == 0) $ LiftIo $ updateArgumentsForUntyped variableTypes arity variableIndices <- retrieveVariableIndicesByName scopeIndex scope <- getScope scopeIndex saveScope $ record {allVariableTypes = variableTypes, allVariableIndices = variableIndices} scope @@ -1201,7 +1217,7 @@ addVariableType var ty = do %inline export lambdaMaxCountPerMethod: Int -lambdaMaxCountPerMethod = 25 +lambdaMaxCountPerMethod = 50 export getLambdaImplementationMethodName : String -> Asm Jname @@ -1634,7 +1650,8 @@ runAsm state (ClassCodeStart version access className sig parent intf anns) = as the (JList String) $ believe_me intf, the (JList JAnnotation) $ believe_me janns] runAsm state (CreateClass opts) = - assemble state $ jvmInstance () "io/github/mmhelloworld/idrisjvm/assembler/Assembler.createClass" [toJClassOpts opts] + assemble state $ jvmInstance () "io/github/mmhelloworld/idrisjvm/assembler/Assembler.createClass" + [assembler state, toJClassOpts opts] runAsm state (CreateField accs sourceFileName className fieldName desc sig fieldInitialValue) = assemble state $ do let jaccs = sum $ accessNum <$> accs jvmInstance () "io/github/mmhelloworld/idrisjvm/assembler/Assembler.createField" diff --git a/src/Compiler/Jvm/Codegen.idr b/src/Compiler/Jvm/Codegen.idr index a16876324..7a75d2368 100644 --- a/src/Compiler/Jvm/Codegen.idr +++ b/src/Compiler/Jvm/Codegen.idr @@ -200,6 +200,9 @@ isInterfaceInvocation : InferredType -> Bool isInterfaceInvocation (IRef className) = "i:" `isPrefixOf` className isInterfaceInvocation _ = False +%foreign "jvm:.startsWith(java/lang/String java/lang/String boolean),java/lang/String" +startsWith : String -> String -> Bool + assembleNil : (isTailCall: Bool) -> InferredType -> Asm () assembleNil isTailCall returnType = do Field GetStatic idrisNilClass "INSTANCE" "Lio/github/mmhelloworld/idrisjvm/runtime/IdrisList$Nil;" @@ -218,6 +221,38 @@ getDynamicVariableIndex variablePrefix = do let variableName = variablePrefix ++ show suffixIndex getVariableIndex variableName +assembleIdentityLambda : (isTailCall : Bool) -> Asm () +assembleIdentityLambda isTailCall = do + Field GetStatic functionsClass "IDENTITY" (getJvmTypeDescriptor inferredLambdaType) + when isTailCall $ asmReturn inferredLambdaType + +assembleIdentity1Lambda : (isTailCall : Bool) -> Asm () +assembleIdentity1Lambda isTailCall = do + Field GetStatic functionsClass "IDENTITY_1" (getJvmTypeDescriptor inferredLambdaType) + when isTailCall $ asmReturn inferredLambdaType + +assembleIdentity2Lambda : (isTailCall : Bool) -> Asm () +assembleIdentity2Lambda isTailCall = do + Field GetStatic functionsClass "IDENTITY_2" (getJvmTypeDescriptor inferredLambdaType) + when isTailCall $ asmReturn inferredLambdaType + +assembleConstantLambda : (isTailCall : Bool) -> Asm () +assembleConstantLambda isTailCall = do + Field GetStatic functionsClass "CONSTANT" (getJvmTypeDescriptor inferredLambdaType) + when isTailCall $ asmReturn inferredLambdaType + +assembleConstant1Lambda : (isTailCall : Bool) -> Asm () +assembleConstant1Lambda isTailCall = do + Field GetStatic functionsClass "CONSTANT_1" (getJvmTypeDescriptor inferredLambdaType) + when isTailCall $ asmReturn inferredLambdaType + +getLambdaTypeByArity: (arity: Nat) -> LambdaType +getLambdaTypeByArity 2 = Function2Lambda +getLambdaTypeByArity 3 = Function3Lambda +getLambdaTypeByArity 4 = Function4Lambda +getLambdaTypeByArity 5 = Function5Lambda +getLambdaTypeByArity _ = FunctionLambda + mutual assembleExpr : (isTailCall: Bool) -> InferredType -> NamedCExp -> Asm () assembleExpr isTailCall returnType (NmDelay _ _ expr) = @@ -297,9 +332,7 @@ mutual let jname = jvmName idrisName functionType <- case !(findFunctionType jname) of Just ty => Pure ty - Nothing => do - addUntypedFunction jname - Pure $ MkInferredFunctionType inferredObjectType $ replicate (length args) inferredObjectType + Nothing => Pure $ MkInferredFunctionType inferredObjectType $ replicate (length args) inferredObjectType let paramTypes = parameterTypes functionType if paramTypes == [] then assembleNmAppNilArity isTailCall returnType idrisName @@ -310,15 +343,20 @@ mutual let methodDescriptor = getMethodDescriptor $ MkInferredFunctionType methodReturnType paramTypes let functionName = getIdrisFunctionName !getProgramName (className jname) (methodName jname) InvokeMethod InvokeStatic (className functionName) (methodName functionName) methodDescriptor False - asmCast methodReturnType returnType + currentMethodName <- currentMethodName <$> GetState + isCalleeTrampolined <- LiftIo $ AsmGlobalState.shouldTrampoline !getGlobalState (show jname) + let isLambda = startsWith (methodName currentMethodName) "lambda$" + let shouldUnwrap = isCalleeTrampolined && (not isTailCall || not isLambda) + let possibleThunkType = if shouldUnwrap then thunkType else methodReturnType + asmCast possibleThunkType returnType when isTailCall $ asmReturn returnType assembleExpr isTailCall returnType (NmApp _ lambdaVariable [arg]) = do assembleExpr False inferredLambdaType lambdaVariable assembleExpr False IUnknown arg - InvokeMethod InvokeInterface "java/util/function/Function" "apply" - "(Ljava/lang/Object;)Ljava/lang/Object;" True - asmCast inferredObjectType returnType + InvokeMethod InvokeInterface "java/util/function/Function" "apply" "(Ljava/lang/Object;)Ljava/lang/Object;" True + let possibleThunkType = if not isTailCall then thunkType else inferredObjectType + asmCast possibleThunkType returnType when isTailCall $ asmReturn returnType assembleExpr isTailCall returnType expr@(NmCon _ _ NOTHING _ []) = assembleNothing isTailCall returnType @@ -1021,6 +1059,38 @@ mutual storeVar ty ty targetVariableIndex Pure targetVariableIndex + createMethodReference : (isTailCall: Bool) -> (arity: Nat) -> Name -> Asm () + createMethodReference isTailCall arity name = do + let jname = jvmName name + functionType <- case !(findFunctionType jname) of + Just ty => Pure ty + Nothing => Pure $ MkInferredFunctionType inferredObjectType $ replicate arity inferredObjectType + let methodReturnType = InferredFunctionType.returnType functionType + let paramTypes = parameterTypes functionType + let methodDescriptor = getMethodDescriptor $ MkInferredFunctionType methodReturnType paramTypes + let functionName = getIdrisFunctionName !getProgramName (className jname) (methodName jname) + let functionInterface = getFunctionInterface arity + let invokeDynamicDescriptor = getMethodDescriptor $ MkInferredFunctionType functionInterface [] + invokeDynamic (className functionName) (methodName functionName) "apply" invokeDynamicDescriptor + (getSamDesc (getLambdaTypeByArity arity)) methodDescriptor methodDescriptor + when (arity > 1) $ do + let methodDescriptor = getMethodDescriptor $ MkInferredFunctionType inferredLambdaType [functionInterface] + InvokeMethod InvokeStatic functionsClass "curry" methodDescriptor False + when isTailCall $ asmReturn inferredLambdaType + + assembleSubMethodWithScope1 : (isTailCall: Bool) -> InferredType -> (parameterName : Maybe Name) -> + NamedCExp -> Asm () + assembleSubMethodWithScope1 isTailCall returnType parameterName body = do + parentScope <- getScope !getCurrentScopeIndex + withScope $ assembleSubMethod isTailCall returnType Nothing parameterName parentScope body + + assembleMethodReference : (isTailCall: Bool) -> InferredType -> (isMethodReference : Bool) -> (arity: Nat) -> + (functionName: Name) -> (parameterName : Maybe Name) -> NamedCExp -> Asm () + assembleMethodReference isTailCall returnType isMethodReference arity functionName parameterName body = do + if isMethodReference + then createMethodReference isTailCall arity functionName + else assembleSubMethodWithScope1 isTailCall returnType parameterName body + assembleSubMethodWithScope : (isTailCall: Bool) -> InferredType -> (parameterValue: Maybe NamedCExp) -> (parameterName : Maybe Name) -> NamedCExp -> Asm () assembleSubMethodWithScope isTailCall returnType (Just value) (Just name) body = do @@ -1041,9 +1111,46 @@ mutual assembleExpr False !(getVariableType variableName) value updateCurrentScopeIndex lambdaScopeIndex - assembleSubMethodWithScope isTailCall returnType _ parameterName body = do - parentScope <- getScope !getCurrentScopeIndex - withScope $ assembleSubMethod isTailCall returnType Nothing parameterName parentScope body + assembleSubMethodWithScope isTailCall returnType _ p0 + body@(NmLam _ p1 (NmLam _ p2 (NmLam _ p3 (NmLam _ p4 (NmApp _ (NmRef _ name) [NmLocal _ arg0, NmLocal _ arg1, + NmLocal _ arg2, NmLocal _ arg3, NmLocal _ arg4]))))) = assembleMethodReference + isTailCall returnType + ((fromMaybe (UN "") p0) == arg0 && p1 == arg1 && p2 == arg2 && p3 == arg3 && p4 == arg4) + 5 name p0 body + assembleSubMethodWithScope isTailCall returnType _ p0 + body@(NmLam _ p1 (NmLam _ p2 (NmLam _ p3 (NmApp _ (NmRef _ name) [NmLocal _ arg0, NmLocal _ arg1, NmLocal _ arg2, + NmLocal _ arg3])))) = assembleMethodReference isTailCall returnType + ((fromMaybe (UN "") p0) == arg0 && p1 == arg1 && p2 == arg2 && p3 == arg3) 4 name p0 body + assembleSubMethodWithScope isTailCall returnType _ p0 + body@(NmLam _ p1 (NmLam _ p2 (NmApp _ (NmRef _ name) [NmLocal _ arg0, NmLocal _ arg1, NmLocal _ arg2]))) = + assembleMethodReference isTailCall returnType ((fromMaybe (UN "") p0) == arg0 && p1 == arg1 && p2 == arg2) + 3 name p0 body + assembleSubMethodWithScope isTailCall returnType _ p0 + body@(NmLam _ p1 (NmApp _ (NmRef _ name) [NmLocal _ arg0, NmLocal _ arg1])) = + assembleMethodReference isTailCall returnType ((fromMaybe (UN "") p0) == arg0 && p1 == arg1) + 2 name p0 body + assembleSubMethodWithScope isTailCall returnType _ p0 body@(NmApp _ (NmRef _ name) [NmLocal _ arg0]) = + assembleMethodReference isTailCall returnType ((fromMaybe (UN "") p0) == arg0) 1 name p0 body + + assembleSubMethodWithScope isTailCall returnType _ parameterName body@(NmLam _ c (NmLam _ a (NmLocal _ b))) = + let hasParameter = isJust parameterName + in if hasParameter && c == b + then assembleConstant1Lambda isTailCall + else if hasParameter && a == b + then assembleIdentity2Lambda isTailCall + else assembleSubMethodWithScope1 isTailCall returnType parameterName body + assembleSubMethodWithScope isTailCall returnType _ parameterName body@(NmLam _ a (NmLocal _ b)) = + if (fromMaybe (UN "") parameterName) == b + then assembleConstantLambda isTailCall + else if isJust parameterName && a == b + then assembleIdentity1Lambda isTailCall + else assembleSubMethodWithScope1 isTailCall returnType parameterName body + assembleSubMethodWithScope isTailCall returnType _ parameterName body@(NmLocal _ b) = + if (fromMaybe (UN "") parameterName) == b + then assembleIdentityLambda isTailCall + else assembleSubMethodWithScope1 isTailCall returnType parameterName body + assembleSubMethodWithScope isTailCall returnType _ parameterName body = + assembleSubMethodWithScope1 isTailCall returnType parameterName body assembleSubMethod : (isTailCall: Bool) -> InferredType -> (parameterValueExpr: (Maybe (Asm ()))) -> (parameterName: Maybe Name) -> Scope -> NamedCExp -> Asm () @@ -1051,7 +1158,7 @@ mutual scope <- getScope !getCurrentScopeIndex maybe (Pure ()) (setScopeCounter . succ) (parentIndex scope) let lambdaBodyReturnType = returnType scope - let lambdaType = getLambdaType parameterName + let lambdaType = getLambdaTypeByParameter parameterName when (lambdaType == DelayedLambda) $ do New "io/github/mmhelloworld/idrisjvm/runtime/MemoizedDelayed" Dup @@ -1101,7 +1208,7 @@ mutual let lambdaReturnType = if isExtracted then lambdaBodyReturnType - else if lambdaType == ThunkLambda then thunkType else IUnknown + else if lambdaType == ThunkLambda then thunkType else inferredObjectType assembleExpr True lambdaReturnType expr addLambdaEndLabel scope labelEnd maybe (Pure ()) (\parentScopeIndex => updateScopeEndLabel parentScopeIndex labelEnd) (parentIndex scope) @@ -1580,11 +1687,7 @@ assembleDefinition idrisName fc = do lineNumberLabels = lineNumberLabels } updateCurrentFunction $ record { dynamicVariableCounter = 0 } let optimizedExpr = optimizedBody function - debug $ "**********************" - debug $ "Assembling " ++ declaringClassName ++ "." ++ methodName - debug "Optimized" - debug "---------" - debug $ showNamedCExp 0 optimizedExpr + debug $ "Assembling " ++ declaringClassName ++ "." ++ methodName ++ ":\n" ++ showNamedCExp 0 optimizedExpr let fileName = fst $ getSourceLocationFromFc fc let descriptor = getMethodDescriptor functionType let isField = arity == 0 @@ -1648,8 +1751,8 @@ prim_waitForFuturesToComplete : List ThreadID -> PrimIO () waitForFuturesToComplete : List ThreadID -> IO () waitForFuturesToComplete futures = primIO $ prim_waitForFuturesToComplete futures -goGroupByClassName : String -> List Name -> List (List Name) -goGroupByClassName programName names = unsafePerformIO $ do +groupByClassName : String -> List Name -> List (List Name) +groupByClassName programName names = unsafePerformIO $ do namesByClassName <- Map.newTreeMap {key=String} {value=List Name} go1 namesByClassName names Map.values {key=String} namesByClassName @@ -1665,10 +1768,6 @@ goGroupByClassName programName names = unsafePerformIO $ do _ <- Map.put {key=String} {value=List Name} namesByClassName jvmClassName newNames go2 names -groupByClassName : String -> List (List Name) -> List (List (List Name)) -groupByClassName programName [] = [] -groupByClassName programName (xs :: xxs) = goGroupByClassName programName xs :: groupByClassName programName xxs - createAsmState : AsmGlobalState -> Name -> IO AsmState createAsmState globalState name = do programName <- AsmGlobalState.getProgramName globalState @@ -1692,35 +1791,43 @@ assemble globalState fcAndDefinitionsByName name = do pure () Nothing => pure () -goAssembleAsync : AsmGlobalState -> Map String (FC, NamedDef) -> List (List Name) -> IO () -goAssembleAsync _ _ [] = pure () -goAssembleAsync globalState fcAndDefinitionsByName (xs :: xss) = do +assembleAsync : AsmGlobalState -> Map String (FC, NamedDef) -> List (List Name) -> IO () +assembleAsync _ _ [] = pure () +assembleAsync globalState fcAndDefinitionsByName (xs :: xss) = do threadIds <- traverse forkAssemble xs waitForFuturesToComplete threadIds - goAssembleAsync globalState fcAndDefinitionsByName xss + assembleAsync globalState fcAndDefinitionsByName xss where forkAssemble : Name -> IO ThreadID forkAssemble name = fork $ assemble globalState fcAndDefinitionsByName name -assembleAsync : AsmGlobalState -> Map String (FC, NamedDef) -> List (List Name) -> IO () -assembleAsync globalState fcAndDefinitionsByName names = - goAssembleAsync globalState fcAndDefinitionsByName $ transpose names - getNameStrFcDef : (Name, FC, NamedDef) -> (String, FC, NamedDef) getNameStrFcDef (name, fc, def) = (jvmSimpleName name, fc, def) getNameStrDef : (String, FC, NamedDef) -> (String, NamedDef) getNameStrDef (name, fc, def) = (name, def) +getTrampolinePatterns : List String -> List String +getTrampolinePatterns directives + = mapMaybe getPattern directives + where + getPattern : String -> Maybe String + getPattern directive = + let (k, v) = break (== '=') directive + in + if (trim k) == "trampoline" + then Just $ trim $ substr 1 (length v) v + else Nothing + ||| Compile a TT expression to JVM bytecode compileToJvmBytecode : Ref Ctxt Defs -> String -> String -> ClosedTerm -> Core () compileToJvmBytecode c outputDirectory outputFile term = do cdata <- getCompileData False Cases term + directives <- getDirectives Jvm let ndefs = namedDefs cdata let idrisMainBody = forget (mainExpr cdata) - let unwrappedExpr = unwrapExpr idrisMainBody let programName = if outputFile == "" then "repl" else outputFile - let nameFcDefs = (idrisMainFunctionName programName, emptyFC, MkNmFun [] unwrappedExpr) :: ndefs + let nameFcDefs = (idrisMainFunctionName programName, emptyFC, MkNmFun [] idrisMainBody) :: ndefs let nameStrFcDefs = getNameStrFcDef <$> nameFcDefs fcAndDefinitionsByName <- coreLift $ Map.fromList nameStrFcDefs let nameStrDefs = getNameStrDef <$> nameStrFcDefs @@ -1728,10 +1835,10 @@ compileToJvmBytecode c outputDirectory outputFile term = do coreLift $ when shouldDebug $ do timeString <- currentTimeString putStrLn (timeString ++ ": Analyzing dependencies") - globalState <- coreLift $ newAsmGlobalState programName - let names = groupByClassName programName . levelOrder $ buildFunctionTreeMain programName definitionsByName + globalState <- coreLift $ newAsmGlobalState programName (getTrampolinePatterns directives) + let names = groupByClassName programName . traverseDepthFirst $ buildFunctionTreeMain programName definitionsByName coreLift $ do - traverse_ (assembleAsync globalState fcAndDefinitionsByName) names + assembleAsync globalState fcAndDefinitionsByName (transpose names) let mainFunctionName = jvmName (idrisMainFunctionName programName) asmState <- createAsmState globalState (idrisMainFunctionName programName) _ <- asm asmState $ createMainMethod programName mainFunctionName diff --git a/src/Compiler/Jvm/Foreign.idr b/src/Compiler/Jvm/Foreign.idr index 6e8c559ff..172aab605 100644 --- a/src/Compiler/Jvm/Foreign.idr +++ b/src/Compiler/Jvm/Foreign.idr @@ -178,26 +178,20 @@ inferForeign programName idrisName fc foreignDescriptors argumentTypes returnTyp let jname = jvmName idrisName let jvmClassAndMethodName = getIdrisFunctionName programName (className jname) (methodName jname) jvmArgumentTypes <- traverse (parse fc) argumentTypes + let arityNat = length jvmArgumentTypes + let isNilArity = arityNat == 0 jvmDescriptor <- findJvmDescriptor fc idrisName foreignDescriptors jvmReturnType <- getInferredType fc !(parse fc returnType) (foreignFunctionClassName, foreignFunctionName, jvmReturnType, jvmArgumentTypes) <- parseForeignFunctionDescriptor fc jvmDescriptor jvmArgumentTypes jvmReturnType let jvmArgumentTypes = getInferredType <$> jvmArgumentTypes -- TODO: Do not discard Java lambda type descriptor scopeIndex <- newScopeIndex - let arityNat = length jvmArgumentTypes let arity = the Int $ cast arityNat - let isNilArity = arity == 0 let argumentNames = if isNilArity then [] else (\argumentIndex => "arg" ++ show argumentIndex) <$> [0 .. arity - 1] let argumentTypesByName = SortedMap.fromList $ List.zip argumentNames jvmArgumentTypes - isUntyped <- isUntypedFunction jname - let methodReturnType = - if isNilArity - then delayedType - else if jvmReturnType == IVoid || isUntyped then inferredObjectType else jvmReturnType - let inferredFunctionType = MkInferredFunctionType - methodReturnType - (if isUntyped then replicate arityNat inferredObjectType else stripInterfacePrefix <$> jvmArgumentTypes) + let methodReturnType = if isNilArity then delayedType else inferredObjectType + let inferredFunctionType = MkInferredFunctionType methodReturnType (replicate arityNat inferredObjectType) scopes <- LiftIo $ JList.new {a=Scope} let externalFunctionBody = NmExtPrim fc (NS (mkNamespace "") $ UN $ getPrimMethodName foreignFunctionName) [ @@ -205,10 +199,7 @@ inferForeign programName idrisName fc foreignDescriptors argumentTypes returnTyp NmPrimVal fc (Str $ foreignFunctionClassName ++ "." ++ foreignFunctionName), getJvmExtPrimArguments $ List.zip argumentTypes $ SortedMap.toList argumentTypesByName, NmPrimVal fc WorldVal] - let functionBody = - if isNilArity - then NmDelay fc LLazy externalFunctionBody - else externalFunctionBody + let functionBody = if isNilArity then NmDelay fc LLazy externalFunctionBody else externalFunctionBody let function = MkFunction jname inferredFunctionType scopes 0 jvmClassAndMethodName functionBody setCurrentFunction function LiftIo $ AsmGlobalState.addFunction !getGlobalState jname function @@ -233,7 +224,7 @@ inferForeign programName idrisName fc foreignDescriptors argumentTypes returnTyp MkScope scopeIndex (Just parentScopeIndex) variableTypes allVariableTypes variableIndices allVariableIndices IUnknown 0 (0, 0) ("", "") [] saveScope delayLambdaScope - updateScopeVariableTypes + updateScopeVariableTypes arityNat where getJvmExtPrimArguments : List (CFType, String, InferredType) -> NamedCExp getJvmExtPrimArguments [] = NmCon fc (UN "emptyForeignArg") DATACON (Just 0) [] diff --git a/src/Compiler/Jvm/InferredType.idr b/src/Compiler/Jvm/InferredType.idr index 9854e4c8d..b2ba8c80d 100644 --- a/src/Compiler/Jvm/InferredType.idr +++ b/src/Compiler/Jvm/InferredType.idr @@ -57,16 +57,24 @@ public export stringClass : String stringClass = "java/lang/String" -%inline public export inferredStringType : InferredType inferredStringType = IRef stringClass -%inline public export inferredLambdaType : InferredType inferredLambdaType = IRef "java/util/function/Function" +export +function2Type : InferredType +function2Type = IRef "java/util/function/BiFunction" + +export +getFunctionInterface : (arity: Nat) -> InferredType +getFunctionInterface 1 = inferredLambdaType +getFunctionInterface 2 = function2Type +getFunctionInterface arity = IRef ("io/github/mmhelloworld/idrisjvm/runtime/Function" ++ show arity) + %inline public export inferredForkJoinTaskType : InferredType @@ -137,6 +145,11 @@ public export idrisJustType : InferredType idrisJustType = IRef idrisJustClass +%inline +public export +functionsClass : String +functionsClass = "io/github/mmhelloworld/idrisjvm/runtime/Functions" + export isPrimitive : InferredType -> Bool isPrimitive IBool = True diff --git a/src/Compiler/Jvm/Optimizer.idr b/src/Compiler/Jvm/Optimizer.idr index 555b33601..23ec1c9ee 100644 --- a/src/Compiler/Jvm/Optimizer.idr +++ b/src/Compiler/Jvm/Optimizer.idr @@ -142,7 +142,7 @@ extractedMethodArgumentName = "$jvm$arg" %inline maxCasesInMethod : Int -maxCasesInMethod = 12 +maxCasesInMethod = 5 appliedLambdaSwitchIndicator : FC appliedLambdaSwitchIndicator = MkFC (PhysicalPkgSrc "$jvmAppliedLambdaSwitch$") (0, 0) (0, 0) @@ -175,7 +175,6 @@ mutual pure $ NmApp appliedLambdaLetIndicator (NmLam fc extractedMethodArgumentVarName body) [liftedValue] goLiftToLambda True (NmLet fc var value sc) = pure $ NmLet fc var !(goLiftToLambda False value) !(goLiftToLambda True sc) - goLiftToLambda _ expr@(NmConCase _ sc [] Nothing) = pure expr goLiftToLambda False (NmConCase fc sc alts def) = do put $ succ !get let var = UN extractedMethodArgumentName @@ -195,7 +194,6 @@ mutual liftedAlts <- traverse liftToLambdaCon alts liftedDef <- traverse liftToLambdaDefault def pure $ NmConCase fc !(goLiftToLambda False sc) liftedAlts liftedDef - goLiftToLambda _ expr@(NmConstCase fc sc [] Nothing) = pure expr goLiftToLambda False (NmConstCase fc sc alts def) = do put $ succ !get let var = UN extractedMethodArgumentName @@ -224,7 +222,7 @@ mutual goLiftToLambda _ (NmOp fc f args) = pure $ NmOp fc f !(traverse (goLiftToLambda False) args) goLiftToLambda _ (NmExtPrim fc f args) = pure $ NmExtPrim fc f !(traverse (goLiftToLambda False) args) goLiftToLambda _ (NmForce fc reason t) = pure $ NmForce fc reason !(goLiftToLambda False t) - goLiftToLambda _ (NmDelay fc reason t) = pure $ NmDelay fc reason !(goLiftToLambda False t) + goLiftToLambda _ (NmDelay fc reason t) = pure $ NmDelay fc reason !(goLiftToLambda True t) goLiftToLambda _ expr = pure expr liftToLambdaDefault : NamedCExp -> State Int NamedCExp @@ -325,30 +323,37 @@ mutual markTailRecursionConstAlt : NamedConstAlt -> Asm NamedConstAlt markTailRecursionConstAlt (MkNConstAlt constant caseBody) = MkNConstAlt constant <$> markTailRecursion caseBody +optThunkExpr : Bool -> NamedCExp -> NamedCExp +optThunkExpr True = thunkExpr +optThunkExpr _ = id + mutual - trampolineExpression : NamedCExp -> NamedCExp + trampolineExpression : (isTailRec: Bool) -> NamedCExp -> NamedCExp -- Do not trampoline as tail recursion will be eliminated - trampolineExpression expr@(NmApp fc (NmRef nameFc (UN ":__jvmTailRec__:")) args) = expr - trampolineExpression expr@(NmCon _ _ _ _ _) = thunkExpr expr - trampolineExpression expr@(NmApp _ _ _) = thunkExpr expr - trampolineExpression expr@(NmLet fc var value body) = - NmLet fc var value $ trampolineExpression body - trampolineExpression expr@(NmConCase fc sc alts def) = + trampolineExpression _ (NmApp fc (NmRef nameFc (UN ":__jvmTailRec__:")) args) = + NmApp fc (NmRef nameFc (UN ":__jvmTailRec__:")) (trampolineExpression False <$> args) + trampolineExpression isTailRec (NmApp fc f args) = + optThunkExpr isTailRec $ NmApp fc (trampolineExpression False f) (trampolineExpression False <$> args) + trampolineExpression _ (NmLam fc param body) = NmLam fc param $ trampolineExpression False body + trampolineExpression isTailRec (NmLet fc var value body) = + NmLet fc var (trampolineExpression False value) $ trampolineExpression isTailRec body + trampolineExpression _ (NmConCase fc sc alts def) = let trampolinedAlts = trampolineExpressionConAlt <$> alts - trampolinedDefault = trampolineExpression <$> def + trampolinedDefault = trampolineExpression True <$> def in NmConCase fc sc trampolinedAlts trampolinedDefault - trampolineExpression (NmConstCase fc sc alts def) = + trampolineExpression _ (NmConstCase fc sc alts def) = let trampolinedAlts = trampolineExpressionConstAlt <$> alts - trampolinedDefault = trampolineExpression <$> def + trampolinedDefault = trampolineExpression True <$> def in NmConstCase fc sc trampolinedAlts trampolinedDefault - trampolineExpression expr = expr + trampolineExpression _ expr = expr trampolineExpressionConAlt : NamedConAlt -> NamedConAlt trampolineExpressionConAlt (MkNConAlt name conInfo tag args caseBody) = - MkNConAlt name conInfo tag args $ trampolineExpression caseBody + MkNConAlt name conInfo tag args $ trampolineExpression True caseBody trampolineExpressionConstAlt : NamedConstAlt -> NamedConstAlt - trampolineExpressionConstAlt (MkNConstAlt constant caseBody) = MkNConstAlt constant $ trampolineExpression caseBody + trampolineExpressionConstAlt (MkNConstAlt constant caseBody) = + MkNConstAlt constant $ trampolineExpression True caseBody exitInferenceScope : Int -> Asm () exitInferenceScope scopeIndex = updateCurrentScopeIndex scopeIndex @@ -428,37 +433,51 @@ withInferenceLambdaScope lineNumberStart lineNumberEnd parameterName expr op = d Pure result public export -data LambdaType = ThunkLambda | DelayedLambda | FunctionLambda +data LambdaType = ThunkLambda | DelayedLambda | FunctionLambda | Function2Lambda | Function3Lambda | Function4Lambda | + Function5Lambda export Eq LambdaType where ThunkLambda == ThunkLambda = True DelayedLambda == DelayedLambda = True FunctionLambda == FunctionLambda = True + Function2Lambda == Function2Lambda = True + Function3Lambda == Function3Lambda = True + Function4Lambda == Function4Lambda = True + Function5Lambda == Function5Lambda = True _ == _ = False export -getLambdaType : (parameterName: Maybe Name) -> LambdaType -getLambdaType (Just (UN "$jvm$thunk")) = ThunkLambda -getLambdaType Nothing = DelayedLambda -getLambdaType _ = FunctionLambda +getLambdaTypeByParameter : (parameterName: Maybe Name) -> LambdaType +getLambdaTypeByParameter (Just (UN "$jvm$thunk")) = ThunkLambda +getLambdaTypeByParameter Nothing = DelayedLambda +getLambdaTypeByParameter _ = FunctionLambda export getLambdaInterfaceMethodName : LambdaType -> String -getLambdaInterfaceMethodName FunctionLambda = "apply" -getLambdaInterfaceMethodName _ = "evaluate" +getLambdaInterfaceMethodName ThunkLambda = "evaluate" +getLambdaInterfaceMethodName DelayedLambda = "evaluate" +getLambdaInterfaceMethodName _ = "apply" export getSamDesc : LambdaType -> String getSamDesc ThunkLambda = "()" ++ getJvmTypeDescriptor thunkType getSamDesc DelayedLambda = "()Ljava/lang/Object;" -getSamDesc FunctionLambda = "(Ljava/lang/Object;)Ljava/lang/Object;" +getSamDesc FunctionLambda = getMethodDescriptor $ MkInferredFunctionType inferredObjectType [inferredObjectType] +getSamDesc Function2Lambda = + getMethodDescriptor $ MkInferredFunctionType inferredObjectType $ replicate 2 inferredObjectType +getSamDesc Function3Lambda = + getMethodDescriptor $ MkInferredFunctionType inferredObjectType $ replicate 3 inferredObjectType +getSamDesc Function4Lambda = + getMethodDescriptor $ MkInferredFunctionType inferredObjectType $ replicate 4 inferredObjectType +getSamDesc Function5Lambda = + getMethodDescriptor $ MkInferredFunctionType inferredObjectType $ replicate 5 inferredObjectType export getLambdaInterfaceType : LambdaType -> InferredType -> InferredType getLambdaInterfaceType ThunkLambda returnType = getThunkType returnType getLambdaInterfaceType DelayedLambda returnType = delayedType -getLambdaInterfaceType FunctionLambda returnType = inferredLambdaType +getLambdaInterfaceType _ returnType = inferredLambdaType export getLambdaImplementationMethodReturnType : LambdaType -> InferredType @@ -723,7 +742,7 @@ mutual let hasParameterValue = isJust parameterValueExpr let (_, lineStart, lineEnd) = getSourceLocation expr let jvmParameterNameAndType = (\(name, ty) => (jvmSimpleName name, ty)) <$> parameterNameAndType - let lambdaType = getLambdaType (fst <$> parameterNameAndType) + let lambdaType = getLambdaTypeByParameter (fst <$> parameterNameAndType) lambdaBodyReturnType <- withInferenceLambdaScope lineStart lineEnd (fst <$> parameterNameAndType) expr $ do when (lambdaType /= ThunkLambda) $ traverse_ createAndAddVariable jvmParameterNameAndType @@ -742,6 +761,11 @@ mutual addVariableType name ty Pure () + inferExprLamWithParameterType1 : (isCached : Bool) -> Maybe Name -> NamedCExp -> Asm InferredType + inferExprLamWithParameterType1 True _ _ = Pure inferredLambdaType + inferExprLamWithParameterType1 False parameterName expr = + inferExprLamWithParameterType ((\name => (name, inferredObjectType)) <$> parameterName) Nothing expr + inferExprLam : AppliedLambdaType -> (parameterValue: Maybe NamedCExp) -> (parameterName : Maybe Name) -> NamedCExp -> Asm InferredType inferExprLam appliedLambdaType parameterValue@(Just value) (Just parameterName) lambdaBody = do @@ -779,8 +803,28 @@ mutual inferExpr valueType value addVariableType variableName valueType updateCurrentScopeIndex lambdaScopeIndex - inferExprLam _ _ parameterName expr = - inferExprLamWithParameterType ((\name => (name, inferredObjectType)) <$> parameterName) Nothing expr + inferExprLam _ _ p0 expr@(NmLam _ p1 (NmLam _ p2 (NmLam _ p3 (NmLam _ p4 (NmApp _ (NmRef _ name) + [NmLocal _ arg0, NmLocal _ arg1, NmLocal _ arg2, NmLocal _ arg3, NmLocal _ arg4]))))) = + inferExprLamWithParameterType1 + ((fromMaybe (UN "") p0) == arg0 && p1 == arg1 && p2 == arg2 && p3 == arg3 && p4 == arg4) p0 expr + inferExprLam _ _ p0 expr@(NmLam _ p1 (NmLam _ p2 (NmLam _ p3 (NmApp _ (NmRef _ name) + [NmLocal _ arg0, NmLocal _ arg1, NmLocal _ arg2, NmLocal _ arg3])))) = + inferExprLamWithParameterType1 + ((fromMaybe (UN "") p0) == arg0 && p1 == arg1 && p2 == arg2 && p3 == arg3) p0 expr + inferExprLam _ _ p0 expr@(NmLam _ p1 (NmLam _ p2 (NmApp _ (NmRef _ name) + [NmLocal _ arg0, NmLocal _ arg1, NmLocal _ arg2]))) = + inferExprLamWithParameterType1 ((fromMaybe (UN "") p0) == arg0 && p1 == arg1 && p2 == arg2) p0 expr + inferExprLam _ _ p0 expr@(NmLam _ p1 (NmApp _ (NmRef _ name) [NmLocal _ arg0, NmLocal _ arg1])) = + inferExprLamWithParameterType1 ((fromMaybe (UN "") p0) == arg0 && p1 == arg1) p0 expr + inferExprLam _ _ p0 expr@(NmApp _ (NmRef _ _) [NmLocal _ b]) = + inferExprLamWithParameterType1 ((fromMaybe (UN "") p0) == b) p0 expr + inferExprLam _ _ p0 expr@(NmLam _ c (NmLam _ a (NmLocal _ b))) = + inferExprLamWithParameterType1 (isJust p0 && (c == b || a == b)) p0 expr + inferExprLam _ _ p0 expr@(NmLam _ a (NmLocal _ b)) = + inferExprLamWithParameterType1 ((fromMaybe (UN "") p0) == b || (isJust p0 && a == b)) p0 expr + inferExprLam _ _ p0 expr@(NmLocal _ b) = + inferExprLamWithParameterType1 ((fromMaybe (UN "") p0) == b) p0 expr + inferExprLam _ _ p0 expr = inferExprLamWithParameterType1 False p0 expr inferExprLet : FC -> InferredType -> (x : Name) -> NamedCExp -> NamedCExp -> Asm InferredType inferExprLet fc exprTy var value expr = do @@ -826,9 +870,7 @@ mutual let functionName = jvmName idrisName functionType <- case !(findFunctionType functionName) of Just ty => Pure ty - Nothing => do - addUntypedFunction functionName - Pure $ MkInferredFunctionType inferredObjectType $ replicate (length args) inferredObjectType + Nothing => Pure $ MkInferredFunctionType inferredObjectType $ replicate (length args) inferredObjectType let argsWithTypes = List.zip args (parameterTypes functionType) traverse_ inferParameter argsWithTypes Pure $ returnType functionType @@ -1076,11 +1118,12 @@ mutual inferExprOp Crash [_, msg] = Pure IUnknown inferExprOp op _ = Throw emptyFC ("Unsupported primitive function " ++ show op) -optimize : TailCallCategory -> NamedCExp -> Asm NamedCExp -optimize tailCallCategory expr = do +optimize : Jname -> TailCallCategory -> NamedCExp -> Asm NamedCExp +optimize jname tailCallCategory expr = do inlinedAndTailRecursionMarkedExpr <- markTailRecursion . liftToLambda $ expr - Pure $ if hasNonSelfTailCall tailCallCategory - then trampolineExpression inlinedAndTailRecursionMarkedExpr + shouldTrampoline <- LiftIo $ AsmGlobalState.shouldTrampoline !getGlobalState (show jname) + Pure $ if shouldTrampoline && hasNonSelfTailCall tailCallCategory + then trampolineExpression True inlinedAndTailRecursionMarkedExpr else inlinedAndTailRecursionMarkedExpr export @@ -1107,24 +1150,16 @@ inferDef programName idrisName fc (MkNmFun args body) = do let expr = if arityInt == 0 then NmDelay fc LLazy body else body let argumentNames = jvmSimpleName <$> args argIndices <- LiftIo $ getArgumentIndices arityInt argumentNames - isUntyped <- isUntypedFunction jname - let initialArgumentTypes = replicate arity $ if isUntyped then inferredObjectType else IUnknown + let initialArgumentTypes = replicate arity inferredObjectType + let inferredFunctionType = MkInferredFunctionType inferredObjectType initialArgumentTypes argumentTypesByName <- LiftIo $ Map.fromList $ List.zip argumentNames initialArgumentTypes scopes <- LiftIo $ JList.new {a=Scope} - let function = - MkFunction jname (MkInferredFunctionType IUnknown initialArgumentTypes) - scopes 0 jvmClassAndMethodName emptyFunction + let function = MkFunction jname inferredFunctionType scopes 0 jvmClassAndMethodName emptyFunction setCurrentFunction function LiftIo $ AsmGlobalState.addFunction !getGlobalState jname function let shouldDebugExpr = shouldDebug && (fromMaybe True $ ((\name => name `isInfixOf` (getSimpleName jname)) <$> debugFunction)) - when shouldDebugExpr $ do - debug $ "**********************" - debug $ "Inferring " ++ (className jvmClassAndMethodName) ++ "." ++ (methodName jvmClassAndMethodName) - debug "Unoptimized" - debug "---------" - debug $ showNamedCExp 0 expr - optimizedExpr <- optimize tailCallCategory expr + optimizedExpr <- optimize jname tailCallCategory expr updateCurrentFunction $ record { optimizedBody = optimizedExpr } resetScope @@ -1138,11 +1173,11 @@ inferDef programName idrisName fc (MkNmFun args body) = do saveScope functionScope retTy <- inferExpr IUnknown optimizedExpr - updateScopeVariableTypes - inferredArgumentTypes <- if isUntyped then pure initialArgumentTypes else getArgumentTypes argumentNames - let inferredFunctionType = - MkInferredFunctionType (if isUntyped then inferredObjectType else retTy) inferredArgumentTypes + updateScopeVariableTypes arity updateCurrentFunction $ record { inferredFunctionType = inferredFunctionType } + when shouldDebugExpr $ + debug $ "Inferring " ++ (className jvmClassAndMethodName) ++ "." ++ (methodName jvmClassAndMethodName) ++ + ":\n" ++ showNamedCExp 0 expr ++ "\n" ++ show inferredFunctionType when shouldDebugExpr $ showScopes (scopeCounter !GetState - 1) where getArgumentTypes : List String -> Asm (List InferredType) diff --git a/src/Compiler/Jvm/Tree.idr b/src/Compiler/Jvm/Tree.idr index 4da79328a..a74484f7d 100644 --- a/src/Compiler/Jvm/Tree.idr +++ b/src/Compiler/Jvm/Tree.idr @@ -27,15 +27,9 @@ implementation Show a => Show (Tree a) where show tree = displayTree show tree export -levelOrder : Tree a -> List (List a) -levelOrder tree = go [] [tree] - where - element : Tree b -> b - element (Node value _) = value - - subtrees : Tree b -> List (Tree b) - subtrees (Node _ children) = children - - go : List (List b) -> List (Tree b) -> List (List b) - go acc [] = acc - go acc trees = go (map element trees :: acc) (concatMap subtrees trees) \ No newline at end of file +traverseDepthFirst : Tree a -> List a +traverseDepthFirst tree = go [] [] tree where + go : (acc: List a) -> (stack: List (Tree a)) -> Tree a -> List a + go acc [] (Node d []) = d :: acc + go acc (next :: rest) (Node d []) = go (d :: acc) rest next + go acc stack (Node d (child :: children)) = go acc (child :: stack) (Node d children) diff --git a/src/Compiler/Jvm/Variable.idr b/src/Compiler/Jvm/Variable.idr index 5acf3ad64..63c5931ac 100644 --- a/src/Compiler/Jvm/Variable.idr +++ b/src/Compiler/Jvm/Variable.idr @@ -175,8 +175,6 @@ checkcast cname = Checkcast cname export asmCast : (sourceType: InferredType) -> (targetType: InferredType) -> Asm () -asmCast ty1@(IRef "java/lang/Object") ty2@(IRef "java/lang/Object") = unwrapObjectThunk - asmCast ty1@(IRef class1) ty2@(IRef class2) = cond [ (class1 == class2, Pure ()), @@ -185,13 +183,13 @@ asmCast ty1@(IRef class1) ty2@(IRef class2) = (class1 == thunkClass && class2 == intThunkClass, unboxToIntThunk), (class1 == thunkClass && class2 == longThunkClass, unboxToLongThunk), (class1 == thunkClass && class2 == doubleThunkClass, unboxToDoubleThunk), - (ty1 == inferredObjectType || isThunkType ty1, do unwrapObjectThunk; checkcast class2) + (isThunkType ty1, do unwrapObjectThunk; checkcast class2) ] (checkcast class2) asmCast IUnknown ty@(IRef clazz) = if isThunkType ty then thunkObject - else do unwrapObjectThunk; checkcast clazz + else checkcast clazz asmCast IBool IBool = Pure () asmCast IByte IByte = Pure () @@ -252,9 +250,7 @@ asmCast (IRef _) arr@(IArray _) = Checkcast $ getJvmTypeDescriptor arr asmCast IVoid IVoid = Pure () asmCast IVoid (IRef _) = Aconstnull asmCast IVoid IUnknown = Aconstnull -asmCast IUnknown IUnknown = unwrapObjectThunk -asmCast (IRef "java/lang/Object") IUnknown = unwrapObjectThunk -asmCast _ IUnknown = Pure () +asmCast ty IUnknown = when (isThunkType ty) unwrapObjectThunk asmCast ty1 ty2 = Throw emptyFC $ "Cannot convert from " ++ show ty1 ++ " to " ++ show ty2