Skip to content

Compiler: simplify branch #2057

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jun 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
* Compiler: support for OCaml 4.14.3+trunk (#1844)
* Compiler: add the `--empty-sourcemap` flag
* Compiler: improve debug/sourcemap location of closures (#1947)
* Compiler: optimize compilation of switches
* Compiler: optimize compilation of switches (#1921, #2057)
* Compiler: evaluate statically more primitives (#1912, #1915, #1965, #1969)
* Compiler: rewrote inlining pass (#1935, #2018, #2027)
* Compiler: improve tailcall optimization (#1943)
Expand Down
232 changes: 176 additions & 56 deletions compiler/lib/specialize.ml
Original file line number Diff line number Diff line change
Expand Up @@ -130,11 +130,107 @@ let f ~shape ~update_def p =

(***)

module Simple_block : sig
type t

val hash : t -> int

val equal : t -> t -> bool

val make : block -> t
end = struct
type t = block

let subst_cont s (pc, arg) = pc, List.map arg ~f:s

let expr s e =
match e with
| Constant _ -> e
| Apply { f; args; exact } -> Apply { f = s f; args = List.map args ~f:s; exact }
| Block (n, a, k, mut) -> Block (n, Array.map a ~f:s, k, mut)
| Field (x, n, typ) -> Field (s x, n, typ)
| Closure (l, pc, loc) -> Closure (l, subst_cont s pc, loc)
| Special _ -> e
| Prim (p, l) ->
Prim
( p
, List.map l ~f:(fun x ->
match x with
| Pv x -> Pv (s x)
| Pc _ -> x) )

let instr s d i =
match i with
| Let (x, e) ->
let x = d x in
Let (x, expr s e)
| Assign (x, y) -> Assign (s x, s y)
| Set_field (x, n, typ, y) -> Set_field (s x, n, typ, s y)
| Offset_ref (x, n) -> Offset_ref (s x, n)
| Array_set (x, y, z) -> Array_set (s x, s y, s z)
| Event _ -> Event Parse_info.zero

let instrs s d l = List.map l ~f:(fun i -> instr s d i)

let last s l =
match l with
| Stop -> l
| Branch cont -> Branch (subst_cont s cont)
| Pushtrap (cont1, x, cont2) -> Pushtrap (subst_cont s cont1, s x, subst_cont s cont2)
| Return x -> Return (s x)
| Raise (x, k) -> Raise (s x, k)
| Cond (x, cont1, cont2) -> Cond (s x, subst_cont s cont1, subst_cont s cont2)
| Switch (x, conts) -> Switch (s x, Array.map conts ~f:(fun cont -> subst_cont s cont))
| Poptrap cont -> Poptrap (subst_cont s cont)

let block s d block =
let params = List.map block.params ~f:s in
let body = instrs s d block.body in
let branch = last s block.branch in
{ params; body; branch }

let make blk =
let t = Var.Hashtbl.create 17 in
let s x =
match Var.Hashtbl.find_opt t x with
| None -> x
| Some x -> x
in
let d x =
let v = Var.of_idx (-Var.Hashtbl.length t) in
Var.Hashtbl.add t x v;
v
in
block s d blk

let instr_equal a b =
match a, b with
| Event _, Event _ -> true
| Event _, _ | _, Event _ -> false
| a, b -> Poly.equal a b

let equal a b =
List.equal ~eq:Var.equal a.params b.params
&& List.equal ~eq:instr_equal a.body b.body
&& Poly.equal a.branch b.branch

let hash (x : block) = Hashtbl.hash x
end

module SBT = Hashtbl.Make (Simple_block)

(* For switches, at this point, we know that this it is sufficient to
check the [pc]. *)
let equal (pc, _) (pc', _) = pc = pc'

let find_outlier_index arr =
type switch_to_cond =
[ `All_equals
| `Distinguished of int
| `Splitted of int
| `Splitted_shifted of int * int
]

let find_outlier_index arr : [ switch_to_cond | `Many_cases ] =
let len = Array.length arr in
let rec find w i =
if i >= len
Expand All @@ -159,6 +255,37 @@ let find_outlier_index arr =
| `All_equals -> if j = i + 1 then `Distinguished i else `Splitted_shifted (i, j)
| `Distinguished _ -> `Many_cases))

let optimize_switch_to_cond block x l (opt : switch_to_cond) =
match opt with
| `All_equals -> { block with branch = Branch l.(0) }
| `Distinguished i ->
let c = Var.fresh () in
{ block with
body =
block.body @ [ Let (c, Prim (Eq, [ Pc (Int (Targetint.of_int_exn i)); Pv x ])) ]
; branch = Cond (c, l.(i), l.((i + 1) mod Array.length l))
}
| `Splitted i ->
let c = Var.fresh () in
{ block with
body =
block.body @ [ Let (c, Prim (Lt, [ Pv x; Pc (Int (Targetint.of_int_exn i)) ])) ]
; branch = Cond (c, l.(i - 1), l.(i))
}
| `Splitted_shifted (i, j) ->
let shifted = Var.fresh () in
let c = Var.fresh () in
{ block with
body =
block.body
@ [ Let
( shifted
, Prim (Extern "%int_sub", [ Pv x; Pc (Int (Targetint.of_int_exn i)) ]) )
; Let (c, Prim (Ult, [ Pv shifted; Pc (Int (Targetint.of_int_exn (j - i))) ]))
]
; branch = Cond (c, l.(i), l.(j))
}

let switches p =
let previous_p = p in
let t = Timer.make () in
Expand All @@ -171,63 +298,56 @@ let switches p =
match block.branch with
| Switch (x, l) -> (
match find_outlier_index l with
| `All_equals ->
incr opt_count;
Addr.Map.add pc { block with branch = Branch l.(0) } blocks
| `Distinguished i ->
| #switch_to_cond as opt ->
incr opt_count;
let block =
let c = Var.fresh () in
{ block with
body =
block.body
@ [ Let
(c, Prim (Eq, [ Pc (Int (Targetint.of_int_exn i)); Pv x ]))
]
; branch = Cond (c, l.(i), l.((i + 1) mod Array.length l))
}
in
let block = optimize_switch_to_cond block x l opt in
Addr.Map.add pc block blocks
| `Splitted i ->
incr opt_count;
let block =
let c = Var.fresh () in
{ block with
body =
block.body
@ [ Let
(c, Prim (Lt, [ Pv x; Pc (Int (Targetint.of_int_exn i)) ]))
]
; branch = Cond (c, l.(i - 1), l.(i))
}
in
Addr.Map.add pc block blocks
| `Splitted_shifted (i, j) ->
incr opt_count;
let block =
let shifted = Var.fresh () in
let c = Var.fresh () in
{ block with
body =
block.body
@ [ Let
( shifted
, Prim
( Extern "%int_sub"
, [ Pv x; Pc (Int (Targetint.of_int_exn i)) ] ) )
; Let
( c
, Prim
( Ult
, [ Pv shifted
; Pc (Int (Targetint.of_int_exn (j - i)))
] ) )
]
; branch = Cond (c, l.(i), l.(j))
}
| `Many_cases ->
let t = SBT.create 0 in
let rewrite = ref Addr.Set.empty in
let l =
Array.map l ~f:(fun ((pc, _) as cont) ->
let block = Code.Addr.Map.find pc blocks in
if List.compare_length_with block.body ~len:7 <= 0
then (
let sb = Simple_block.make block in
match SBT.find_opt t sb with
| Some cont' when not (equal cont' cont) ->
rewrite := Addr.Set.add (fst cont') !rewrite;
cont'
| Some _ | None ->
SBT.add t sb cont;
cont)
else cont)
in
Addr.Map.add pc block blocks
| `Many_cases -> blocks)
if not (Addr.Set.is_empty !rewrite)
then (
incr opt_count;
let blocks =
Addr.Set.fold
(fun pc blocks ->
let block = Code.Addr.Map.find pc blocks in
Addr.Map.add
pc
{ block with
body =
List.filter
~f:(function
| Event _ -> false
| _ -> true)
block.body
}
blocks)
!rewrite
blocks
in
match find_outlier_index l with
| #switch_to_cond as opt ->
let block = optimize_switch_to_cond block x l opt in
Addr.Map.add pc block blocks
| `Many_cases ->
Addr.Map.add pc { block with branch = Switch (x, l) } blocks)
else blocks)
| _ -> blocks)
p.blocks
p.blocks
Expand All @@ -237,4 +357,4 @@ let switches p =
if stats () then Format.eprintf "Stats - switches: %d@." !opt_count;
if debug_stats ()
then Code.check_updates ~name:"switches" previous_p p ~updates:!opt_count;
p
Deadcode.remove_unused_blocks p
94 changes: 94 additions & 0 deletions compiler/tests-compiler/cond.ml
Original file line number Diff line number Diff line change
Expand Up @@ -67,3 +67,97 @@ let%expect_test "conditional" =
}
//end
|}]

let%expect_test "conditional" =
let program =
compile_and_parse
{|
type rip_relative_kind =
| Explicitly_rip_relative
| Implicitly_rip_relative
| Not_rip_relative

(** val rip_relative_kind_beq :
rip_relative_kind -> rip_relative_kind -> bool **)

let rip_relative_kind_beq x y =
match x with
| Explicitly_rip_relative ->
(match y with
| Explicitly_rip_relative -> true
| Implicitly_rip_relative -> false
| Not_rip_relative -> false)
| Implicitly_rip_relative ->
(match y with
| Explicitly_rip_relative -> false
| Implicitly_rip_relative -> true
| Not_rip_relative -> false)
| Not_rip_relative ->
(match y with
| Explicitly_rip_relative -> false
| Implicitly_rip_relative -> false
| Not_rip_relative -> true)
|}
in
print_fun_decl program (Some "rip_relative_kind_beq");
[%expect
{|
function rip_relative_kind_beq(x, y){
switch(x){
case 0:
return 0 === y ? 1 : 0;
case 1:
return 1 === y ? 1 : 0;
default: return 2 === y ? 1 : 0;
}
}
//end
|}]

let%expect_test "conditional" =
let program =
compile_and_parse
{|
type rip_relative_kind =
| Explicitly_rip_relative
| Implicitly_rip_relative
| Not_rip_relative

(** val rip_relative_kind_beq :
rip_relative_kind -> rip_relative_kind -> bool **)

let rip_relative_kind_beq x y =
let i = match x with
| Explicitly_rip_relative ->
(match y with
| Explicitly_rip_relative -> 1
| Implicitly_rip_relative -> 2
| Not_rip_relative -> 2)
| Implicitly_rip_relative ->
(match y with
| Explicitly_rip_relative -> 2
| Implicitly_rip_relative -> 1
| Not_rip_relative -> 2)
| Not_rip_relative ->
(match y with
| Explicitly_rip_relative -> 2
| Implicitly_rip_relative -> 2
| Not_rip_relative -> 1)
in print_int i
|}
in
print_fun_decl program (Some "rip_relative_kind_beq");
[%expect
{|
function rip_relative_kind_beq(x, y){
switch(x){
case 0:
var i = 0 === y ? 1 : 2; break;
case 1:
var i = 1 === y ? 1 : 2; break;
default: var i = 2 === y ? 1 : 2;
}
return caml_call1(Stdlib[44], i);
}
//end
|}]
Loading
Loading