Skip to content

Commit 01e8e0d

Browse files
ported over pg guard client
1 parent a7aed37 commit 01e8e0d

File tree

1 file changed

+79
-72
lines changed

1 file changed

+79
-72
lines changed
Lines changed: 79 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
from typing import List
1+
from contextlib import contextmanager
2+
from typing import List, Optional
23
from guardrails_api_client import Guard as GuardStruct
34
from guardrails_api.classes.http_error import HttpError
45
from guardrails_api.clients.guard_client import GuardClient
@@ -18,48 +19,20 @@ def __init__(self):
1819
self.initialized = True
1920
self.pgClient = PostgresClient()
2021

21-
def get_db(self): # generator for local sessions
22+
@contextmanager
23+
def get_db_context(self):
2224
db = self.pgClient.SessionLocal()
2325
try:
2426
yield db
2527
finally:
2628
db.close()
2729

28-
def get_guard(self, guard_name: str, as_of_date: str = None) -> GuardStruct:
29-
db = next(self.get_db())
30-
latest_guard_item = db.query(GuardItem).filter_by(name=guard_name).first()
31-
audit_item = None
32-
if as_of_date is not None:
33-
audit_item = (
34-
db.query(GuardItemAudit)
35-
.filter_by(name=guard_name)
36-
.filter(GuardItemAudit.replaced_on > as_of_date)
37-
.order_by(GuardItemAudit.replaced_on.asc())
38-
.first()
39-
)
40-
guard_item = audit_item if audit_item is not None else latest_guard_item
41-
if guard_item is None:
42-
raise HttpError(
43-
status=404,
44-
message="NotFound",
45-
cause="A Guard with the name {guard_name} does not exist!".format(
46-
guard_name=guard_name
47-
),
48-
)
49-
return from_guard_item(guard_item)
50-
51-
def get_guard_item(self, guard_name: str) -> GuardItem:
52-
db = next(self.get_db())
53-
return db.query(GuardItem).filter_by(name=guard_name).first()
30+
def util_get_guard_item(self, guard_name: str, db) -> GuardItem:
31+
item = db.query(GuardItem).filter_by(name=guard_name).first()
32+
return item
5433

55-
def get_guards(self) -> List[GuardStruct]:
56-
db = next(self.get_db())
57-
guard_items = db.query(GuardItem).all()
5834

59-
return [from_guard_item(gi) for gi in guard_items]
60-
61-
def create_guard(self, guard: GuardStruct) -> GuardStruct:
62-
db = next(self.get_db())
35+
def util_create_guard(self, guard: GuardStruct, db) -> GuardStruct:
6336
guard_item = GuardItem(
6437
name=guard.name,
6538
railspec=guard.to_dict(),
@@ -69,48 +42,82 @@ def create_guard(self, guard: GuardStruct) -> GuardStruct:
6942
db.add(guard_item)
7043
db.commit()
7144
return from_guard_item(guard_item)
45+
46+
# Below are used directly by Controllers and start db sessions
7247

73-
def update_guard(self, guard_name: str, guard: GuardStruct) -> GuardStruct:
74-
db = next(self.get_db())
75-
guard_item = self.get_guard_item(guard_name)
76-
if guard_item is None:
77-
raise HttpError(
78-
status=404,
79-
message="NotFound",
80-
cause="A Guard with the name {guard_name} does not exist!".format(
81-
guard_name=guard_name
82-
),
83-
)
84-
# guard_item.num_reasks = guard.num_reasks
85-
guard_item.railspec = guard.to_dict()
86-
guard_item.description = guard.description
87-
db.commit()
88-
return from_guard_item(guard_item)
48+
def get_guard(self, guard_name: str, as_of_date: Optional[str] = None) -> GuardStruct:
49+
with self.get_db_context() as db:
50+
latest_guard_item = db.query(GuardItem).filter_by(name=guard_name).first()
51+
audit_item = None
52+
if as_of_date is not None:
53+
audit_item = (
54+
db.query(GuardItemAudit)
55+
.filter_by(name=guard_name)
56+
.filter(GuardItemAudit.replaced_on > as_of_date)
57+
.order_by(GuardItemAudit.replaced_on.asc())
58+
.first()
59+
)
60+
guard_item = audit_item if audit_item is not None else latest_guard_item
61+
if guard_item is None:
62+
raise HttpError(
63+
status=404,
64+
message="NotFound",
65+
cause="A Guard with the name {guard_name} does not exist!".format(
66+
guard_name=guard_name
67+
),
68+
)
69+
return from_guard_item(guard_item)
8970

90-
def upsert_guard(self, guard_name: str, guard: GuardStruct) -> GuardStruct:
91-
db = next(self.get_db())
92-
guard_item = self.get_guard_item(guard_name)
93-
if guard_item is not None:
71+
def get_guards(self) -> List[GuardStruct]:
72+
with self.get_db_context() as db:
73+
guard_items = db.query(GuardItem).all()
74+
return [from_guard_item(gi) for gi in guard_items]
75+
76+
def create_guard(self, guard: GuardStruct) -> GuardStruct:
77+
with self.get_db_context() as db:
78+
return self.util_create_guard(guard, db)
79+
80+
def update_guard(self, guard_name: str, guard: GuardStruct) -> GuardStruct:
81+
with self.get_db_context() as db:
82+
guard_item = self.util_get_guard_item(guard_name, db)
83+
if guard_item is None:
84+
raise HttpError(
85+
status=404,
86+
message="NotFound",
87+
cause="A Guard with the name {guard_name} does not exist!".format(
88+
guard_name=guard_name
89+
),
90+
)
91+
# guard_item.num_reasks = guard.num_reasks
9492
guard_item.railspec = guard.to_dict()
9593
guard_item.description = guard.description
96-
# guard_item.num_reasks = guard.num_reasks
9794
db.commit()
9895
return from_guard_item(guard_item)
99-
else:
100-
return self.create_guard(guard)
96+
97+
def upsert_guard(self, guard_name: str, guard: GuardStruct) -> GuardStruct:
98+
with self.get_db_context() as db:
99+
guard_item = self.util_get_guard_item(guard_name, db)
100+
if guard_item is not None:
101+
guard_item.railspec = guard.to_dict()
102+
guard_item.description = guard.description
103+
# guard_item.num_reasks = guard.num_reasks
104+
db.commit()
105+
return from_guard_item(guard_item)
106+
else:
107+
return self.util_create_guard(guard, db)
101108

102109
def delete_guard(self, guard_name: str) -> GuardStruct:
103-
db = next(self.get_db())
104-
guard_item = self.get_guard_item(guard_name)
105-
if guard_item is None:
106-
raise HttpError(
107-
status=404,
108-
message="NotFound",
109-
cause="A Guard with the name {guard_name} does not exist!".format(
110-
guard_name=guard_name
111-
),
112-
)
113-
db.delete(guard_item)
114-
db.commit()
115-
guard = from_guard_item(guard_item)
116-
return guard
110+
with self.get_db_context() as db:
111+
guard_item = self.util_get_guard_item(guard_name, db)
112+
if guard_item is None:
113+
raise HttpError(
114+
status=404,
115+
message="NotFound",
116+
cause="A Guard with the name {guard_name} does not exist!".format(
117+
guard_name=guard_name
118+
),
119+
)
120+
db.delete(guard_item)
121+
db.commit()
122+
guard = from_guard_item(guard_item)
123+
return guard

0 commit comments

Comments
 (0)