1- from typing import List
1+ from contextlib import contextmanager
2+ from typing import List , Optional
23from guardrails_api_client import Guard as GuardStruct
34from guardrails_api .classes .http_error import HttpError
45from guardrails_api .clients .guard_client import GuardClient
@@ -18,48 +19,20 @@ def __init__(self):
1819 self .initialized = True
1920 self .pgClient = PostgresClient ()
2021
21- def get_db (self ): # generator for local sessions
22+ @contextmanager
23+ def get_db_context (self ):
2224 db = self .pgClient .SessionLocal ()
2325 try :
2426 yield db
2527 finally :
2628 db .close ()
2729
28- def get_guard (self , guard_name : str , as_of_date : str = None ) -> GuardStruct :
29- db = next (self .get_db ())
30- latest_guard_item = db .query (GuardItem ).filter_by (name = guard_name ).first ()
31- audit_item = None
32- if as_of_date is not None :
33- audit_item = (
34- db .query (GuardItemAudit )
35- .filter_by (name = guard_name )
36- .filter (GuardItemAudit .replaced_on > as_of_date )
37- .order_by (GuardItemAudit .replaced_on .asc ())
38- .first ()
39- )
40- guard_item = audit_item if audit_item is not None else latest_guard_item
41- if guard_item is None :
42- raise HttpError (
43- status = 404 ,
44- message = "NotFound" ,
45- cause = "A Guard with the name {guard_name} does not exist!" .format (
46- guard_name = guard_name
47- ),
48- )
49- return from_guard_item (guard_item )
50-
51- def get_guard_item (self , guard_name : str ) -> GuardItem :
52- db = next (self .get_db ())
53- return db .query (GuardItem ).filter_by (name = guard_name ).first ()
30+ def util_get_guard_item (self , guard_name : str , db ) -> GuardItem :
31+ item = db .query (GuardItem ).filter_by (name = guard_name ).first ()
32+ return item
5433
55- def get_guards (self ) -> List [GuardStruct ]:
56- db = next (self .get_db ())
57- guard_items = db .query (GuardItem ).all ()
5834
59- return [from_guard_item (gi ) for gi in guard_items ]
60-
61- def create_guard (self , guard : GuardStruct ) -> GuardStruct :
62- db = next (self .get_db ())
35+ def util_create_guard (self , guard : GuardStruct , db ) -> GuardStruct :
6336 guard_item = GuardItem (
6437 name = guard .name ,
6538 railspec = guard .to_dict (),
@@ -69,48 +42,82 @@ def create_guard(self, guard: GuardStruct) -> GuardStruct:
6942 db .add (guard_item )
7043 db .commit ()
7144 return from_guard_item (guard_item )
45+
46+ # Below are used directly by Controllers and start db sessions
7247
73- def update_guard (self , guard_name : str , guard : GuardStruct ) -> GuardStruct :
74- db = next (self .get_db ())
75- guard_item = self .get_guard_item (guard_name )
76- if guard_item is None :
77- raise HttpError (
78- status = 404 ,
79- message = "NotFound" ,
80- cause = "A Guard with the name {guard_name} does not exist!" .format (
81- guard_name = guard_name
82- ),
83- )
84- # guard_item.num_reasks = guard.num_reasks
85- guard_item .railspec = guard .to_dict ()
86- guard_item .description = guard .description
87- db .commit ()
88- return from_guard_item (guard_item )
48+ def get_guard (self , guard_name : str , as_of_date : Optional [str ] = None ) -> GuardStruct :
49+ with self .get_db_context () as db :
50+ latest_guard_item = db .query (GuardItem ).filter_by (name = guard_name ).first ()
51+ audit_item = None
52+ if as_of_date is not None :
53+ audit_item = (
54+ db .query (GuardItemAudit )
55+ .filter_by (name = guard_name )
56+ .filter (GuardItemAudit .replaced_on > as_of_date )
57+ .order_by (GuardItemAudit .replaced_on .asc ())
58+ .first ()
59+ )
60+ guard_item = audit_item if audit_item is not None else latest_guard_item
61+ if guard_item is None :
62+ raise HttpError (
63+ status = 404 ,
64+ message = "NotFound" ,
65+ cause = "A Guard with the name {guard_name} does not exist!" .format (
66+ guard_name = guard_name
67+ ),
68+ )
69+ return from_guard_item (guard_item )
8970
90- def upsert_guard (self , guard_name : str , guard : GuardStruct ) -> GuardStruct :
91- db = next (self .get_db ())
92- guard_item = self .get_guard_item (guard_name )
93- if guard_item is not None :
71+ def get_guards (self ) -> List [GuardStruct ]:
72+ with self .get_db_context () as db :
73+ guard_items = db .query (GuardItem ).all ()
74+ return [from_guard_item (gi ) for gi in guard_items ]
75+
76+ def create_guard (self , guard : GuardStruct ) -> GuardStruct :
77+ with self .get_db_context () as db :
78+ return self .util_create_guard (guard , db )
79+
80+ def update_guard (self , guard_name : str , guard : GuardStruct ) -> GuardStruct :
81+ with self .get_db_context () as db :
82+ guard_item = self .util_get_guard_item (guard_name , db )
83+ if guard_item is None :
84+ raise HttpError (
85+ status = 404 ,
86+ message = "NotFound" ,
87+ cause = "A Guard with the name {guard_name} does not exist!" .format (
88+ guard_name = guard_name
89+ ),
90+ )
91+ # guard_item.num_reasks = guard.num_reasks
9492 guard_item .railspec = guard .to_dict ()
9593 guard_item .description = guard .description
96- # guard_item.num_reasks = guard.num_reasks
9794 db .commit ()
9895 return from_guard_item (guard_item )
99- else :
100- return self .create_guard (guard )
96+
97+ def upsert_guard (self , guard_name : str , guard : GuardStruct ) -> GuardStruct :
98+ with self .get_db_context () as db :
99+ guard_item = self .util_get_guard_item (guard_name , db )
100+ if guard_item is not None :
101+ guard_item .railspec = guard .to_dict ()
102+ guard_item .description = guard .description
103+ # guard_item.num_reasks = guard.num_reasks
104+ db .commit ()
105+ return from_guard_item (guard_item )
106+ else :
107+ return self .util_create_guard (guard , db )
101108
102109 def delete_guard (self , guard_name : str ) -> GuardStruct :
103- db = next ( self .get_db ())
104- guard_item = self .get_guard_item (guard_name )
105- if guard_item is None :
106- raise HttpError (
107- status = 404 ,
108- message = "NotFound" ,
109- cause = "A Guard with the name {guard_name} does not exist!" .format (
110- guard_name = guard_name
111- ),
112- )
113- db .delete (guard_item )
114- db .commit ()
115- guard = from_guard_item (guard_item )
116- return guard
110+ with self .get_db_context () as db :
111+ guard_item = self .util_get_guard_item (guard_name , db )
112+ if guard_item is None :
113+ raise HttpError (
114+ status = 404 ,
115+ message = "NotFound" ,
116+ cause = "A Guard with the name {guard_name} does not exist!" .format (
117+ guard_name = guard_name
118+ ),
119+ )
120+ db .delete (guard_item )
121+ db .commit ()
122+ guard = from_guard_item (guard_item )
123+ return guard
0 commit comments