1111import sys
1212import subprocess
1313import time
14- from typing import cast
14+ from typing import cast , Any
15+ from pathlib import Path
1516import psycopg
1617from psycopg .abc import Query
1718from psycopg import sql
@@ -79,7 +80,7 @@ def set_db_env_vars():
7980 os .environ ['PGOSM_CONN_PG' ] = connection_string (admin = True )
8081
8182
82- def pg_conn_parts () -> dict :
83+ def pg_conn_parts () -> dict [ str , str | None ] :
8384 """Returns dictionary of connection parts based on environment variables
8485 if they exist.
8586 """
@@ -130,7 +131,7 @@ def pg_conn_parts() -> dict:
130131 LOGGER .debug (f'DB Name: { pg_db } ' )
131132 os .environ ['POSTGRES_DB' ] = pg_db
132133
133- pg_details = {'pg_user' : pg_user ,
134+ pg_details : dict [ str , str | None ] = {'pg_user' : pg_user ,
134135 'pg_pass' : pg_pass ,
135136 'pg_host' : pg_host ,
136137 'pg_port' : pg_port ,
@@ -182,15 +183,13 @@ def pg_isready() -> bool:
182183 Uses `pg_version_check()` for simple approach.
183184 """
184185 try :
185- result = pg_version_check ()
186+ pg_version_check ()
186187 except AttributeError :
187188 err_msg = 'Error checking version, likely waiting for Postgres to start.'
188189 err_msg += ' Only an error if it does not go away after a few attempts.'
189190 logging .getLogger ('pgosm-flex' ).warning (err_msg )
190191 return False
191192
192- if result is None :
193- return False
194193 return True
195194
196195
@@ -210,7 +209,7 @@ def log_pg_details():
210209
211210def prepare_pgosm_db (
212211 skip_qgis_style : bool
213- , db_path : str
212+ , db_path : Path
214213 , import_mode : helpers .ImportMode
215214 , schema_name : str
216215 ):
@@ -258,17 +257,17 @@ def start_import(
258257 pgosm_region : str
259258 , pgosm_date : str
260259 , srid : int
261- , language : str
260+ , language : str | None
262261 , layerset : str
263262 , git_info : str
264263 , osm2pgsql_version : str
265264 , import_mode : helpers .ImportMode
266265 , schema_name : str
267- , input_file : str
266+ , input_file : str | None
268267 ) -> int :
269268 """Creates record in `osm.pgosm_flex` table and returns `id` from `osm.pgosm_flex`.
270269 """
271- params = {'pgosm_region' : pgosm_region
270+ params : dict [ str , Any ] = {'pgosm_region' : pgosm_region
272271 , 'pgosm_date' : pgosm_date
273272 , 'srid' : srid
274273 , 'language' : language
@@ -292,12 +291,18 @@ def start_import(
292291 RETURNING id
293292;
294293"""
295- sql_raw = sql_raw .format (schema_name = schema_name )
296- # FIXME: Why os environ here instead of get conn string???
294+ sql_formatted = sql .SQL (sql_raw ).format (
295+ schema_name = sql .Identifier (schema_name )
296+ )
297297 with get_db_conn (conn_string = connection_string ()) as conn :
298298 cur = conn .cursor ()
299- cur .execute (sql_raw , params = params )
300- import_id = cur .fetchone ()[0 ]
299+ cur .execute (sql_formatted , params = params )
300+ row = cur .fetchone ()
301+ if row :
302+ import_id = int (row [0 ])
303+ else :
304+ msg = 'Invalid response. `import_id` should never be missing.'
305+ raise ValueError (msg )
301306
302307 return import_id
303308
@@ -318,6 +323,9 @@ def pg_version_check() -> int:
318323 cur .execute (sql_raw )
319324 results = cur .fetchone ()
320325
326+ if not results :
327+ raise ValueError ('Unable to return Postgres version number. Likely another error going on.' )
328+
321329 # It's an int https://www.postgresql.org/docs/current/runtime-config-preset.html#GUC-SERVER-VERSION-NUM
322330 pg_version = int (results [0 ])
323331 if pg_version < 120000 :
@@ -374,15 +382,15 @@ def create_pgosm_db() -> bool:
374382 return True
375383
376384
377- def prepare_osm_schema (db_path : str , skip_qgis_style : bool , schema_name : str ):
385+ def prepare_osm_schema (db_path : Path , skip_qgis_style : bool , schema_name : str ):
378386 """Runs deploy scripts to prepare the PgOSM Flex database.
379387
380388 This function's code could be simplified, but currently I like the verbosity
381389 of it. It doesn't need to stay like this forever, but for now... it's fine.
382390
383391 Parameters
384392 ---------------------------
385- db_path : str
393+ db_path : Path
386394 Path to folder with SQL scripts.
387395 skip_qgis_style : bool
388396 scheme_name : str
@@ -405,10 +413,10 @@ def prepare_osm_schema(db_path: str, skip_qgis_style: bool, schema_name: str):
405413 else :
406414 LOGGER .info ('Loading QGIS styles' )
407415 qgis_styles .load_qgis_styles (db_path = db_path ,
408- db_name = pg_conn_parts ()['pg_db' ])
416+ db_name = str ( pg_conn_parts ()['pg_db' ]) )
409417
410418
411- def run_insert_pgosm_road (db_path : str , schema_name : str ):
419+ def run_insert_pgosm_road (db_path : Path , schema_name : str ):
412420 """Runs script to load data to `pgosm.road` table.
413421 """
414422 sql_filename = 'roads-us.sql'
@@ -421,14 +429,14 @@ def run_insert_pgosm_road(db_path: str, schema_name: str):
421429
422430
423431def run_deploy_file (
424- db_path : str
432+ db_path : Path
425433 , sql_filename : str
426434 , schema_name : str
427435 , subfolder : str = 'deploy'
428436 ):
429437 """Run a SQL script under the deploy path. Used to setup PgOSM Flex DB.
430438 """
431- full_path = os . path . join ( db_path , subfolder , sql_filename )
439+ full_path = db_path / subfolder / sql_filename
432440 LOGGER .info (f'Deploying { full_path } ' )
433441
434442 with open (full_path ) as f :
@@ -457,7 +465,7 @@ def get_db_conn(conn_string: str) -> psycopg.Connection:
457465 return conn
458466
459467
460- def pgosm_after_import (flex_path : str ) -> bool :
468+ def pgosm_after_import (flex_path : Path ) -> bool :
461469 """Runs post-processing SQL via Lua script.
462470
463471 Layerset logic is established via environment variable, must happen
@@ -469,7 +477,7 @@ def pgosm_after_import(flex_path: str) -> bool:
469477
470478 output = subprocess .run (cmds ,
471479 text = True ,
472- cwd = flex_path ,
480+ cwd = str ( flex_path ) ,
473481 check = False ,
474482 stdout = subprocess .PIPE ,
475483 stderr = subprocess .STDOUT )
@@ -554,7 +562,7 @@ def osm2pgsql_replication_finish(skip_nested: bool):
554562 cur .execute (sql_raw , params )
555563
556564
557- def run_pg_dump (export_path : str , skip_qgis_style : bool ):
565+ def run_pg_dump (export_path : Path , skip_qgis_style : bool ):
558566 """Runs `pg_dump` to save processed data to load into other PostGIS DBs.
559567 """
560568 logger = logging .getLogger ('pgosm-flex' )
@@ -565,25 +573,25 @@ def run_pg_dump(export_path: str, skip_qgis_style: bool):
565573 logger .info (f'Running pg_dump (only { schema_name } schema)' )
566574 cmds = ['pg_dump' , '-d' , conn_string ,
567575 f'--schema={ schema_name } ' ,
568- '-f' , export_path ]
576+ '-f' , str ( export_path ) ]
569577 else :
570578 logger .info (f'Running pg_dump ({ schema_name } schema plus extras)' )
571579 cmds = ['pg_dump' , '-d' , conn_string ,
572580 f'--schema={ schema_name } ' ,
573581 '--schema=pgosm' ,
574582 '--schema=public' ,
575- '-f' , export_path ]
583+ '-f' , str ( export_path ) ]
576584
577585 output = subprocess .run (cmds ,
578586 text = True ,
579587 capture_output = True ,
580588 check = False )
581589 LOGGER .info (f'pg_dump complete, saved to { export_path } ' )
582590 LOGGER .debug (f'pg_dump output: \n { output .stderr } ' )
583- fix_pg_dump_create_public (export_path )
591+ fix_pg_dump_create_public (export_path = export_path )
584592
585593
586- def fix_pg_dump_create_public (export_path : str ):
594+ def fix_pg_dump_create_public (export_path : Path ):
587595 """Using pg_dump with `--schema=public` results in
588596 a .sql script containing `CREATE SCHEMA public;`, nearly always breaks
589597 in target DB. Replaces with `CREATE SCHEMA IF NOT EXISTS public;`
@@ -607,16 +615,19 @@ def log_import_message(import_id: int, msg: str, schema_name: str):
607615 AND pf.id = %(import_id)s
608616;
609617"""
610- sql_raw = sql_raw .format (schema_name = schema_name )
618+ sql_formatted = sql .SQL (sql_raw ).format (
619+ schema_name = sql .Identifier (schema_name )
620+ )
611621 with get_db_conn (conn_string = os .environ ['PGOSM_CONN' ]) as conn :
612- params = {'import_id' : import_id ,
613- 'msg' : msg
614- }
622+ params : dict [str , int | str ] = {
623+ 'import_id' : import_id
624+ , 'msg' : msg
625+ }
615626 cur = conn .cursor ()
616- cur .execute (sql_raw , params = params )
627+ cur .execute (sql_formatted , params = params )
617628
618629
619- def get_prior_import (schema_name : str ) -> dict :
630+ def get_prior_import (schema_name : str ) -> dict [ str , Any ] :
620631 """Gets the latest import details from `osm.pgosm_flex`.
621632 """
622633 sql_raw = """
@@ -636,6 +647,6 @@ def get_prior_import(schema_name: str) -> dict:
636647 results = cur .execute (sql_raw ).fetchone ()
637648
638649 if isinstance (results , type (None )):
639- results = {}
650+ results : dict [ str , Any ] = {}
640651
641652 return results
0 commit comments