44
55import json
66import os
7+ import time
78import re
89import ssl
910import sys
@@ -83,26 +84,23 @@ def getVO(proxy_data):
8384class BaseRequest (object ):
8485 """This class helps supporting multiple kinds of requests that require connections"""
8586
86- def __init__ (self , url , caPath , name = "unknown" ):
87+ def __init__ (self , url , caPath , pilotUUID , name = "unknown" ):
8788 self .name = name
8889 self .url = url
8990 self .caPath = caPath
9091 self .headers = {
9192 "User-Agent" : "Dirac Pilot [Unknown ID]"
9293 }
94+ self .pilotUUID = pilotUUID
9395 # We assume we have only one context, so this variable could be shared to avoid opening n times a cert.
9496 # On the contrary, to avoid race conditions, we do avoid using "self.data" and "self.headers"
9597 self ._context = None
9698
9799 self ._prepareRequest ()
98100
99- def generateUserAgent (self , pilotUUID ):
100- """To analyse the traffic, we can send a taylor-made User-Agent
101-
102- :param pilotUUID: Unique ID of the Pilot
103- :type pilotUUID: str
104- """
105- self .headers ["User-Agent" ] = "Dirac Pilot [%s]" % pilotUUID
101+ def generateUserAgent (self ):
102+ """To analyse the traffic, we can send a taylor-made User-Agent"""
103+ self .addHeader ("User-Agent" , "Dirac Pilot [%s]" % self .pilotUUID )
106104
107105 def _prepareRequest (self ):
108106 """As previously, loads the SSL certificates of the server (to avoid "unknown issuer")"""
@@ -128,18 +126,18 @@ def executeRequest(self, raw_data, insecure=False, content_type="json"):
128126 """
129127 if content_type == "json" :
130128 data = json .dumps (raw_data ).encode ("utf-8" )
131- self .headers [ "Content-Type" ] = "application/json"
129+ self .addHeader ( "Content-Type" , "application/json" )
132130 elif content_type == "x-www-form-urlencoded" :
133131 if sys .version_info .major == 3 :
134132 data = urlencode (raw_data ).encode ("utf-8" ) # encode to bytes ! for python3
135133 else :
136134 # Python2
137135 data = urlencode (raw_data )
138- self .headers [ "Content-Type" ] = "application/x-www-form-urlencoded"
136+ self .addHeader ( "Content-Type" , "application/x-www-form-urlencoded" )
139137 else :
140138 raise ValueError ("Invalid content_type. Use 'json' or 'x-www-form-urlencoded'." )
141139
142- self .headers [ "Content-Length" ] = str (len (data ))
140+ self .addHeader ( "Content-Length" , str (len (data ) ))
143141
144142 request = Request (self .url , data = data , headers = self .headers , method = "POST" )
145143
@@ -173,21 +171,27 @@ def executeRequest(self, raw_data, insecure=False, content_type="json"):
173171class TokenBasedRequest (BaseRequest ):
174172 """Connected Request with JWT support"""
175173
176- def __init__ (self , url , caPath , jwtData ):
177- super (TokenBasedRequest , self ).__init__ (url , caPath , "TokenBasedConnection" )
178-
174+ def __init__ (self , url , caPath , jwtData , pilotUUID ):
175+ super (TokenBasedRequest , self ).__init__ (url , caPath , pilotUUID , "TokenBasedConnection" )
179176 self .jwtData = jwtData
180177
181178 def addJwtToHeader (self ):
182179 # Adds the JWT in the HTTP request (in the Bearer field)
183- self .headers ["Authorization" ] = "Bearer: %s" % self .jwtData
180+ self .headers ["Authorization" ] = "Bearer: %s" % self .jwtData [ "access_token" ]
184181
182+ def executeRequest (self , raw_data , insecure = False , content_type = "json" , is_token_refreshed = False ):
183+
184+ return super (TokenBasedRequest , self ).executeRequest (
185+ raw_data ,
186+ insecure = insecure ,
187+ content_type = content_type
188+ )
185189
186190class X509BasedRequest (BaseRequest ):
187191 """Connected Request with X509 support"""
188192
189- def __init__ (self , url , caPath , certEnv ):
190- super (X509BasedRequest , self ).__init__ (url , caPath , "X509BasedConnection" )
193+ def __init__ (self , url , caPath , certEnv , pilotUUID ):
194+ super (X509BasedRequest , self ).__init__ (url , caPath , pilotUUID , "X509BasedConnection" )
191195
192196 self .certEnv = certEnv
193197 self ._hasExtraCredentials = False
@@ -210,3 +214,113 @@ def executeRequest(self, raw_data, insecure=False, content_type="json"):
210214 insecure = insecure ,
211215 content_type = content_type
212216 )
217+
218+
219+ def refreshPilotToken (url , pilotUUID , jwt , jwt_lock , clientID ):
220+ """
221+ Refresh the JWT token in a separate thread.
222+
223+ :param str url: Server URL
224+ :param str pilotUUID: Pilot unique ID
225+ :param dict jwt: Shared dict with current JWT; updated in-place
226+ :param threading.Lock jwt_lock: Lock to safely update the jwt dict
227+ :return: None
228+ """
229+
230+ # PRECONDITION: jwt must contain "refresh_token"
231+ if not jwt or "refresh_token" not in jwt :
232+ raise ValueError ("To refresh a token, a pilot needs a JWT with refresh_token" )
233+
234+ # Get CA path from environment
235+ caPath = os .getenv ("X509_CERT_DIR" )
236+
237+ # Create request object with required configuration
238+ config = BaseRequest (
239+ url = "%s/api/auth/token" % url ,
240+ caPath = caPath ,
241+ pilotUUID = pilotUUID
242+ )
243+
244+ # Prepare refresh token payload
245+ payload = {
246+ "grant_type" : "refresh_token" ,
247+ "refresh_token" : jwt ["refresh_token" ],
248+ "client_id" : clientID
249+ }
250+
251+ # Perform the request to refresh the token
252+ response = config .executeRequest (
253+ raw_data = payload ,
254+ insecure = True ,
255+ content_type = "x-www-form-urlencoded"
256+ )
257+
258+ # Ensure thread-safe update of the shared jwt dictionary
259+ jwt_lock .acquire ()
260+ try :
261+ jwt .update (response )
262+ finally :
263+ jwt_lock .release ()
264+
265+
266+ def revokePilotToken (url , pilotUUID , jwt , clientID ):
267+ """
268+ Refresh the JWT token in a separate thread.
269+
270+ :param str url: Server URL
271+ :param str pilotUUID: Pilot unique ID
272+ :param dict jwt: Shared dict with current JWT;
273+ :return: None
274+ """
275+
276+ # PRECONDITION: jwt must contain "refresh_token"
277+ if not jwt or "refresh_token" not in jwt :
278+ raise ValueError ("To refresh a token, a pilot needs a JWT with refresh_token" )
279+
280+ # Get CA path from environment
281+ caPath = os .getenv ("X509_CERT_DIR" )
282+
283+ # Create request object with required configuration
284+ config = BaseRequest (
285+ url = "%s/api/auth/revoke" % url ,
286+ caPath = caPath ,
287+ pilotUUID = pilotUUID
288+ )
289+
290+ # Prepare refresh token payload
291+ payload = {
292+ "refresh_token" : jwt ["refresh_token" ],
293+ "client_id" : clientID
294+ }
295+
296+ # Perform the request to revoke the token
297+ _response = config .executeRequest (
298+ raw_data = payload ,
299+ insecure = True ,
300+ content_type = "x-www-form-urlencoded"
301+ )
302+
303+ # === Token refresher thread function ===
304+ def refreshTokenLoop (url , pilotUUID , jwt , jwt_lock , logger , clientID , interval = 600 ):
305+ """
306+ Periodically refresh the pilot JWT token.
307+
308+ :param str url: DiracX server URL
309+ :param str pilotUUID: Pilot UUID
310+ :param dict jwt: Shared JWT dictionary
311+ :param threading.Lock jwt_lock: Lock to safely update JWT
312+ :param Logger logger: Logger to debug
313+ :param str clientID: ClientID used to refresh tokens
314+ :param int interval: Sleep time between refreshes in seconds
315+ :return: None
316+ """
317+ while True :
318+ time .sleep (interval )
319+
320+ try :
321+ refreshPilotToken (url , pilotUUID , jwt , jwt_lock , clientID )
322+
323+ logger .info ("Token refreshed." )
324+ except Exception as e :
325+ logger .error ("Token refresh failed: %s\n " % str (e ))
326+ continue
0 commit comments