Skip to content

Commit 7a2cfc3

Browse files
committed
Rewrite DecEq deriving with better performance
I see ~47% speedup in instance generation time (typechecking) and ~22% speedup at runtime (normalization) with admittedly flaky benchmarks. Crucially, the generated instances should now only ever force booleans during normalization if all we want to know is the boolean equality.
1 parent 4fbf11b commit 7a2cfc3

File tree

1 file changed

+52
-38
lines changed

1 file changed

+52
-38
lines changed

Tactic/Derive/DecEq.agda

Lines changed: 52 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -17,17 +17,12 @@ import Data.List as L
1717
import Data.List.NonEmpty as NE
1818

1919
open import Relation.Nullary
20-
open import Relation.Nullary.Decidable
2120

2221
open import Reflection.Tactic
23-
open import Reflection.AST.Term using (_≟-Pattern_)
24-
open import Reflection.Utils
25-
open import Reflection.Utils.TCI
2622
open import Reflection.QuotedDefinitions
2723

2824
open import Class.DecEq.Core
2925
open import Class.Functor
30-
open import Class.Monad
3126
open import Class.MonadTC.Instances
3227
open import Class.Traversable
3328

@@ -40,14 +35,21 @@ open ClauseExprM
4035
private
4136
instance _ = ContextMonad-MonadTC
4237

43-
-- We take the Dec P argument first to improve type checking performance.
44-
-- It's easy to infer the type of P from this argument and we need to know
45-
-- P to be able to check the pattern lambda generated for the P → Q direction
46-
-- of the isomorphism. Having the isomorphism first would cause the type checker
47-
-- to go back and forth between the pattern lambda and the Dec P argument,
48-
-- inferring just enough of the type of make progress on the lambda.
49-
map' : {p q} {P : Set p} {Q : Set q} Dec P P ⇔ Q Dec Q
50-
map' d record { to = to ; from = from } = map′ to from d
38+
-- Here's an example of what code this generates, here for a record R with 3 fields:
39+
-- DecEq : DecEq R
40+
-- DecEq ._≟_ ⟪ x₁ , x₂ , x₃ ⟫ ⟪ y₁ , y₂ , y₃ ⟫ =
41+
-- case (x₁ ≟ y₁) of λ where
42+
-- (false because ¬p) → no (case ¬p of λ where (ofⁿ ¬p) refl → ¬p refl)
43+
-- (true because p₁) → case (x₂ ≟ y₂) of λ where
44+
-- (false because ¬p) → no (case ¬p of λ where (ofⁿ ¬p) refl → ¬p refl)
45+
-- (true because p₂) → case (x₃ ≟ y₃) of λ where
46+
-- (false because ¬p) → no (case ¬p of λ where (ofⁿ ¬p) refl → ¬p refl)
47+
-- (true because p₃) → yes (case p₁ , p₂ , p₃ of λ where (ofʸ refl , ofʸ refl , ofʸ refl) → refl)
48+
49+
-- patterns almost like `yes` and `no`, except that they don't match the `Reflects` proof
50+
-- delaying maching on the `Reflects` proof as late as possible results in a major speed increase
51+
pattern ``yes' x = quote _because_ ◇⟦ quote true ◇ ∣ x ⟧
52+
pattern ``no' x = quote _because_ ◇⟦ quote false ◇ ∣ x ⟧
5153

5254
module _ (transName : Name Maybe Name) where
5355

@@ -57,36 +59,48 @@ private
5759
... | nothing = quote _≟_ ∙⟦ t ∣ t' ⟧
5860
eqFromTerm _ t t' = quote _≟_ ∙⟦ t ∣ t' ⟧
5961

60-
toDecEqName : SinglePattern List (Term Term Term)
61-
toDecEqName (l , _) = L.map (λ where (_ , arg _ t) eqFromTerm t) l
62-
63-
-- on the diagonal we have one pattern, outside we don't care
62+
-- `nothing`: outside of the diagonal, not equal
63+
-- `just`: on the diagonal, with that pattern, could be equal
6464
-- assume that the types in the pattern are properly normalized
65-
mapDiag : Maybe SinglePattern TC Term
66-
mapDiag nothing = return $ `no `λ⦅ [ ("" , vArg?) ] ⦆∅
67-
mapDiag (just p@(l , _)) = let k = length l in do
68-
typeList traverse ⦃ Functor-List ⦄ inferType (applyDownFrom ♯ (length l))
69-
return $ quote map' ∙⟦ genPf k (L.map eqFromTerm typeList) ∣ genEquiv k ⟧
65+
genBranch : Maybe SinglePattern TC Term
66+
genBranch nothing = return $ `no `λ⦅ [ ("" , vArg?) ] ⦆∅
67+
genBranch (just ([] , _)) = return $ `yes `refl
68+
genBranch (just p@(l@(x ∷ xs) , _)) = do
69+
typeList traverse inferType (applyUpTo ♯ (length l))
70+
let eqs = L.map eqFromTerm typeList
71+
return $ foldl (λ t eq genCase eq t) genTrueCase eqs
7072
where
71-
genPf : List (Term Term Term) Term
72-
genPf k [] = `yes (quote tt ◆)
73-
genPf k (n ∷ l) = quote _×-dec_ ∙⟦ genPf k l ∣ n (♯ (length l)) (♯ (length l + k)) ⟧
74-
75-
-- c x1 .. xn ≡ c y1 .. yn ⇔ x1 ≡ y1 .. xn ≡ yn
76-
genEquiv : Term
77-
genEquiv n = quote mk⇔ ∙⟦ `λ⟦ reflPattern n ⇒ `refl ⟧ ∣ `λ⟦ ``refl ⇒ reflTerm n ⟧ ⟧
78-
where
79-
reflPattern : Pattern
80-
reflPattern 0 = quote tt ◇
81-
reflPattern (suc n) = quote _,_ ◇⟦ reflPattern n ∣ ``refl ⟧
82-
83-
reflTerm : Term
84-
reflTerm 0 = quote tt ◆
85-
reflTerm (suc n) = quote _,_ ◆⟦ reflTerm n ∣ `refl ⟧
73+
k = ℕ.suc (length xs)
74+
75+
vars : NE.List⁺ ℕ
76+
vars = 0 NE.∷ applyUpTo ℕ.suc (length xs)
77+
78+
-- case (xᵢ ≟ yᵢ) of λ { (false because ...) → no ... ; (true because p) → t }
79+
-- since we always add one variable to the scope of t the uncompared terms
80+
-- are always at index 2k+1 and k
81+
genCase : (Term Term Term) Term Term
82+
genCase _`≟_ t = `case ♯ (2 * k ∸ 1) `≟ ♯ (k ∸ 1) of clauseExprToPatLam (MatchExpr
83+
( (singlePatternFromPattern (vArg (``yes' (` 0))) , inj₂ (just t))
84+
∷ (singlePatternFromPattern (vArg (``no' (` 0))) , inj₂ (just (`no $
85+
-- case ¬p of λ where (ofⁿ ¬p) refl → ¬p refl
86+
`case ♯ 0 of clauseExprToPatLam (multiClauseExpr
87+
[( singlePatternFromPattern (vArg (quote ofⁿ ◇⟦ ` 0 ⟧))
88+
NE.∷ singlePatternFromPattern (vArg ``refl) ∷ []
89+
, inj₂ (just ♯ 0 ⟦ `refl ⟧)) ]))))
90+
∷ []))
91+
92+
-- yes (case p₁ , ... , pₖ of λ where (ofʸ refl , ... , ofʸ refl) → refl)
93+
genTrueCase : Term
94+
genTrueCase = `yes $
95+
`case NE.foldl₁ (quote _,′_ ∙⟦_∣_⟧) (NE.map ♯ vars)
96+
of clauseExprToPatLam (MatchExpr
97+
[ (singlePatternFromPattern
98+
(vArg (NE.foldl₁ (quote _,_ ◇⟦_∣_⟧) (NE.map (λ _ quote ofʸ ◇⟦ ``refl ⟧) vars)))
99+
, inj₂ (just `refl)) ])
86100

87101
toMapDiag : SinglePattern SinglePattern NE.List⁺ SinglePattern × TC (ClauseExpr ⊎ Maybe Term)
88102
toMapDiag p@(_ , arg _ p₁) p'@(_ , arg _ p₂) =
89-
(p NE.∷ [ p' ] , finishMatch (if ⌊ p₁ ≟-Pattern p₂ ⌋ then mapDiag (just p) else mapDiag nothing))
103+
(p NE.∷ [ p' ] , finishMatch (if ⌊ p₁ ≟-Pattern p₂ ⌋ then genBranch (just p) else genBranch nothing))
90104

91105
module _ ⦃ _ : TCOptions ⦄ where
92106
derive-DecEq : List (Name × Name) UnquoteDecl

0 commit comments

Comments
 (0)