aboutsummaryrefslogtreecommitdiff
path: root/discocaml/draw_tree.ml
blob: 60a8aa5c8c74f415dabfcc48e4c018bbadf64228 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
open Ast
module IntSet = Set.Make (Int)

let fmt_with_possible_subscript (fmt : Format.formatter) (s : string) : unit =
  match Util.break_to_subscript s with
  | Some (name, sub) ->
      Format.fprintf fmt "<%s<FONT POINT-SIZE=\"16\"><SUB>%d</SUB></FONT>>" name
        sub
  | None -> Format.fprintf fmt "%S" s

let colors =
  [| "#56b4e9"; "#009e73"; "#e69f00"; "#d55e00"; "#cc79a7"; "#ffaa14" |]

let get_color (i : int) = colors.(i mod Array.length colors)

let assign_colors (ast : expr ast) : string array =
  let binders = get_binders ast in
  let out = Array.make (Arraylist.length ast.subexprs) "#000000" in

  (* First, accumulate lists of binders for each "basename". *)
  let name_counts = Hashtbl.create 16 in
  binders
  |> Array.iteri (fun i -> function
       | Some (`Binder name) ->
           let base =
             match Util.break_to_subscript name with
             | Some (base, _) -> base
             | None -> name
           in
           let binders =
             match Hashtbl.find_opt name_counts base with
             | Some binders -> binders
             | None ->
                 let binders = Arraylist.make 0 0 in
                 Hashtbl.add name_counts base binders;
                 binders
           in
           Arraylist.push binders i
       | _ -> ());

  (* Next, go through the binders to find those with basename collisions. If
     there's overlap, assign colors to the binders. *)
  name_counts
  |> Hashtbl.iter (fun _ binders ->
         if Arraylist.length binders > 1 then
           Arraylist.to_seq binders
           |> Seq.iteri (fun i j -> out.(j) <- get_color i));

  (* Finally, find the bound variables and copy their colors. *)
  binders
  |> Array.iteri (fun i -> function
       | Some (`Bound j) -> out.(i) <- out.(j.index) | _ -> ());

  out

let add_node (fmt : Format.formatter) (i : expr index) (expr : expr)
    (color : string) : unit =
  match expr with
  | App _ ->
      Format.fprintf fmt
        "  expr%d [fontname=\"CMU Typewriter Text Bold\", label=\"apply\"];\n"
        i.index
  | Bool b -> Format.fprintf fmt "  expr%d [label=\"%b\"];\n" i.index b
  | Cons _ ->
      Format.fprintf fmt
        "  expr%d [fontname=\"CMU Typewriter Text Bold\", label=\"::\"];\n"
        i.index
  | If (_, _, _) ->
      Format.fprintf fmt
        "  expr%d [fontname=\"CMU Typewriter Text Bold\", label=\"if\"];\n"
        i.index
  | Int n -> Format.fprintf fmt "  expr%d [label=\"%d\"];\n" i.index n
  | Lam _ ->
      Format.fprintf fmt
        "  expr%d [fontname=\"CMU Typewriter Text Bold\", label=\"λ\"];\n"
        i.index
  | Let (recursive, _, _, _) ->
      Format.fprintf fmt
        "  expr%d [fontname=\"CMU Typewriter Text Bold\", label=\"%s\"];\n"
        i.index
        (if recursive then "let rec" else "let")
  | Nil ->
      Format.fprintf fmt
        "  expr%d [fontname=\"CMU Typewriter Text Bold\", label=\"[]\"];\n"
        i.index
  | Prim (Add, _) -> Format.fprintf fmt "  expr%d [label=\"+\"];\n" i.index
  | Prim (Sub, _) -> Format.fprintf fmt "  expr%d [label=\"-\"];\n" i.index
  | Prim (Mul, _) -> Format.fprintf fmt "  expr%d [label=\"*\"];\n" i.index
  | Prim (RelOp, (op, _, _)) ->
      Format.fprintf fmt "  expr%d [label=\"%s\"];\n" i.index
        (string_of_relop op)
  | Var x ->
      Format.fprintf fmt "  expr%d [color=\"%s\", label=%a];\n" i.index color
        fmt_with_possible_subscript x

let add_expr_edges (ast : 'a ast) (fmt : Format.formatter)
    (nodes : IntSet.t ref) (colors : string array) : unit =
  let rec loop (i : expr index) : unit =
    nodes := IntSet.add i.index !nodes;
    let edge_to (j : expr index) : unit =
      loop j;
      Format.fprintf fmt "  expr%d -> expr%d;\n" i.index j.index
    in
    match get_subexpr ast i with
    | App (f, x) ->
        edge_to f;
        edge_to x
    | Cons (hd, tl) ->
        edge_to hd;
        edge_to tl
    | If (cond, then_, else_) ->
        edge_to cond;
        edge_to then_;
        edge_to else_
    | Lam (x, b) ->
        Format.fprintf fmt "  expr%d -> expr%d_var;\n" i.index i.index;
        Format.fprintf fmt
          "  expr%d_var [color=\"%s\", label=%a, shape=\"ellipse\"];\n" i.index
          colors.(i.index) fmt_with_possible_subscript x;
        edge_to b
    | Let (_, name, bound, body) ->
        Format.fprintf fmt "  expr%d -> expr%d_var;\n" i.index i.index;
        Format.fprintf fmt
          "  expr%d_var [color=\"%s\", label=%a, shape=\"ellipse\"];\n" i.index
          colors.(i.index) fmt_with_possible_subscript name;
        edge_to bound;
        edge_to body
    | Prim (Add, (l, r)) ->
        edge_to l;
        edge_to r
    | Prim (Sub, (l, r)) ->
        edge_to l;
        edge_to r
    | Prim (Mul, (l, r)) ->
        edge_to l;
        edge_to r
    | Prim (RelOp, (_, l, r)) ->
        edge_to l;
        edge_to r
    | Var _ -> ()
    | Bool _ | Int _ | Nil -> ()
  in
  loop ast.root

let draw_tree (ast : expr ast) : string =
  let buf = Buffer.create 16 and nodes = ref IntSet.empty in
  let fmt = Format.formatter_of_buffer buf in
  Format.fprintf fmt "digraph {\n";
  Format.fprintf fmt
    "  node [fontsize=\"20pt\", penwidth=\"2.0\", shape=\"box\", \
     style=\"rounded\"];\n";

  let colors = assign_colors ast in

  add_expr_edges ast fmt nodes colors;
  Format.fprintf fmt "\n";

  IntSet.iter
    (fun index ->
      let i = { index } in
      add_node fmt i (get_subexpr ast i) colors.(i.index))
    !nodes;

  (*
  get_binders ast
  |> Array.iteri (fun i -> function
       | Some j ->
           Format.fprintf fmt
             "  expr%d -> expr%d_var [color=\"#cccccc\", constraint=false];\n" i
             j.index
       | None -> ());
  *)
  Format.fprintf fmt "}\n";
  Format.pp_print_flush fmt ();
  Buffer.contents buf