|
| 1 | +""" |
| 2 | +Some code are adapted from https://github.com/taoyds/test-suite-sql-eval |
| 3 | +""" |
| 4 | + |
| 5 | + |
| 6 | +import re |
| 7 | +from itertools import chain |
| 8 | +from itertools import product |
| 9 | +from collections import defaultdict |
| 10 | +import random |
| 11 | +import math |
| 12 | +import time |
| 13 | +import neo4j |
| 14 | +from typing import List, Tuple, Dict, Set |
| 15 | +from typing import Literal |
| 16 | +from cypherbench.neo4j_connector import Neo4jConnector |
| 17 | + |
| 18 | + |
| 19 | +def to_hashable(obj, unorder_list=True): |
| 20 | + """ |
| 21 | + Recursively transforms a list, dictionary, or set into a hashable object. |
| 22 | + Lists and sets are converted to tuples. Dictionaries are converted to tuples of sorted (key, value) pairs. |
| 23 | +
|
| 24 | + Args: |
| 25 | + obj: The object to be transformed into a hashable form. |
| 26 | +
|
| 27 | + Returns: |
| 28 | + A hashable version of the input object. |
| 29 | + """ |
| 30 | + if isinstance(obj, (tuple, int, float, str, bool, type(None))): |
| 31 | + # These are already hashable |
| 32 | + return obj |
| 33 | + elif isinstance(obj, neo4j.time.Date): |
| 34 | + return obj.iso_format() |
| 35 | + elif isinstance(obj, (list, tuple)): |
| 36 | + # Convert list to a tuple |
| 37 | + if unorder_list: |
| 38 | + return tuple(sorted(to_hashable(item) for item in obj)) |
| 39 | + else: |
| 40 | + return tuple(to_hashable(item) for item in obj) |
| 41 | + elif isinstance(obj, set): |
| 42 | + # Convert set to a tuple of sorted elements |
| 43 | + return tuple(sorted(to_hashable(item) for item in obj)) |
| 44 | + elif isinstance(obj, dict): |
| 45 | + # Convert dict to a tuple of sorted key-value pairs |
| 46 | + return tuple(sorted((to_hashable(k), to_hashable(v)) for k, v in obj.items())) |
| 47 | + else: |
| 48 | + # For other types, raise an error or handle as needed |
| 49 | + raise TypeError(f"Unhashable type: {type(obj)}") |
| 50 | + |
| 51 | + |
| 52 | +def execution_accuracy(pred_cypher: str, |
| 53 | + target_cypher: str, |
| 54 | + neo4j_connector: Neo4jConnector, |
| 55 | + timeout: int = 120) -> float: |
| 56 | + """Execution accuracy for two cypher queries""" |
| 57 | + if pred_cypher == target_cypher: |
| 58 | + return 1.0 |
| 59 | + t0 = time.time() |
| 60 | + target_executed = neo4j_connector.run_query(target_cypher) |
| 61 | + target_seconds = time.time() - t0 |
| 62 | + if target_seconds > timeout: |
| 63 | + print(f"Warning: Execution of target cypher query {target_cypher} took longer than {timeout} seconds") |
| 64 | + try: |
| 65 | + pred_executed = neo4j_connector.run_query(pred_cypher, timeout=timeout) |
| 66 | + pred_executed = [{k: to_hashable(v) for k, v in record.items()} for record in pred_executed] |
| 67 | + except ( |
| 68 | + neo4j.exceptions.CypherSyntaxError, |
| 69 | + neo4j.exceptions.DatabaseError, |
| 70 | + neo4j.exceptions.CypherTypeError, |
| 71 | + neo4j.exceptions.ClientError, |
| 72 | + ) as e: |
| 73 | + return 0.0 |
| 74 | + except TypeError as e: |
| 75 | + # TODO: For some queries (e.g. queries that bind the path to a variable), the result is not hashable |
| 76 | + # However, currently we don't have such queries in the benchmark |
| 77 | + # So this exception indicates the predicted Cypher query is incorrect |
| 78 | + return 0.0 |
| 79 | + except Exception as e: |
| 80 | + print(f"Warning: Exception {e} occurred while executing the predicted Cypher query {pred_cypher}") |
| 81 | + return 0.0 |
| 82 | + |
| 83 | + target_executed = [{k: to_hashable(v) for k, v in record.items()} for record in target_executed] |
| 84 | + return _compare_execution( |
| 85 | + pred_executed=pred_executed, |
| 86 | + target_executed=target_executed, |
| 87 | + order_matters='order by' in target_cypher.lower() |
| 88 | + ) |
| 89 | + |
| 90 | + |
| 91 | +def permute_tuple(element: Tuple, perm: Tuple) -> Tuple: |
| 92 | + assert len(element) == len(perm) |
| 93 | + return tuple([element[i] for i in perm]) |
| 94 | + |
| 95 | + |
| 96 | +def unorder_row(row: Tuple) -> Tuple: |
| 97 | + return tuple(sorted(row, key=lambda x: str(x) + str(type(x)))) |
| 98 | + |
| 99 | + |
| 100 | +# unorder each row in the table |
| 101 | +# [result_1 and result_2 has the same bag of unordered row] |
| 102 | +# is a necessary condition of |
| 103 | +# [result_1 and result_2 are equivalent in denotation] |
| 104 | +def quick_rej(result1: List[Tuple], result2: List[Tuple], order_matters: bool) -> bool: |
| 105 | + s1 = [unorder_row(row) for row in result1] |
| 106 | + s2 = [unorder_row(row) for row in result2] |
| 107 | + if order_matters: |
| 108 | + return s1 == s2 |
| 109 | + else: |
| 110 | + return set(s1) == set(s2) |
| 111 | + |
| 112 | + |
| 113 | +# return whether two bag of relations are equivalent |
| 114 | +def multiset_eq(l1: List, l2: List) -> bool: |
| 115 | + if len(l1) != len(l2): |
| 116 | + return False |
| 117 | + d = defaultdict(int) |
| 118 | + for e in l1: |
| 119 | + d[e] = d[e] + 1 |
| 120 | + for e in l2: |
| 121 | + d[e] = d[e] - 1 |
| 122 | + if d[e] < 0: |
| 123 | + return False |
| 124 | + return True |
| 125 | + |
| 126 | + |
| 127 | +def get_constraint_permutation(tab1_sets_by_columns: List[Set], result2: List[Tuple]): |
| 128 | + num_cols = len(result2[0]) |
| 129 | + perm_constraints = [{i for i in range(num_cols)} for _ in range(num_cols)] |
| 130 | + if num_cols <= 3: |
| 131 | + return product(*perm_constraints) |
| 132 | + |
| 133 | + # we sample 20 rows and constrain the space of permutations |
| 134 | + for _ in range(20): |
| 135 | + random_tab2_row = random.choice(result2) |
| 136 | + |
| 137 | + for tab1_col in range(num_cols): |
| 138 | + for tab2_col in set(perm_constraints[tab1_col]): |
| 139 | + if random_tab2_row[tab2_col] not in tab1_sets_by_columns[tab1_col]: |
| 140 | + perm_constraints[tab1_col].remove(tab2_col) |
| 141 | + return product(*perm_constraints) |
| 142 | + |
| 143 | + |
| 144 | +# check whether two denotations are correct |
| 145 | +def result_eq(result1: List[Tuple], result2: List[Tuple], order_matters: bool) -> bool: |
| 146 | + if len(result1) == 0 and len(result2) == 0: |
| 147 | + return True |
| 148 | + |
| 149 | + # if length is not the same, then they are definitely different bag of rows |
| 150 | + if len(result1) != len(result2): |
| 151 | + return False |
| 152 | + |
| 153 | + num_cols = len(result1[0]) |
| 154 | + |
| 155 | + # if the results do not have the same number of columns, they are different |
| 156 | + if len(result2[0]) != num_cols: |
| 157 | + return False |
| 158 | + |
| 159 | + # unorder each row and compare whether the denotation is the same |
| 160 | + # this can already find most pair of denotations that are different |
| 161 | + if not quick_rej(result1, result2, order_matters): |
| 162 | + return False |
| 163 | + |
| 164 | + # the rest of the problem is in fact more complicated than one might think |
| 165 | + # we want to find a permutation of column order and a permutation of row order, |
| 166 | + # s.t. result_1 is the same as result_2 |
| 167 | + # we return true if we can find such column & row permutations |
| 168 | + # and false if we cannot |
| 169 | + tab1_sets_by_columns = [{row[i] for row in result1} for i in range(num_cols)] |
| 170 | + |
| 171 | + # on a high level, we enumerate all possible column permutations that might make result_1 == result_2 |
| 172 | + # we decrease the size of the column permutation space by the function get_constraint_permutation |
| 173 | + # if one of the permutation make result_1, result_2 equivalent, then they are equivalent |
| 174 | + for perm in get_constraint_permutation(tab1_sets_by_columns, result2): |
| 175 | + if len(perm) != len(set(perm)): |
| 176 | + continue |
| 177 | + if num_cols == 1: |
| 178 | + result2_perm = result2 |
| 179 | + else: |
| 180 | + result2_perm = [permute_tuple(element, perm) for element in result2] |
| 181 | + if order_matters: |
| 182 | + if result1 == result2_perm: |
| 183 | + return True |
| 184 | + else: |
| 185 | + # in fact the first condition must hold if the second condition holds |
| 186 | + # but the first is way more efficient implementation-wise |
| 187 | + # and we use it to quickly reject impossible candidates |
| 188 | + if set(result1) == set(result2_perm) and multiset_eq(result1, result2_perm): |
| 189 | + return True |
| 190 | + return False |
| 191 | + |
| 192 | + |
| 193 | +def to_tuples(result: List[Dict]) -> List[Tuple]: |
| 194 | + keys = list(result[0].keys()) |
| 195 | + for row in result: |
| 196 | + assert set(row.keys()) == set(keys) |
| 197 | + return [tuple([row[key] for key in keys]) for row in result] |
| 198 | + |
| 199 | + |
| 200 | +def _compare_execution( |
| 201 | + pred_executed: list[dict], target_executed: list[dict], order_matters: bool |
| 202 | +) -> float: |
| 203 | + """Execution match considering same order of the output""" |
| 204 | + if not pred_executed and not target_executed: |
| 205 | + return 1.0 |
| 206 | + elif not pred_executed or not target_executed: |
| 207 | + return 0.0 |
| 208 | + |
| 209 | + gold_tuples = to_tuples(target_executed) |
| 210 | + pred_tuples = to_tuples(pred_executed) |
| 211 | + return float(result_eq(gold_tuples, pred_tuples, order_matters=order_matters)) |
0 commit comments