@@ -3,13 +3,9 @@ type node = {
33 neighbours: list (string ),
44};
55
6- type mark =
7- | Temp
8- | Perm ;
9-
106type directedGraph = list (node );
117
12- exception Not_found ( string) ;
8+ exception Invalid_node_id ( string) ;
139
1410exception Graph_not_DAG ;
1511
@@ -21,9 +17,9 @@ let parseAdjList = adj_list => {
2117
2218 List . iter(insert, adj_list);
2319 let validateNeighbours = node => {
24- List . iter(neighbour => {
25- if (! Hashtbl . mem(adj_tbl, neighbour )) {
26- raise (Not_found (neighbour ));
20+ List . iter(neighbour_id => {
21+ if (! Hashtbl . mem(adj_tbl, neighbour_id )) {
22+ raise (Invalid_node_id (neighbour_id ));
2723 }
2824 }, node. neighbours);
2925 };
@@ -32,44 +28,46 @@ let parseAdjList = adj_list => {
3228 adj_tbl;
3329};
3430
35- let rec visit = (~node_id, ~adj_tbl, ~mark_tbl, ~list) => {
36- if (Hashtbl . mem(mark_tbl, node_id)) {
37- switch (Hashtbl . find(mark_tbl, node_id)) {
38- | Temp => raise (Graph_not_DAG );
39- | Perm => list;
31+ let rec visit = (~node_id, ~adj_tbl, ~visited, ~ordering, ~ancestors) => {
32+ if (Hashtbl . mem(visited, node_id)) {
33+ let pred = id => id == node_id;
34+ switch (List . find(pred, ancestors)) {
35+ | exception Not_found => ordering;
36+ | _ => raise (Graph_not_DAG );
4037 };
4138 }
4239 else {
40+ let ancestors = List . append(ancestors, [ node_id] );
4341 let neighbours = Hashtbl . find(adj_tbl, node_id);
44- let visitNeighbour = (sorting , neighbour_id) => {
42+ let visitNeighbour = (ordering , neighbour_id) => {
4543 visit(
4644 ~node_id= neighbour_id,
4745 ~adj_tbl= adj_tbl,
48- ~mark_tbl= mark_tbl,
49- ~list= sorting);
46+ ~visited= visited,
47+ ~ordering= ordering,
48+ ~ancestors= ancestors);
5049 };
5150
52- Hashtbl . add(mark_tbl, node_id, Temp );
53- let list = List . fold_left(visitNeighbour, list, neighbours);
54- Hashtbl . add(mark_tbl, node_id, Perm );
55-
51+ Hashtbl . add(visited, node_id, node_id);
52+ let list = List . fold_left(visitNeighbour, ordering, neighbours);
5653 List . append(list, [ node_id] );
5754 };
5855};
5956
6057let sort = adj_list => {
6158 let num_nodes = List . length(adj_list);
62- let mark_tbl = Hashtbl . create(num_nodes);
6359 let adj_tbl = parseAdjList(adj_list);
60+ let visited = Hashtbl . create(num_nodes);
6461
65- let traverse = (list , node) => {
62+ let traverse = (ordering , node) => {
6663 visit(
6764 ~node_id= node. id,
6865 ~adj_tbl= adj_tbl,
69- ~mark_tbl= mark_tbl,
70- ~list= list);
66+ ~visited= visited,
67+ ~ordering= ordering,
68+ ~ancestors= [] );
7169 };
7270
73- let list = List . fold_left(traverse, [] , adj_list);
74- List . rev(list );
71+ let top_sorting = List . fold_left(traverse, [] , adj_list);
72+ List . rev(top_sorting );
7573};
0 commit comments