Skip to content

Commit 6bb7645

Browse files
committed
ordered option + stringify all numbers
1 parent d51f53b commit 6bb7645

File tree

2 files changed

+12
-6
lines changed

2 files changed

+12
-6
lines changed

evaluate.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
parser.add_argument('source_file', help='source file for the prediction')
1313
parser.add_argument('db_file', help='source database for the prediction')
1414
parser.add_argument('pred_file', help='predictions by the model')
15+
parser.add_argument('--ordered', action='store_true', help='whether the exact match should consider the order of conditions')
1516
args = parser.parse_args()
1617

1718
engine = DBEngine(args.db_file)
@@ -21,13 +22,13 @@
2122
for ls, lp in tqdm(zip(fs, fp), total=count_lines(args.source_file)):
2223
eg = json.loads(ls)
2324
ep = json.loads(lp)
24-
qg = Query.from_dict(eg['sql'])
25+
qg = Query.from_dict(eg['sql'], ordered=args.ordered)
2526
gold = engine.execute_query(eg['table_id'], qg, lower=True)
2627
pred = ep.get('error', None)
2728
qp = None
2829
if not ep.get('error', None):
2930
try:
30-
qp = Query.from_dict(ep['query'])
31+
qp = Query.from_dict(ep['query'], ordered=args.ordered)
3132
pred = engine.execute_query(eg['table_id'], qp, lower=True)
3233
except Exception as e:
3334
pred = repr(e)

lib/query.py

+9-4
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,20 @@ class Query:
1313
cond_ops = ['=', '>', '<', 'OP']
1414
syms = ['SELECT', 'WHERE', 'AND', 'COL', 'TABLE', 'CAPTION', 'PAGE', 'SECTION', 'OP', 'COND', 'QUESTION', 'AGG', 'AGGOPS', 'CONDOPS']
1515

16-
def __init__(self, sel_index, agg_index, conditions=tuple()):
16+
def __init__(self, sel_index, agg_index, conditions=tuple(), ordered=False):
1717
self.sel_index = sel_index
1818
self.agg_index = agg_index
1919
self.conditions = list(conditions)
20+
self.ordered = ordered
2021

2122
def __eq__(self, other):
2223
if isinstance(other, self.__class__):
2324
indices = self.sel_index == other.sel_index and self.agg_index == other.agg_index
24-
conds = [(col, op, cond.lower() if isinstance(cond, str) else cond) for col, op, cond in self.conditions] == [(col, op, cond.lower() if isinstance(cond, str) else cond) for col, op, cond in other.conditions]
25+
if other.ordered:
26+
conds = [(col, op, str(cond).lower()) for col, op, cond in self.conditions] == [(col, op, str(cond).lower()) for col, op, cond in other.conditions]
27+
else:
28+
conds = set([(col, op, str(cond).lower()) for col, op, cond in self.conditions]) == set([(col, op, str(cond).lower()) for col, op, cond in other.conditions])
29+
2530
return indices and conds
2631
return NotImplemented
2732

@@ -52,8 +57,8 @@ def lower(self):
5257
return self.__class__(self.sel_index, self.agg_index, conds)
5358

5459
@classmethod
55-
def from_dict(cls, d):
56-
return cls(sel_index=d['sel'], agg_index=d['agg'], conditions=d['conds'])
60+
def from_dict(cls, d, ordered=False):
61+
return cls(sel_index=d['sel'], agg_index=d['agg'], conditions=d['conds'], ordered=ordered)
5762

5863
@classmethod
5964
def from_tokenized_dict(cls, d):

0 commit comments

Comments
 (0)