Skip to content

Commit 7c9c68d

Browse files
committed
Add evaluation code
1 parent 1df9e8a commit 7c9c68d

File tree

5 files changed

+585
-0
lines changed

5 files changed

+585
-0
lines changed

cypherbench/evaluate.py

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
import argparse
2+
import copy
3+
import json
4+
import os
5+
import math
6+
from tqdm import tqdm
7+
from concurrent.futures import ThreadPoolExecutor, as_completed
8+
from cypherbench.metrics import *
9+
from cypherbench.neo4j_connector import Neo4jConnector
10+
from cypherbench.schema import Nl2CypherSample
11+
12+
13+
RETURN_PATTERN_MAPPING = {
14+
"n_name": "n_name",
15+
"n_prop": "n_prop_combined",
16+
"n_name_prop": "n_prop_combined",
17+
"n_prop_distinct": "n_prop_combined",
18+
"n_prop_array_distinct": "n_prop_combined",
19+
"n_order_by": "n_order_by",
20+
"n_argmax": "n_argmax",
21+
"n_where": "n_where",
22+
"n_agg": "n_agg",
23+
"n_group_by": "n_group_by"
24+
}
25+
26+
METRIC_FUNC_MAPPING = {
27+
'execution_accuracy': execution_accuracy,
28+
'psjs': provenance_subgraph_jaccard_similarity,
29+
'executable': executable,
30+
}
31+
32+
33+
def compute_metrics(item: Nl2CypherSample, metrics, neo4j_conn):
34+
item = copy.deepcopy(item)
35+
for m in metrics:
36+
pred_cypher = item.pred_cypher
37+
if pred_cypher.endswith('<end_of_turn>'):
38+
pred_cypher = pred_cypher[:-len('<end_of_turn>')].strip()
39+
item.metrics[m] = METRIC_FUNC_MAPPING[m](
40+
pred_cypher=pred_cypher,
41+
target_cypher=item.gold_cypher,
42+
neo4j_connector=neo4j_conn
43+
)
44+
return item
45+
46+
47+
def avg_and_round(nums: list[float], n: int = 4):
48+
return round(sum(nums) / len(nums), n) if nums else math.nan
49+
50+
51+
def aggregate(results: list[tuple[str, float]]):
52+
res = {}
53+
for key, value in results:
54+
if key not in res:
55+
res[key] = []
56+
res[key].append(value)
57+
for key, values in res.items():
58+
res[key] = avg_and_round(values)
59+
return res
60+
61+
62+
def main():
63+
parser = argparse.ArgumentParser()
64+
parser.add_argument('--neo4j_info', default='neo4j_info.json')
65+
parser.add_argument('--result_dir', default='output/gpt-4o')
66+
parser.add_argument('--num_threads', type=int, default=8)
67+
parser.add_argument('--metrics', nargs='+', default=['execution_accuracy', 'psjs', 'executable'])
68+
parser.add_argument('--metric_for_agg', default='execution_accuracy')
69+
args = parser.parse_args()
70+
print(args)
71+
print()
72+
73+
with open(os.path.join(args.result_dir, 'result.json')) as fin:
74+
result = [Nl2CypherSample(**item) for item in json.load(fin)]
75+
76+
with open(args.neo4j_info) as fin:
77+
neo4j_info = json.load(fin)
78+
79+
graph2conn = {graph: Neo4jConnector(name=graph, **info) for graph, info in
80+
neo4j_info['full'].items()}
81+
82+
# Use ThreadPoolExecutor for multithreading
83+
result_with_metrics = []
84+
with ThreadPoolExecutor(max_workers=args.num_threads) as executor:
85+
futures = [executor.submit(compute_metrics, item, args.metrics, graph2conn[item.graph]) for item in result]
86+
for future in tqdm(as_completed(futures), total=len(result)):
87+
result_with_metrics.append(future.result())
88+
89+
aggregated = {}
90+
aggregated['overall'] = {m: avg_and_round([item.metrics[m] for item in result_with_metrics]) for m in args.metrics}
91+
92+
metric_for_agg = args.metric_for_agg
93+
aggregated['by_graph'] = aggregate([(item.graph, item.metrics[metric_for_agg]) for item in result_with_metrics])
94+
aggregated['by_match'] = aggregate([(item.from_template.match_category, item.metrics[metric_for_agg])
95+
for item in result_with_metrics])
96+
aggregated['by_return'] = aggregate(
97+
[(RETURN_PATTERN_MAPPING[item.from_template.return_pattern_id], item.metrics[metric_for_agg])
98+
for item in result_with_metrics if item.from_template.return_pattern_id in RETURN_PATTERN_MAPPING]
99+
)
100+
101+
output_path = os.path.join(args.result_dir, f'result_with_metrics.json')
102+
with open(output_path, 'w') as fout:
103+
json.dump([item.model_dump(mode='json') for item in result_with_metrics], fout, indent=2)
104+
print(f'Saved result with metrics to {output_path}')
105+
106+
output_path = os.path.join(args.result_dir, f'aggregated_metrics.json')
107+
with open(output_path, 'w') as fout:
108+
json.dump(aggregated, fout, indent=2)
109+
print(f'Saved aggregated metrics to {output_path}')
110+
111+
print()
112+
print('Aggregated metrics:')
113+
print(json.dumps(aggregated, indent=2))
114+
115+
116+
if __name__ == '__main__':
117+
main()

cypherbench/metrics/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .execution_accuracy import execution_accuracy
2+
from .executable import executable
3+
from .provenance_subgraph_jaccard_similarity import provenance_subgraph_jaccard_similarity

cypherbench/metrics/executable.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
import neo4j
2+
from cypherbench.neo4j_connector import Neo4jConnector
3+
4+
5+
def executable(pred_cypher: str,
6+
target_cypher: str,
7+
neo4j_connector: Neo4jConnector,
8+
timeout: int = 120) -> float:
9+
"""Whether the predicted Cypher query is executable"""
10+
try:
11+
neo4j_connector.run_query(pred_cypher, timeout=timeout)
12+
except (
13+
neo4j.exceptions.CypherSyntaxError,
14+
neo4j.exceptions.DatabaseError,
15+
neo4j.exceptions.CypherTypeError,
16+
neo4j.exceptions.ClientError,
17+
) as e:
18+
return 0.0
19+
except Exception as e:
20+
print(f"Warning: Exception {e} occurred while executing the predicted Cypher query {pred_cypher}")
21+
return 0.0
22+
23+
return 1.0
Lines changed: 211 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,211 @@
1+
"""
2+
Some code are adapted from https://github.com/taoyds/test-suite-sql-eval
3+
"""
4+
5+
6+
import re
7+
from itertools import chain
8+
from itertools import product
9+
from collections import defaultdict
10+
import random
11+
import math
12+
import time
13+
import neo4j
14+
from typing import List, Tuple, Dict, Set
15+
from typing import Literal
16+
from cypherbench.neo4j_connector import Neo4jConnector
17+
18+
19+
def to_hashable(obj, unorder_list=True):
20+
"""
21+
Recursively transforms a list, dictionary, or set into a hashable object.
22+
Lists and sets are converted to tuples. Dictionaries are converted to tuples of sorted (key, value) pairs.
23+
24+
Args:
25+
obj: The object to be transformed into a hashable form.
26+
27+
Returns:
28+
A hashable version of the input object.
29+
"""
30+
if isinstance(obj, (tuple, int, float, str, bool, type(None))):
31+
# These are already hashable
32+
return obj
33+
elif isinstance(obj, neo4j.time.Date):
34+
return obj.iso_format()
35+
elif isinstance(obj, (list, tuple)):
36+
# Convert list to a tuple
37+
if unorder_list:
38+
return tuple(sorted(to_hashable(item) for item in obj))
39+
else:
40+
return tuple(to_hashable(item) for item in obj)
41+
elif isinstance(obj, set):
42+
# Convert set to a tuple of sorted elements
43+
return tuple(sorted(to_hashable(item) for item in obj))
44+
elif isinstance(obj, dict):
45+
# Convert dict to a tuple of sorted key-value pairs
46+
return tuple(sorted((to_hashable(k), to_hashable(v)) for k, v in obj.items()))
47+
else:
48+
# For other types, raise an error or handle as needed
49+
raise TypeError(f"Unhashable type: {type(obj)}")
50+
51+
52+
def execution_accuracy(pred_cypher: str,
53+
target_cypher: str,
54+
neo4j_connector: Neo4jConnector,
55+
timeout: int = 120) -> float:
56+
"""Execution accuracy for two cypher queries"""
57+
if pred_cypher == target_cypher:
58+
return 1.0
59+
t0 = time.time()
60+
target_executed = neo4j_connector.run_query(target_cypher)
61+
target_seconds = time.time() - t0
62+
if target_seconds > timeout:
63+
print(f"Warning: Execution of target cypher query {target_cypher} took longer than {timeout} seconds")
64+
try:
65+
pred_executed = neo4j_connector.run_query(pred_cypher, timeout=timeout)
66+
pred_executed = [{k: to_hashable(v) for k, v in record.items()} for record in pred_executed]
67+
except (
68+
neo4j.exceptions.CypherSyntaxError,
69+
neo4j.exceptions.DatabaseError,
70+
neo4j.exceptions.CypherTypeError,
71+
neo4j.exceptions.ClientError,
72+
) as e:
73+
return 0.0
74+
except TypeError as e:
75+
# TODO: For some queries (e.g. queries that bind the path to a variable), the result is not hashable
76+
# However, currently we don't have such queries in the benchmark
77+
# So this exception indicates the predicted Cypher query is incorrect
78+
return 0.0
79+
except Exception as e:
80+
print(f"Warning: Exception {e} occurred while executing the predicted Cypher query {pred_cypher}")
81+
return 0.0
82+
83+
target_executed = [{k: to_hashable(v) for k, v in record.items()} for record in target_executed]
84+
return _compare_execution(
85+
pred_executed=pred_executed,
86+
target_executed=target_executed,
87+
order_matters='order by' in target_cypher.lower()
88+
)
89+
90+
91+
def permute_tuple(element: Tuple, perm: Tuple) -> Tuple:
92+
assert len(element) == len(perm)
93+
return tuple([element[i] for i in perm])
94+
95+
96+
def unorder_row(row: Tuple) -> Tuple:
97+
return tuple(sorted(row, key=lambda x: str(x) + str(type(x))))
98+
99+
100+
# unorder each row in the table
101+
# [result_1 and result_2 has the same bag of unordered row]
102+
# is a necessary condition of
103+
# [result_1 and result_2 are equivalent in denotation]
104+
def quick_rej(result1: List[Tuple], result2: List[Tuple], order_matters: bool) -> bool:
105+
s1 = [unorder_row(row) for row in result1]
106+
s2 = [unorder_row(row) for row in result2]
107+
if order_matters:
108+
return s1 == s2
109+
else:
110+
return set(s1) == set(s2)
111+
112+
113+
# return whether two bag of relations are equivalent
114+
def multiset_eq(l1: List, l2: List) -> bool:
115+
if len(l1) != len(l2):
116+
return False
117+
d = defaultdict(int)
118+
for e in l1:
119+
d[e] = d[e] + 1
120+
for e in l2:
121+
d[e] = d[e] - 1
122+
if d[e] < 0:
123+
return False
124+
return True
125+
126+
127+
def get_constraint_permutation(tab1_sets_by_columns: List[Set], result2: List[Tuple]):
128+
num_cols = len(result2[0])
129+
perm_constraints = [{i for i in range(num_cols)} for _ in range(num_cols)]
130+
if num_cols <= 3:
131+
return product(*perm_constraints)
132+
133+
# we sample 20 rows and constrain the space of permutations
134+
for _ in range(20):
135+
random_tab2_row = random.choice(result2)
136+
137+
for tab1_col in range(num_cols):
138+
for tab2_col in set(perm_constraints[tab1_col]):
139+
if random_tab2_row[tab2_col] not in tab1_sets_by_columns[tab1_col]:
140+
perm_constraints[tab1_col].remove(tab2_col)
141+
return product(*perm_constraints)
142+
143+
144+
# check whether two denotations are correct
145+
def result_eq(result1: List[Tuple], result2: List[Tuple], order_matters: bool) -> bool:
146+
if len(result1) == 0 and len(result2) == 0:
147+
return True
148+
149+
# if length is not the same, then they are definitely different bag of rows
150+
if len(result1) != len(result2):
151+
return False
152+
153+
num_cols = len(result1[0])
154+
155+
# if the results do not have the same number of columns, they are different
156+
if len(result2[0]) != num_cols:
157+
return False
158+
159+
# unorder each row and compare whether the denotation is the same
160+
# this can already find most pair of denotations that are different
161+
if not quick_rej(result1, result2, order_matters):
162+
return False
163+
164+
# the rest of the problem is in fact more complicated than one might think
165+
# we want to find a permutation of column order and a permutation of row order,
166+
# s.t. result_1 is the same as result_2
167+
# we return true if we can find such column & row permutations
168+
# and false if we cannot
169+
tab1_sets_by_columns = [{row[i] for row in result1} for i in range(num_cols)]
170+
171+
# on a high level, we enumerate all possible column permutations that might make result_1 == result_2
172+
# we decrease the size of the column permutation space by the function get_constraint_permutation
173+
# if one of the permutation make result_1, result_2 equivalent, then they are equivalent
174+
for perm in get_constraint_permutation(tab1_sets_by_columns, result2):
175+
if len(perm) != len(set(perm)):
176+
continue
177+
if num_cols == 1:
178+
result2_perm = result2
179+
else:
180+
result2_perm = [permute_tuple(element, perm) for element in result2]
181+
if order_matters:
182+
if result1 == result2_perm:
183+
return True
184+
else:
185+
# in fact the first condition must hold if the second condition holds
186+
# but the first is way more efficient implementation-wise
187+
# and we use it to quickly reject impossible candidates
188+
if set(result1) == set(result2_perm) and multiset_eq(result1, result2_perm):
189+
return True
190+
return False
191+
192+
193+
def to_tuples(result: List[Dict]) -> List[Tuple]:
194+
keys = list(result[0].keys())
195+
for row in result:
196+
assert set(row.keys()) == set(keys)
197+
return [tuple([row[key] for key in keys]) for row in result]
198+
199+
200+
def _compare_execution(
201+
pred_executed: list[dict], target_executed: list[dict], order_matters: bool
202+
) -> float:
203+
"""Execution match considering same order of the output"""
204+
if not pred_executed and not target_executed:
205+
return 1.0
206+
elif not pred_executed or not target_executed:
207+
return 0.0
208+
209+
gold_tuples = to_tuples(target_executed)
210+
pred_tuples = to_tuples(pred_executed)
211+
return float(result_eq(gold_tuples, pred_tuples, order_matters=order_matters))

0 commit comments

Comments
 (0)