open Ast module IntSet = Set.Make (Int) let fmt_with_possible_subscript (fmt : Format.formatter) (s : string) : unit = match Util.break_to_subscript s with | Some (name, sub) -> Format.fprintf fmt "<%s%d>" name sub | None -> Format.fprintf fmt "%S" s let colors = [| "#56b4e9"; "#009e73"; "#e69f00"; "#d55e00"; "#cc79a7"; "#ffaa14" |] let get_color (i : int) = colors.(i mod Array.length colors) let assign_colors (ast : expr ast) : string array = let binders = get_binders ast in let out = Array.make (Arraylist.length ast.subexprs) "#000000" in (* First, accumulate lists of binders for each "basename". *) let name_counts = Hashtbl.create 16 in binders |> Array.iteri (fun i -> function | Some (`Binder name) -> let base = match Util.break_to_subscript name with | Some (base, _) -> base | None -> name in let binders = match Hashtbl.find_opt name_counts base with | Some binders -> binders | None -> let binders = Arraylist.make 0 0 in Hashtbl.add name_counts base binders; binders in Arraylist.push binders i | _ -> ()); (* Next, go through the binders to find those with basename collisions. If there's overlap, assign colors to the binders. *) name_counts |> Hashtbl.iter (fun _ binders -> if Arraylist.length binders > 1 then Arraylist.to_seq binders |> Seq.iteri (fun i j -> out.(j) <- get_color i)); (* Finally, find the bound variables and copy their colors. *) binders |> Array.iteri (fun i -> function | Some (`Bound j) -> out.(i) <- out.(j.index) | _ -> ()); out let add_node (fmt : Format.formatter) (i : expr index) (expr : expr) (color : string) : unit = match expr with | App _ -> Format.fprintf fmt " expr%d [fontname=\"CMU Typewriter Text Bold\", label=\"apply\"];\n" i.index | Bool b -> Format.fprintf fmt " expr%d [label=\"%b\"];\n" i.index b | Cons _ -> Format.fprintf fmt " expr%d [fontname=\"CMU Typewriter Text Bold\", label=\"::\"];\n" i.index | If (_, _, _) -> Format.fprintf fmt " expr%d [fontname=\"CMU Typewriter Text Bold\", label=\"if\"];\n" i.index | Int n -> Format.fprintf fmt " expr%d [label=\"%d\"];\n" i.index n | Lam _ -> Format.fprintf fmt " expr%d [fontname=\"CMU Typewriter Text Bold\", label=\"λ\"];\n" i.index | Let (recursive, _, _, _) -> Format.fprintf fmt " expr%d [fontname=\"CMU Typewriter Text Bold\", label=\"%s\"];\n" i.index (if recursive then "let rec" else "let") | Nil -> Format.fprintf fmt " expr%d [fontname=\"CMU Typewriter Text Bold\", label=\"[]\"];\n" i.index | Prim (Add, _) -> Format.fprintf fmt " expr%d [label=\"+\"];\n" i.index | Prim (Sub, _) -> Format.fprintf fmt " expr%d [label=\"-\"];\n" i.index | Prim (Mul, _) -> Format.fprintf fmt " expr%d [label=\"*\"];\n" i.index | Prim (RelOp, (op, _, _)) -> Format.fprintf fmt " expr%d [label=\"%s\"];\n" i.index (string_of_relop op) | Var x -> Format.fprintf fmt " expr%d [color=\"%s\", label=%a];\n" i.index color fmt_with_possible_subscript x let add_expr_edges (ast : 'a ast) (fmt : Format.formatter) (nodes : IntSet.t ref) (colors : string array) : unit = let rec loop (i : expr index) : unit = nodes := IntSet.add i.index !nodes; let edge_to (j : expr index) : unit = loop j; Format.fprintf fmt " expr%d -> expr%d;\n" i.index j.index in match get_subexpr ast i with | App (f, x) -> edge_to f; edge_to x | Cons (hd, tl) -> edge_to hd; edge_to tl | If (cond, then_, else_) -> edge_to cond; edge_to then_; edge_to else_ | Lam (x, b) -> Format.fprintf fmt " expr%d -> expr%d_var;\n" i.index i.index; Format.fprintf fmt " expr%d_var [color=\"%s\", label=%a, shape=\"ellipse\"];\n" i.index colors.(i.index) fmt_with_possible_subscript x; edge_to b | Let (_, name, bound, body) -> Format.fprintf fmt " expr%d -> expr%d_var;\n" i.index i.index; Format.fprintf fmt " expr%d_var [color=\"%s\", label=%a, shape=\"ellipse\"];\n" i.index colors.(i.index) fmt_with_possible_subscript name; edge_to bound; edge_to body | Prim (Add, (l, r)) -> edge_to l; edge_to r | Prim (Sub, (l, r)) -> edge_to l; edge_to r | Prim (Mul, (l, r)) -> edge_to l; edge_to r | Prim (RelOp, (_, l, r)) -> edge_to l; edge_to r | Var _ -> () | Bool _ | Int _ | Nil -> () in loop ast.root let draw_tree (ast : expr ast) : string = let buf = Buffer.create 16 and nodes = ref IntSet.empty in let fmt = Format.formatter_of_buffer buf in Format.fprintf fmt "digraph {\n"; Format.fprintf fmt " node [fontsize=\"20pt\", penwidth=\"2.0\", shape=\"box\", \ style=\"rounded\"];\n"; let colors = assign_colors ast in add_expr_edges ast fmt nodes colors; Format.fprintf fmt "\n"; IntSet.iter (fun index -> let i = { index } in add_node fmt i (get_subexpr ast i) colors.(i.index)) !nodes; (* get_binders ast |> Array.iteri (fun i -> function | Some j -> Format.fprintf fmt " expr%d -> expr%d_var [color=\"#cccccc\", constraint=false];\n" i j.index | None -> ()); *) Format.fprintf fmt "}\n"; Format.pp_print_flush fmt (); Buffer.contents buf