|
4 | 4 | from collections import Counter
|
5 | 5 | import logging
|
6 | 6 | import random
|
| 7 | +import textwrap |
7 | 8 |
|
8 | 9 | import rdflib
|
9 | 10 | from rdflib import URIRef
|
|
13 | 14 | from gp_learner import mutate_increase_dist
|
14 | 15 | from gp_learner import mutate_merge_var
|
15 | 16 | from gp_learner import mutate_simplify_pattern
|
16 |
| -from gp_learner import mutate_deep_narrow_path |
17 | 17 | from graph_pattern import GraphPattern
|
18 | 18 | from graph_pattern import SOURCE_VAR
|
19 | 19 | from graph_pattern import TARGET_VAR
|
@@ -109,33 +109,36 @@ def test_mutate_merge_var():
|
109 | 109 | assert False, "merge never reached one of the cases: %s" % cases
|
110 | 110 |
|
111 | 111 |
|
112 |
| -def test_mutate_deep_narrow_path(): |
113 |
| - p = Variable('p') |
114 |
| - gp = GraphPattern([ |
115 |
| - (SOURCE_VAR, p, TARGET_VAR) |
116 |
| - ]) |
117 |
| - child = mutate_deep_narrow_path(gp) |
118 |
| - assert gp == child or len(child) > len(gp) |
119 |
| - print(gp) |
120 |
| - print(child) |
121 |
| - |
| 112 | +def test_deep_narrow_path_query(): |
| 113 | + node_var = Variable('node_var') |
| 114 | + edge_var = Variable('edge_var') |
| 115 | + gtps = [ |
| 116 | + (dbp['Barrel'], dbp['Wine']), |
| 117 | + (dbp['Barrister'], dbp['Law']), |
| 118 | + (dbp['Beak'], dbp['Bird']), |
| 119 | + (dbp['Blanket'], dbp['Bed']), |
| 120 | + ] |
122 | 121 |
|
123 |
| -def test_to_find_edge_var_for_narrow_path_query(): |
124 |
| - node_var = Variable('node_variable') |
125 |
| - edge_var = Variable('edge_variable') |
126 | 122 | gp = GraphPattern([
|
127 | 123 | (node_var, edge_var, SOURCE_VAR),
|
128 | 124 | (SOURCE_VAR, wikilink, TARGET_VAR)
|
129 | 125 | ])
|
130 |
| - filter_node_count = 10 |
131 |
| - filter_edge_count = 1 |
132 |
| - limit_res = 32 |
133 |
| - vars_ = {SOURCE_VAR,TARGET_VAR} |
134 |
| - res = GraphPattern.to_find_edge_var_for_narrow_path_query(gp, edge_var, node_var, |
135 |
| - vars_, filter_node_count, |
136 |
| - filter_edge_count, limit_res) |
137 |
| - print(gp) |
138 |
| - print(res) |
| 126 | + |
| 127 | + vars_ = (SOURCE_VAR, TARGET_VAR) |
| 128 | + res = gp.to_deep_narrow_path_query( |
| 129 | + edge_var, node_var, vars_, {vars_: gtps}, |
| 130 | + limit=32, |
| 131 | + max_node_count=10, |
| 132 | + min_edge_count=2, |
| 133 | + ).strip() |
| 134 | + doc = gp.to_deep_narrow_path_query.__doc__ |
| 135 | + doc_str_example_query = "\n".join([ |
| 136 | + l for l in doc.splitlines() |
| 137 | + if l.startswith(' ') |
| 138 | + ]) |
| 139 | + doc_str_example_query = textwrap.dedent(doc_str_example_query) |
| 140 | + assert res == doc_str_example_query, \ |
| 141 | + "res:\n%s\n\ndoes not look like:\n\n%s" % (res, doc_str_example_query) |
139 | 142 |
|
140 | 143 |
|
141 | 144 | def test_simplify_pattern():
|
@@ -303,5 +306,4 @@ def test_gtp_scores():
|
303 | 306 |
|
304 | 307 |
|
305 | 308 | if __name__ == '__main__':
|
306 |
| - # test_mutate_deep_narrow_path() |
307 |
| - test_to_find_edge_var_for_narrow_path_query() |
| 309 | + test_deep_narrow_path_query() |
0 commit comments