Skip to content

Commit 1a7e0b2

Browse files
Merge pull request #432 from rustprooflabs/cleanup-type-hinting
Code quality improvements
2 parents 298f584 + 59d33f8 commit 1a7e0b2

8 files changed

Lines changed: 248 additions & 177 deletions

File tree

docker/database.py

Lines changed: 45 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@
1111
import sys
1212
import subprocess
1313
import time
14-
from typing import cast
14+
from typing import cast, Any
15+
from pathlib import Path
1516
import psycopg
1617
from psycopg.abc import Query
1718
from psycopg import sql
@@ -79,7 +80,7 @@ def set_db_env_vars():
7980
os.environ['PGOSM_CONN_PG'] = connection_string(admin=True)
8081

8182

82-
def pg_conn_parts() -> dict:
83+
def pg_conn_parts() -> dict[str, str | None]:
8384
"""Returns dictionary of connection parts based on environment variables
8485
if they exist.
8586
"""
@@ -130,7 +131,7 @@ def pg_conn_parts() -> dict:
130131
LOGGER.debug(f'DB Name: {pg_db}')
131132
os.environ['POSTGRES_DB'] = pg_db
132133

133-
pg_details = {'pg_user': pg_user,
134+
pg_details: dict[str, str | None] = {'pg_user': pg_user,
134135
'pg_pass': pg_pass,
135136
'pg_host': pg_host,
136137
'pg_port': pg_port,
@@ -182,15 +183,13 @@ def pg_isready() -> bool:
182183
Uses `pg_version_check()` for simple approach.
183184
"""
184185
try:
185-
result = pg_version_check()
186+
pg_version_check()
186187
except AttributeError:
187188
err_msg = 'Error checking version, likely waiting for Postgres to start.'
188189
err_msg += ' Only an error if it does not go away after a few attempts.'
189190
logging.getLogger('pgosm-flex').warning(err_msg)
190191
return False
191192

192-
if result is None:
193-
return False
194193
return True
195194

196195

@@ -210,7 +209,7 @@ def log_pg_details():
210209

211210
def prepare_pgosm_db(
212211
skip_qgis_style: bool
213-
, db_path: str
212+
, db_path: Path
214213
, import_mode: helpers.ImportMode
215214
, schema_name: str
216215
):
@@ -258,17 +257,17 @@ def start_import(
258257
pgosm_region: str
259258
, pgosm_date: str
260259
, srid: int
261-
, language: str
260+
, language: str | None
262261
, layerset: str
263262
, git_info: str
264263
, osm2pgsql_version: str
265264
, import_mode: helpers.ImportMode
266265
, schema_name: str
267-
, input_file: str
266+
, input_file: str | None
268267
) -> int:
269268
"""Creates record in `osm.pgosm_flex` table and returns `id` from `osm.pgosm_flex`.
270269
"""
271-
params = {'pgosm_region': pgosm_region
270+
params: dict[str, Any] = {'pgosm_region': pgosm_region
272271
, 'pgosm_date': pgosm_date
273272
, 'srid': srid
274273
, 'language': language
@@ -292,12 +291,18 @@ def start_import(
292291
RETURNING id
293292
;
294293
"""
295-
sql_raw = sql_raw.format(schema_name=schema_name)
296-
# FIXME: Why os environ here instead of get conn string???
294+
sql_formatted = sql.SQL(sql_raw).format(
295+
schema_name=sql.Identifier(schema_name)
296+
)
297297
with get_db_conn(conn_string=connection_string()) as conn:
298298
cur = conn.cursor()
299-
cur.execute(sql_raw, params=params)
300-
import_id = cur.fetchone()[0]
299+
cur.execute(sql_formatted, params=params)
300+
row = cur.fetchone()
301+
if row:
302+
import_id = int(row[0])
303+
else:
304+
msg = 'Invalid response. `import_id` should never be missing.'
305+
raise ValueError(msg)
301306

302307
return import_id
303308

@@ -318,6 +323,9 @@ def pg_version_check() -> int:
318323
cur.execute(sql_raw)
319324
results = cur.fetchone()
320325

326+
if not results:
327+
raise ValueError('Unable to return Postgres version number. Likely another error going on.')
328+
321329
# It's an int https://www.postgresql.org/docs/current/runtime-config-preset.html#GUC-SERVER-VERSION-NUM
322330
pg_version = int(results[0])
323331
if pg_version < 120000:
@@ -374,15 +382,15 @@ def create_pgosm_db() -> bool:
374382
return True
375383

376384

377-
def prepare_osm_schema(db_path: str, skip_qgis_style: bool, schema_name: str):
385+
def prepare_osm_schema(db_path: Path, skip_qgis_style: bool, schema_name: str):
378386
"""Runs deploy scripts to prepare the PgOSM Flex database.
379387
380388
This function's code could be simplified, but currently I like the verbosity
381389
of it. It doesn't need to stay like this forever, but for now... it's fine.
382390
383391
Parameters
384392
---------------------------
385-
db_path : str
393+
db_path : Path
386394
Path to folder with SQL scripts.
387395
skip_qgis_style : bool
388396
scheme_name : str
@@ -405,10 +413,10 @@ def prepare_osm_schema(db_path: str, skip_qgis_style: bool, schema_name: str):
405413
else:
406414
LOGGER.info('Loading QGIS styles')
407415
qgis_styles.load_qgis_styles(db_path=db_path,
408-
db_name=pg_conn_parts()['pg_db'])
416+
db_name=str(pg_conn_parts()['pg_db']))
409417

410418

411-
def run_insert_pgosm_road(db_path: str, schema_name: str):
419+
def run_insert_pgosm_road(db_path: Path, schema_name: str):
412420
"""Runs script to load data to `pgosm.road` table.
413421
"""
414422
sql_filename = 'roads-us.sql'
@@ -421,14 +429,14 @@ def run_insert_pgosm_road(db_path: str, schema_name: str):
421429

422430

423431
def run_deploy_file(
424-
db_path: str
432+
db_path: Path
425433
, sql_filename: str
426434
, schema_name: str
427435
, subfolder: str='deploy'
428436
):
429437
"""Run a SQL script under the deploy path. Used to setup PgOSM Flex DB.
430438
"""
431-
full_path = os.path.join(db_path, subfolder, sql_filename)
439+
full_path = db_path / subfolder / sql_filename
432440
LOGGER.info(f'Deploying {full_path}')
433441

434442
with open(full_path) as f:
@@ -457,7 +465,7 @@ def get_db_conn(conn_string: str) -> psycopg.Connection:
457465
return conn
458466

459467

460-
def pgosm_after_import(flex_path: str) -> bool:
468+
def pgosm_after_import(flex_path: Path) -> bool:
461469
"""Runs post-processing SQL via Lua script.
462470
463471
Layerset logic is established via environment variable, must happen
@@ -469,7 +477,7 @@ def pgosm_after_import(flex_path: str) -> bool:
469477

470478
output = subprocess.run(cmds,
471479
text=True,
472-
cwd=flex_path,
480+
cwd=str(flex_path),
473481
check=False,
474482
stdout=subprocess.PIPE,
475483
stderr=subprocess.STDOUT)
@@ -554,7 +562,7 @@ def osm2pgsql_replication_finish(skip_nested: bool):
554562
cur.execute(sql_raw, params)
555563

556564

557-
def run_pg_dump(export_path: str, skip_qgis_style: bool):
565+
def run_pg_dump(export_path: Path, skip_qgis_style: bool):
558566
"""Runs `pg_dump` to save processed data to load into other PostGIS DBs.
559567
"""
560568
logger = logging.getLogger('pgosm-flex')
@@ -565,25 +573,25 @@ def run_pg_dump(export_path: str, skip_qgis_style: bool):
565573
logger.info(f'Running pg_dump (only {schema_name} schema)')
566574
cmds = ['pg_dump', '-d', conn_string,
567575
f'--schema={schema_name}',
568-
'-f', export_path]
576+
'-f', str(export_path)]
569577
else:
570578
logger.info(f'Running pg_dump ({schema_name} schema plus extras)')
571579
cmds = ['pg_dump', '-d', conn_string,
572580
f'--schema={schema_name}',
573581
'--schema=pgosm',
574582
'--schema=public',
575-
'-f', export_path]
583+
'-f', str(export_path)]
576584

577585
output = subprocess.run(cmds,
578586
text=True,
579587
capture_output=True,
580588
check=False)
581589
LOGGER.info(f'pg_dump complete, saved to {export_path}')
582590
LOGGER.debug(f'pg_dump output: \n {output.stderr}')
583-
fix_pg_dump_create_public(export_path)
591+
fix_pg_dump_create_public(export_path=export_path)
584592

585593

586-
def fix_pg_dump_create_public(export_path: str):
594+
def fix_pg_dump_create_public(export_path: Path):
587595
"""Using pg_dump with `--schema=public` results in
588596
a .sql script containing `CREATE SCHEMA public;`, nearly always breaks
589597
in target DB. Replaces with `CREATE SCHEMA IF NOT EXISTS public;`
@@ -607,16 +615,19 @@ def log_import_message(import_id: int, msg: str, schema_name: str):
607615
AND pf.id = %(import_id)s
608616
;
609617
"""
610-
sql_raw = sql_raw.format(schema_name=schema_name)
618+
sql_formatted = sql.SQL(sql_raw).format(
619+
schema_name=sql.Identifier(schema_name)
620+
)
611621
with get_db_conn(conn_string=os.environ['PGOSM_CONN']) as conn:
612-
params = {'import_id': import_id,
613-
'msg': msg
614-
}
622+
params: dict[str, int | str] = {
623+
'import_id': import_id
624+
, 'msg': msg
625+
}
615626
cur = conn.cursor()
616-
cur.execute(sql_raw, params=params)
627+
cur.execute(sql_formatted, params=params)
617628

618629

619-
def get_prior_import(schema_name: str) -> dict:
630+
def get_prior_import(schema_name: str) -> dict[str, Any]:
620631
"""Gets the latest import details from `osm.pgosm_flex`.
621632
"""
622633
sql_raw = """
@@ -636,6 +647,6 @@ def get_prior_import(schema_name: str) -> dict:
636647
results = cur.execute(sql_raw).fetchone()
637648

638649
if isinstance(results, type(None)):
639-
results = {}
650+
results: dict[str, Any] = {}
640651

641652
return results

0 commit comments

Comments
 (0)