Skip to content

Commit ceea73f

Browse files
authored
Read database info from configuration file (#330)
1 parent 739d1ca commit ceea73f

File tree

3 files changed

+43
-46
lines changed

3 files changed

+43
-46
lines changed

dashboard/app.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,6 @@
3333
opt_manager = None
3434
cal_manager = None
3535

36-
# load database
37-
db = load_database()
3836
# list of available experiments
3937
experiments = load_experiments()
4038

@@ -60,12 +58,13 @@ def update(
6058
global par_manager
6159
global opt_manager
6260
global cal_manager
63-
# load data
64-
exp_data, sim_data = load_data(db)
6561
# load input and output variables
6662
input_variables, output_variables, simulation_calibration = load_variables(
6763
state.experiment
6864
)
65+
# load data
66+
db = load_database(state.experiment)
67+
exp_data, sim_data = load_data(db)
6968
# reset output
7069
if reset_output:
7170
out_manager = OutputManager(output_variables)
@@ -232,7 +231,7 @@ def find_simulation(event, db):
232231
print(msg)
233232

234233

235-
def open_simulation_dialog(event):
234+
def open_simulation_dialog(event, db):
236235
try:
237236
data_directory, file_path = find_simulation(event, db)
238237
state.simulation_video = file_path.endswith(".mp4")
@@ -314,7 +313,10 @@ def home_route():
314313
figure = plotly.Figure(
315314
display_mode_bar="true",
316315
config={"responsive": True},
317-
click=(open_simulation_dialog, "[utils.safe($event)]"),
316+
click=(
317+
open_simulation_dialog,
318+
"[utils.safe($event), db]",
319+
),
318320
)
319321
ctrl.figure_update = figure.update
320322

dashboard/utils.py

Lines changed: 15 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -106,40 +106,29 @@ def verify_input_variables(model_file, experiment):
106106

107107

108108
@timer
109-
def load_database():
109+
def load_database(experiment):
110110
print("Loading database...")
111-
# load database
112-
db_defaults = {
113-
"host": "mongodb05.nersc.gov",
114-
"port": 27017,
115-
"name": "bella_sf",
116-
"auth": "bella_sf",
117-
"user": "bella_sf_ro",
118-
}
119-
# read database information from environment variables (if unset, use defaults)
120-
db_host = os.getenv("SF_DB_HOST", db_defaults["host"])
121-
db_port = int(os.getenv("SF_DB_PORT", db_defaults["port"]))
122-
db_name = os.getenv("SF_DB_NAME", db_defaults["name"])
123-
db_auth = os.getenv("SF_DB_AUTH_SOURCE", db_defaults["auth"])
124-
db_user = os.getenv("SF_DB_USER", db_defaults["user"])
125-
# read database password from environment variable (no default provided)
126-
db_password = os.getenv("SF_DB_READONLY_PASSWORD")
111+
# load configuration dictionary
112+
config_dict = load_config_dict(experiment)
113+
# read database information from configuration dictionary
114+
db_host = config_dict["database"]["host"]
115+
db_port = config_dict["database"]["port"]
116+
db_name = config_dict["database"]["name"]
117+
db_auth = config_dict["database"]["auth"]
118+
db_username = config_dict["database"]["username_ro"]
119+
db_password_env = config_dict["database"]["password_ro_env"]
120+
db_password = os.getenv(db_password_env)
127121
if db_password is None:
128-
raise RuntimeError("Environment variable SF_DB_READONLY_PASSWORD must be set!")
129-
# SSH forward?
130-
if db_host == "localhost" or db_host == "127.0.0.1":
131-
direct_connection = True
132-
else:
133-
direct_connection = False
122+
raise RuntimeError(f"Environment variable {db_password_env} must be set!")
134123
# get database instance
135124
print(f"Connecting to database {db_name}@{db_host}:{db_port}...")
136125
db = pymongo.MongoClient(
137126
host=db_host,
138127
port=db_port,
139-
username=db_user,
140-
password=db_password,
141128
authSource=db_auth,
142-
directConnection=direct_connection,
129+
username=db_username,
130+
password=db_password,
131+
directConnection=(db_host in ["localhost", "127.0.0.1"]), # SSH forwarding
143132
)[db_name]
144133
return db
145134

ml/train_model.py

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -63,20 +63,6 @@
6363
if model_type not in ["NN", "ensemble_NN", "GP"]:
6464
raise ValueError(f"Invalid model type: {model_type}")
6565

66-
###############################################
67-
# Open credential file for database
68-
###############################################
69-
with open(os.path.join(os.getenv("HOME"), "db.profile")) as f:
70-
db_profile = f.read()
71-
72-
# Connect to the MongoDB database with read-only access
73-
db = pymongo.MongoClient(
74-
host="mongodb05.nersc.gov",
75-
username="bella_sf_admin",
76-
password=re.findall("SF_DB_ADMIN_PASSWORD='(.+)'", db_profile)[0],
77-
authSource="bella_sf",
78-
)["bella_sf"]
79-
8066
# Extract configurations of experiments & models
8167
yaml_dict = None
8268
current_file_directory = os.path.dirname(os.path.abspath(__file__))
@@ -94,6 +80,26 @@
9480
if yaml_dict is None:
9581
raise RuntimeError("File config.yaml not found.")
9682

83+
# Connect to the MongoDB database with read-write access
84+
db_host = yaml_dict["database"]["host"]
85+
db_name = yaml_dict["database"]["name"]
86+
db_auth = yaml_dict["database"]["auth"]
87+
db_username = yaml_dict["database"]["username_rw"]
88+
db_password_env = yaml_dict["database"]["password_rw_env"]
89+
# Look for the password in the profile file
90+
with open(os.path.join(os.getenv("HOME"), "db.profile")) as f:
91+
db_profile = f.read()
92+
match = re.search(f"{db_password_env}='([^']*)'", db_profile)
93+
if not match:
94+
raise RuntimeError(f"Environment variable {db_password_env} must be set")
95+
db_password = match.group(1)
96+
db = pymongo.MongoClient(
97+
host=db_host,
98+
authSource=db_auth,
99+
username=db_username,
100+
password=db_password,
101+
)[db_name]
102+
97103
input_variables = yaml_dict["inputs"]
98104
input_names = [v["name"] for v in input_variables.values()]
99105
output_variables = yaml_dict["outputs"]

0 commit comments

Comments
 (0)