Skip to content

Commit 1267c8c

Browse files
feat: Add thread to refresh tokens
1 parent e2dba24 commit 1267c8c

File tree

2 files changed

+47
-45
lines changed

2 files changed

+47
-45
lines changed

Pilot/pilotCommands.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -587,6 +587,7 @@ class PilotLoginX(CommandBase):
587587
"""The pilot logs in and fetches their JWT.
588588
589589
.. note:: This command is only compatible with DiracX, and requires Dirac version >= 9.0
590+
.. note:: This command will start a new thread to refresh tokens regularly
590591
"""
591592

592593

@@ -611,10 +612,6 @@ def execute(self):
611612
self.log.error("DiracXServer (url) not given, exiting...")
612613
sys.exit(-1)
613614

614-
if not self.pp.clientID:
615-
self.log.error("ClientID not given, exiting...")
616-
sys.exit(-1)
617-
618615
self.log.info("Fetching JWT in DiracX (URL: %s)" % self.pp.diracXServer)
619616

620617
config = BaseRequest(
@@ -648,7 +645,6 @@ def execute(self):
648645
self.pp.jwt,
649646
self.jwt_lock,
650647
self.log,
651-
self.pp.clientID,
652648
self.pp.refreshTokenEvery
653649
)
654650
)

Pilot/proxyTools.py

Lines changed: 46 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -127,17 +127,23 @@ def executeRequest(self, raw_data, insecure=False, content_type="json"):
127127
if content_type == "json":
128128
data = json.dumps(raw_data).encode("utf-8")
129129
self.addHeader("Content-Type", "application/json")
130-
elif content_type == "x-www-form-urlencoded":
131-
if sys.version_info.major == 3:
132-
data = urlencode(raw_data).encode("utf-8") # encode to bytes ! for python3
133-
else:
134-
# Python2
135-
data = urlencode(raw_data)
136-
self.addHeader("Content-Type", "application/x-www-form-urlencoded")
130+
self.addHeader("Content-Length", str(len(data)))
137131
else:
138-
raise ValueError("Invalid content_type. Use 'json' or 'x-www-form-urlencoded'.")
139132

140-
self.addHeader("Content-Length", str(len(data)))
133+
data = urlencode(raw_data)
134+
135+
if content_type == "x-www-form-urlencoded":
136+
if sys.version_info.major == 3:
137+
data = urlencode(raw_data).encode("utf-8") # encode to bytes ! for python3
138+
139+
self.addHeader("Content-Type", "application/x-www-form-urlencoded")
140+
self.addHeader("Content-Length", str(len(data)))
141+
elif content_type == "query":
142+
self.url = self.url + "?" + data
143+
data = None # No body
144+
else:
145+
raise ValueError("Invalid content_type. Use 'json' or 'x-www-form-urlencoded'.")
146+
141147

142148
request = Request(self.url, data=data, headers=self.headers, method="POST")
143149

@@ -150,22 +156,26 @@ def executeRequest(self, raw_data, insecure=False, content_type="json"):
150156
ctx.check_hostname = False
151157
ctx.verify_mode = ssl.CERT_NONE
152158

153-
if sys.version_info.major == 3:
154-
# Python 3 code
155-
with urlopen(request, context=ctx) as res:
156-
response_data = res.read().decode("utf-8") # Decode response bytes
157-
else:
158-
# Python 2 code
159-
res = urlopen(request, context=ctx)
160-
try:
161-
response_data = res.read()
162-
finally:
163-
res.close()
159+
160+
try:
161+
if sys.version_info.major == 3:
162+
# Python 3 code
163+
with urlopen(request, context=ctx) as res:
164+
response_data = res.read().decode("utf-8") # Decode response bytes
165+
else:
166+
# Python 2 code
167+
res = urlopen(request, context=ctx)
168+
try:
169+
response_data = res.read()
170+
finally:
171+
res.close()
172+
except HTTPError as e:
173+
raise RuntimeError("HTTPError : %s" % e.read().decode())
164174

165175
try:
166176
return json.loads(response_data) # Parse JSON response
167177
except ValueError: # In Python 2, json.JSONDecodeError is a subclass of ValueError
168-
raise Exception("Invalid JSON response: %s" % response_data)
178+
raise ValueError("Invalid JSON response: %s" % response_data)
169179

170180

171181
class TokenBasedRequest(BaseRequest):
@@ -174,12 +184,13 @@ class TokenBasedRequest(BaseRequest):
174184
def __init__(self, url, caPath, jwtData, pilotUUID):
175185
super(TokenBasedRequest, self).__init__(url, caPath, pilotUUID, "TokenBasedConnection")
176186
self.jwtData = jwtData
187+
self.addJwtToHeader()
177188

178189
def addJwtToHeader(self):
179190
# Adds the JWT in the HTTP request (in the Bearer field)
180-
self.headers["Authorization"] = "Bearer: %s" % self.jwtData["access_token"]
191+
self.headers["Authorization"] = "Bearer %s" % self.jwtData["access_token"]
181192

182-
def executeRequest(self, raw_data, insecure=False, content_type="json", is_token_refreshed=False):
193+
def executeRequest(self, raw_data, insecure=False, content_type="json"):
183194

184195
return super(TokenBasedRequest, self).executeRequest(
185196
raw_data,
@@ -216,7 +227,7 @@ def executeRequest(self, raw_data, insecure=False, content_type="json"):
216227
)
217228

218229

219-
def refreshPilotToken(url, pilotUUID, jwt, jwt_lock, clientID):
230+
def refreshPilotToken(url, pilotUUID, jwt, jwt_lock):
220231
"""
221232
Refresh the JWT token in a separate thread.
222233
@@ -235,24 +246,19 @@ def refreshPilotToken(url, pilotUUID, jwt, jwt_lock, clientID):
235246
caPath = os.getenv("X509_CERT_DIR")
236247

237248
# Create request object with required configuration
238-
config = BaseRequest(
239-
url="%s/api/auth/token" % url,
249+
config = TokenBasedRequest(
250+
url="%s/api/pilots/refresh-token" % url,
240251
caPath=caPath,
241-
pilotUUID=pilotUUID
252+
pilotUUID=pilotUUID,
253+
jwtData=jwt
242254
)
243255

244-
# Prepare refresh token payload
245-
payload = {
246-
"grant_type": "refresh_token",
247-
"refresh_token": jwt["refresh_token"],
248-
"client_id": clientID
249-
}
250-
251256
# Perform the request to refresh the token
252257
response = config.executeRequest(
253-
raw_data=payload,
258+
raw_data={
259+
"refresh_token": jwt["refresh_token"]
260+
},
254261
insecure=True,
255-
content_type="x-www-form-urlencoded"
256262
)
257263

258264
# Ensure thread-safe update of the shared jwt dictionary
@@ -269,6 +275,7 @@ def revokePilotToken(url, pilotUUID, jwt, clientID):
269275
270276
:param str url: Server URL
271277
:param str pilotUUID: Pilot unique ID
278+
:param str clientID: ClientID used to revoke tokens
272279
:param dict jwt: Shared dict with current JWT;
273280
:return: None
274281
"""
@@ -297,11 +304,11 @@ def revokePilotToken(url, pilotUUID, jwt, clientID):
297304
_response = config.executeRequest(
298305
raw_data=payload,
299306
insecure=True,
300-
content_type="x-www-form-urlencoded"
307+
content_type="query"
301308
)
302309

303310
# === Token refresher thread function ===
304-
def refreshTokenLoop(url, pilotUUID, jwt, jwt_lock, logger, clientID, interval=600):
311+
def refreshTokenLoop(url, pilotUUID, jwt, jwt_lock, logger, interval=600):
305312
"""
306313
Periodically refresh the pilot JWT token.
307314
@@ -310,15 +317,14 @@ def refreshTokenLoop(url, pilotUUID, jwt, jwt_lock, logger, clientID, interval=6
310317
:param dict jwt: Shared JWT dictionary
311318
:param threading.Lock jwt_lock: Lock to safely update JWT
312319
:param Logger logger: Logger to debug
313-
:param str clientID: ClientID used to refresh tokens
314320
:param int interval: Sleep time between refreshes in seconds
315321
:return: None
316322
"""
317323
while True:
318324
time.sleep(interval)
319325

320326
try:
321-
refreshPilotToken(url, pilotUUID, jwt, jwt_lock, clientID)
327+
refreshPilotToken(url, pilotUUID, jwt, jwt_lock)
322328

323329
logger.info("Token refreshed.")
324330
except Exception as e:

0 commit comments

Comments
 (0)