aboutsummaryrefslogtreecommitdiff
path: root/discocaml/draw_tree.ml
blob: 08a8288fe61aa3a3c5080f244319f653fdf75b9b (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
open Ast
module IntSet = Set.Make (Int)

let add_node (fmt : Format.formatter) (i : expr index) (expr : expr) : unit =
  match expr with
  | App _ ->
      Format.fprintf fmt
        "  expr%d [fontname=\"CMU Typewriter Text Bold\", label=\"apply\"];\n"
        i.index
  | Bool b -> Format.fprintf fmt "  expr%d [label=\"%b\"];\n" i.index b
  | Cons _ ->
      Format.fprintf fmt
        "  expr%d [fontname=\"CMU Typewriter Text Bold\", label=\"::\"];\n"
        i.index
  | If (_, _, _) -> Format.fprintf fmt "  expr%d [label=\"if\"];\n" i.index
  | Int n -> Format.fprintf fmt "  expr%d [label=\"%d\"];\n" i.index n
  | Lam _ -> Format.fprintf fmt "  expr%d [label=\"λ\"];\n" i.index
  | Let (recursive, _, _, _) ->
      Format.fprintf fmt "  expr%d [label=\"%s\"];\n" i.index
        (if recursive then "letrec" else "let")
  | Nil ->
      Format.fprintf fmt
        "  expr%d [fontname=\"CMU Typewriter Text Bold\", label=\"[]\"];\n"
        i.index
  | Prim (Add, _) -> Format.fprintf fmt "  expr%d [label=\"+\"];\n" i.index
  | Prim (Sub, _) -> Format.fprintf fmt "  expr%d [label=\"-\"];\n" i.index
  | Prim (Mul, _) -> Format.fprintf fmt "  expr%d [label=\"*\"];\n" i.index
  | Var x -> Format.fprintf fmt "  expr%d [label=%S];\n" i.index x

let add_expr_edges (ast : 'a ast) (fmt : Format.formatter)
    (nodes : IntSet.t ref) : expr index -> unit =
  let rec loop (i : expr index) : unit =
    nodes := IntSet.add i.index !nodes;
    let edge_to (j : expr index) : unit =
      loop j;
      Format.fprintf fmt "  expr%d -> expr%d;\n" i.index j.index
    in
    match get_subexpr ast i with
    | App (f, x) ->
        edge_to f;
        edge_to x
    | Cons (hd, tl) ->
        edge_to hd;
        edge_to tl
    | If (cond, then_, else_) ->
        edge_to cond;
        edge_to then_;
        edge_to else_
    | Lam (x, b) ->
        Format.fprintf fmt "  expr%d -> expr%d_var;\n" i.index i.index;
        Format.fprintf fmt "  expr%d_var [label=%S];\n" i.index x;
        edge_to b
    | Let (_, name, bound, body) ->
        Format.fprintf fmt "  expr%d -> expr%d_var;\n" i.index i.index;
        Format.fprintf fmt "  expr%d_var [label=%S];\n" i.index name;
        edge_to bound;
        edge_to body
    | Prim (Add, (l, r)) ->
        edge_to l;
        edge_to r
    | Prim (Sub, (l, r)) ->
        edge_to l;
        edge_to r
    | Prim (Mul, (l, r)) ->
        edge_to l;
        edge_to r
    | Var _ -> ()
    | Bool _ | Int _ | Nil -> ()
  in
  loop

let draw_tree (ast : expr ast) : string =
  let buf = Buffer.create 16 and nodes = ref IntSet.empty in
  let fmt = Format.formatter_of_buffer buf in
  Format.fprintf fmt "digraph {\n";
  add_expr_edges ast fmt nodes ast.root;
  Format.fprintf fmt "\n";
  IntSet.iter
    (fun index ->
      let i = { index } in
      add_node fmt i (get_subexpr ast i))
    !nodes;
  Format.fprintf fmt "}\n";
  Format.pp_print_flush fmt ();
  Buffer.contents buf