aboutsummaryrefslogtreecommitdiff
path: root/discocaml/draw_tree.ml
blob: 6ea4ff2c9916c135be6df323f87bff5f5f09efd7 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
open Ast
module IntSet = Set.Make (Int)

let add_node (fmt : Format.formatter) (i : expr index) (expr : expr) : unit =
  match expr with
  | App _ ->
      Format.fprintf fmt
        "  expr%d [fontname=\"CMU Typewriter Text Bold\", label=\"apply\"];\n"
        i.index
  | Int n -> Format.fprintf fmt "  expr%d [label=\"%d\"];\n" i.index n
  | Lam (_, _) -> Format.fprintf fmt "  expr%d [label=\"λ\"];\n" i.index
  | Prim (`Add, _) -> Format.fprintf fmt "  expr%d [label=\"+\"];\n" i.index
  | Prim (`Sub, _) -> Format.fprintf fmt "  expr%d [label=\"-\"];\n" i.index
  | Prim (`Mul, _) -> Format.fprintf fmt "  expr%d [label=\"*\"];\n" i.index
  | Var n -> Format.fprintf fmt "  expr%d [label=%S];\n" i.index n

let add_expr_edges (ast : 'a ast) (fmt : Format.formatter)
    (nodes : IntSet.t ref) : expr index -> unit =
  let rec loop (i : expr index) : unit =
    nodes := IntSet.add i.index !nodes;
    let edge_to (j : expr index) : unit =
      loop j;
      Format.fprintf fmt "  expr%d -> expr%d;\n" i.index j.index
    in
    match get_subexpr ast i with
    | App (f, x) ->
        edge_to f;
        edge_to x
    | Int _ -> ()
    | Lam (x, b) ->
        Format.fprintf fmt "  expr%d -> expr%d_var;\n" i.index i.index;
        Format.fprintf fmt "  expr%d_var [label=%S];\n" i.index x;
        edge_to b
    | Prim (_, xs) -> Array.iter edge_to xs
    | Var _ -> ()
  in
  loop

let draw_tree (ast : expr ast) : string =
  let buf = Buffer.create 16 and nodes = ref IntSet.empty in
  let fmt = Format.formatter_of_buffer buf in
  Format.fprintf fmt "digraph {\n";
  add_expr_edges ast fmt nodes ast.root;
  Format.fprintf fmt "\n";
  IntSet.iter
    (fun index ->
      let i = { index } in
      add_node fmt i (get_subexpr ast i))
    !nodes;
  Format.fprintf fmt "}\n";
  Format.pp_print_flush fmt ();
  Buffer.contents buf