Skip to content

Commit 281561d

Browse files
Clean up sql parser function.
1 parent f403047 commit 281561d

File tree

2 files changed

+20
-30
lines changed

2 files changed

+20
-30
lines changed

datajoint/utils.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,3 +95,19 @@ def safe_copy(src, dest, overwrite=False):
9595
temp_file = dest + '.copying'
9696
shutil.copyfile(src, temp_file)
9797
os.rename(temp_file, dest)
98+
99+
100+
def parse_sql(filepath):
101+
DELIMITER = ';'
102+
statement = ''
103+
with open(filepath, 'rt') as f:
104+
for line in f:
105+
line = line.strip()
106+
if not line.startswith('--') and len(line) > 1:
107+
if not line.startswith('DELIMITER'):
108+
statement += ' ' + line
109+
if line.endswith(DELIMITER):
110+
yield statement[1:]
111+
statement = ''
112+
else:
113+
DELIMITER = line.split()[1]

tests/__init__.py

Lines changed: 4 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import urllib3
1616
import certifi
1717
import shutil
18+
from datajoint.utils import parse_sql
1819

1920
__author__ = 'Edgar Walker, Fabian Sinz, Dimitri Yatsenko, Raphael Guzman'
2021

@@ -118,36 +119,9 @@ def setup_package():
118119
CREATE DATABASE {};
119120
""".format(db_name))
120121

121-
def parse_sql(filename):
122-
stmts = []
123-
DELIMITER = ';'
124-
stmt = ''
125-
for line in open(filename, 'r').readlines():
126-
if not line.strip():
127-
continue
128-
129-
if line.startswith('--'):
130-
continue
131-
132-
if 'DELIMITER' in line:
133-
DELIMITER = line.split()[1]
134-
continue
135-
136-
if (DELIMITER not in line):
137-
stmt += line.replace(DELIMITER, ';')
138-
continue
139-
140-
if stmt:
141-
stmt += line
142-
stmts.append(stmt.strip())
143-
stmt = ''
144-
else:
145-
stmts.append(line.strip())
146-
return stmts
147-
148-
stmts = parse_sql('{}/{}'.format(source, db_file))
149-
for stmt in stmts:
150-
conn_root.query(stmt)
122+
statements = parse_sql('{}/{}'.format(source, db_file))
123+
for s in statements:
124+
conn_root.query(s)
151125

152126
# Add old S3
153127
source = os.path.dirname(os.path.realpath(__file__)) + \

0 commit comments

Comments
 (0)