11import datetime
2- from typing import TYPE_CHECKING , Optional , Tuple
2+ from typing import TYPE_CHECKING , Optional , Tuple , Union
33
44from aiobotocore .client import AioBaseClient
55from aiobotocore .session import AioSession , get_session
@@ -30,7 +30,7 @@ class DynamoBackend(Backend):
3030 >> FastAPICache.init(dynamodb)
3131 """
3232
33- client : DynamoDBClient
33+ client : Union [ DynamoDBClient , None ]
3434 session : AioSession
3535 table_name : str
3636 region : Optional [str ]
@@ -46,58 +46,63 @@ async def init(self) -> None:
4646 ).__aenter__ ()
4747
4848 async def close (self ) -> None :
49- self .client = await self .client .__aexit__ (None , None , None )
49+ if self .client :
50+ await self .client .__aexit__ (None , None , None )
51+ self .client = None
5052
5153 async def get_with_ttl (self , key : str ) -> Tuple [int , Optional [bytes ]]:
52- response = await self .client .get_item (TableName = self .table_name , Key = {"key" : {"S" : key }})
54+ if self .client :
55+ response = await self .client .get_item (TableName = self .table_name , Key = {"key" : {"S" : key }})
5356
54- if "Item" in response :
55- value = response ["Item" ].get ("value" , {}).get ("B" )
56- ttl = response ["Item" ].get ("ttl" , {}).get ("N" )
57+ if "Item" in response :
58+ value = response ["Item" ].get ("value" , {}).get ("B" )
59+ ttl = response ["Item" ].get ("ttl" , {}).get ("N" )
5760
58- if not ttl :
59- return - 1 , value
61+ if not ttl :
62+ return - 1 , value
6063
61- # It's only eventually consistent so we need to check ourselves
62- expire = int (ttl ) - int (datetime .datetime .now ().timestamp ())
63- if expire > 0 :
64- return expire , value
64+ # It's only eventually consistent so we need to check ourselves
65+ expire = int (ttl ) - int (datetime .datetime .now ().timestamp ())
66+ if expire > 0 :
67+ return expire , value
6568
6669 return 0 , None
6770
6871 async def get (self , key : str ) -> Optional [bytes ]:
69- response = await self .client .get_item (TableName = self .table_name , Key = {"key" : {"S" : key }})
70- if "Item" in response :
71- return response ["Item" ].get ("value" , {}).get ("B" )
72+ if self .client :
73+ response = await self .client .get_item (TableName = self .table_name , Key = {"key" : {"S" : key }})
74+ if "Item" in response :
75+ return response ["Item" ].get ("value" , {}).get ("B" )
7276 return None
7377
7478 async def set (self , key : str , value : bytes , expire : Optional [int ] = None ) -> None :
75- ttl = (
76- {
77- "ttl" : {
78- "N" : str (
79- int (
80- (
81- datetime .datetime .now () + datetime .timedelta (seconds = expire )
82- ).timestamp ()
79+ if self .client :
80+ ttl = (
81+ {
82+ "ttl" : {
83+ "N" : str (
84+ int (
85+ (
86+ datetime .datetime .now () + datetime .timedelta (seconds = expire )
87+ ).timestamp ()
88+ )
8389 )
84- )
90+ }
8591 }
86- }
87- if expire
88- else {}
89- )
90-
91- await self .client .put_item (
92- TableName = self .table_name ,
93- Item = {
94- ** {
95- "key" : {"S" : key },
96- "value" : {"B" : value },
92+ if expire
93+ else {}
94+ )
95+
96+ await self .client .put_item (
97+ TableName = self .table_name ,
98+ Item = {
99+ ** {
100+ "key" : {"S" : key },
101+ "value" : {"B" : value },
102+ },
103+ ** ttl ,
97104 },
98- ** ttl ,
99- },
100- )
105+ )
101106
102107 async def clear (self , namespace : Optional [str ] = None , key : Optional [str ] = None ) -> int :
103108 raise NotImplementedError
0 commit comments