Skip to content

Create tables in the specified schema to avoid moving the tables afterwards #21

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Sep 4, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 25 additions & 48 deletions load_into_pg.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,24 +46,30 @@ def show_progress(block_num, block_size, total_size):
file_part = None
six.print_("")

def getConnectionParameters():
"""Get the parameters for the connection to the database."""

def buildConnectionString(dbname, mbHost, mbPort, mbUsername, mbPassword):
dbConnectionParam = "dbname={}".format(dbname)
parameters = {}

if mbPort is not None:
dbConnectionParam += " port={}".format(mbPort)
if args.dbname:
parameters['dbname'] = args.dbname

if mbHost is not None:
dbConnectionParam += " host={}".format(mbHost)
if args.host:
parameters['host'] = args.host

# TODO Is the escaping done here correct?
if mbUsername is not None:
dbConnectionParam += " user={}".format(mbUsername)
if args.port:
parameters['port'] = args.port

# TODO Is the escaping done here correct?
if mbPassword is not None:
dbConnectionParam += " password={}".format(mbPassword)
return dbConnectionParam
if args.username:
parameters['user'] = args.username

if args.password:
parameters['password'] = args.password

if args.schema_name:
parameters['options'] = "-c search_path=" + args.schema_name

return parameters


def _makeDefValues(keys):
Expand Down Expand Up @@ -174,7 +180,7 @@ def _getTableKeys(table):
return keys


def handleTable(table, insertJson, createFk, mbDbFile, dbConnectionParam):
def handleTable(table, insertJson, createFk, mbDbFile):
"""Handle the table including the post/pre processing."""
keys = _getTableKeys(table)
dbFile = mbDbFile if mbDbFile is not None else table + ".xml"
Expand All @@ -193,7 +199,7 @@ def handleTable(table, insertJson, createFk, mbDbFile, dbConnectionParam):
sys.exit(-1)

try:
with pg.connect(dbConnectionParam) as conn:
with pg.connect(**getConnectionParameters()) as conn:
with conn.cursor() as cur:
try:
with open(dbFile, "rb") as xml:
Expand Down Expand Up @@ -273,29 +279,8 @@ def handleTable(table, insertJson, createFk, mbDbFile, dbConnectionParam):
six.print_("Warning from the database.", file=sys.stderr)
six.print_("pg.Warning: {0}".format(str(w)), file=sys.stderr)


def moveTableToSchema(table, schemaName, dbConnectionParam):
try:
with pg.connect(dbConnectionParam) as conn:
with conn.cursor() as cur:
# create the schema
cur.execute("CREATE SCHEMA IF NOT EXISTS " + schemaName + ";")
conn.commit()
# move the table to the right schema
cur.execute("ALTER TABLE " + table + " SET SCHEMA " + schemaName + ";")
conn.commit()
except pg.Error as e:
six.print_("Error in dealing with the database.", file=sys.stderr)
six.print_("pg.Error ({0}): {1}".format(e.pgcode, e.pgerror), file=sys.stderr)
six.print_(str(e), file=sys.stderr)
except pg.Warning as w:
six.print_("Warning from the database.", file=sys.stderr)
six.print_("pg.Warning: {0}".format(str(w)), file=sys.stderr)


#############################################################


parser = argparse.ArgumentParser()
parser.add_argument(
"-t",
Expand Down Expand Up @@ -384,10 +369,6 @@ def moveTableToSchema(table, schemaName, dbConnectionParam):
except NameError:
pass

dbConnectionParam = buildConnectionString(
args.dbname, args.host, args.port, args.username, args.password
)

# load given file in table
if args.file and args.table:
table = args.table
Expand All @@ -398,14 +379,13 @@ def moveTableToSchema(table, schemaName, dbConnectionParam):
specialRules[("Posts", "Body")] = "NULL"

choice = input("This will drop the {} table. Are you sure [y/n]?".format(table))

if len(choice) > 0 and choice[0].lower() == "y":
handleTable(
table, args.insert_json, args.foreign_keys, args.file, dbConnectionParam
)
table, args.insert_json, args.foreign_keys, args.file)
else:
six.print_("Cancelled.")
if args.schema_name != "public":
moveTableToSchema(table, args.schema_name, dbConnectionParam)

exit(0)

# load a project
Expand Down Expand Up @@ -453,7 +433,7 @@ def moveTableToSchema(table, schemaName, dbConnectionParam):

for table in tables:
six.print_("Load {0}.xml file".format(table))
handleTable(table, args.insert_json, args.foreign_keys, None, dbConnectionParam)
handleTable(table, args.insert_json, args.foreign_keys, None)
# remove file
os.remove(table + ".xml")

Expand All @@ -465,9 +445,6 @@ def moveTableToSchema(table, schemaName, dbConnectionParam):
else:
six.print_("Archive '{0}' deleted".format(filepath))

if args.schema_name != "public":
for table in tables:
moveTableToSchema(table, args.schema_name, dbConnectionParam)
exit(0)

else:
Expand Down