Skip to content

Commit e2dba24

Browse files
feat: Add multi thread support to refresh tokens
1 parent 1412a18 commit e2dba24

File tree

4 files changed

+190
-32
lines changed

4 files changed

+190
-32
lines changed

Pilot/dirac-pilot.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,12 @@
4949
getCommand,
5050
pythonPathCheck,
5151
)
52+
53+
try:
54+
from Pilot.proxyTools import revokePilotToken
55+
except ImportError:
56+
from proxyTools import revokePilotToken
57+
5258
############################
5359

5460
if __name__ == "__main__":
@@ -124,3 +130,14 @@
124130
if remote:
125131
log.buffer.flush()
126132
sys.exit(-1)
133+
134+
log.info("Pilot tasks finished.")
135+
136+
if pilotParams.jwt:
137+
log.info("Revoking pilot token.")
138+
revokePilotToken(
139+
pilotParams.diracXServer,
140+
pilotParams.pilotUUID,
141+
pilotParams.jwt,
142+
pilotParams.clientID
143+
)

Pilot/pilotCommands.py

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ def __init__(self, pilotParams):
2121

2222
import filecmp
2323
import os
24+
import threading
2425
import platform
2526
import shutil
2627
import socket
@@ -63,9 +64,9 @@ def __init__(self, pilotParams):
6364
)
6465

6566
try:
66-
from Pilot.proxyTools import BaseRequest
67+
from Pilot.proxyTools import BaseRequest, refreshTokenLoop
6768
except ImportError:
68-
from proxyTools import BaseRequest
69+
from proxyTools import BaseRequest, refreshTokenLoop
6970

7071
try:
7172
from urllib.error import HTTPError, URLError
@@ -592,6 +593,7 @@ class PilotLoginX(CommandBase):
592593
def __init__(self, pilotParams):
593594
"""c'tor"""
594595
super(PilotLoginX, self).__init__(pilotParams)
596+
self.jwt_lock = threading.Lock()
595597

596598
@logFinalizer
597599
def execute(self):
@@ -609,17 +611,20 @@ def execute(self):
609611
self.log.error("DiracXServer (url) not given, exiting...")
610612
sys.exit(-1)
611613

614+
if not self.pp.clientID:
615+
self.log.error("ClientID not given, exiting...")
616+
sys.exit(-1)
617+
612618
self.log.info("Fetching JWT in DiracX (URL: %s)" % self.pp.diracXServer)
613619

614620
config = BaseRequest(
615-
"%s/api/auth/pilot-login" % (
621+
"%s/api/pilots/token" % (
616622
self.pp.diracXServer
617623
),
618-
os.getenv("X509_CERT_DIR")
624+
os.getenv("X509_CERT_DIR"),
625+
self.pp.pilotUUID
619626
)
620627

621-
config.generateUserAgent(self.pp.pilotUUID)
622-
623628
try:
624629
self.pp.jwt = config.executeRequest({
625630
"pilot_stamp": self.pp.pilotUUID,
@@ -632,6 +637,25 @@ def execute(self):
632637

633638
self.log.info("Fetched the pilot token with the pilot secret.")
634639

640+
self.log.info("Starting the refresh thread.")
641+
self.log.info("Refreshing the token every %d seconds." % self.pp.refreshTokenEvery)
642+
# Start background refresh thread
643+
t = threading.Thread(
644+
target=refreshTokenLoop,
645+
args=(
646+
self.pp.diracXServer,
647+
self.pp.pilotUUID,
648+
self.pp.jwt,
649+
self.jwt_lock,
650+
self.log,
651+
self.pp.clientID,
652+
self.pp.refreshTokenEvery
653+
)
654+
)
655+
t.daemon = True
656+
t.start()
657+
658+
635659
class CheckCECapabilities(CommandBase):
636660
"""Used to get CE tags and other relevant parameters."""
637661

Pilot/pilotTools.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -705,15 +705,12 @@ def sendMessage(url, pilotUUID, wnVO, method, rawMessage, jwt={}):
705705
config = None
706706

707707
if jwt:
708-
try:
709-
access_token = jwt["access_token"]
710-
except ValueError as e:
711-
raise ValueError("JWT is needed, with an access_token field")
712708

713709
config = TokenBasedRequest(
714710
url=url,
715711
caPath=caPath,
716-
jwtData=access_token
712+
jwtData=jwt,
713+
pilotUUID=pilotUUID
717714
)
718715

719716
else:
@@ -722,12 +719,10 @@ def sendMessage(url, pilotUUID, wnVO, method, rawMessage, jwt={}):
722719
config = X509BasedRequest(
723720
url=url,
724721
caPath=caPath,
725-
certEnv=cert
722+
certEnv=cert,
723+
pilotUUID=pilotUUID
726724
)
727725

728-
# Config the header, will help debugging
729-
config.generateUserAgent(pilotUUID=pilotUUID)
730-
731726
# Do the request
732727
_res = config.executeRequest(
733728
raw_data=raw_data,
@@ -926,6 +921,8 @@ def __init__(self):
926921
self.queueName = ""
927922
self.gridCEType = ""
928923
self.pilotSecret = ""
924+
self.clientID = ""
925+
self.refreshTokenEvery = 300
929926
self.jwt = {
930927
"access_token": "",
931928
"refresh_token": ""
@@ -1041,6 +1038,8 @@ def __init__(self):
10411038
("", "architectureScript=", "architecture script to use"),
10421039
("", "CVMFS_locations=", "comma-separated list of CVMS locations"),
10431040
("", "pilotSecret=", "secret that the pilot uses with DiracX"),
1041+
("", "clientID=", "client id used by DiracX to revoke a token"),
1042+
("", "refreshTokenEvery=", "how often we have to refresh a token (in seconds)")
10441043
)
10451044

10461045
# Possibly get Setup and JSON URL/filename from command line
@@ -1248,6 +1247,10 @@ def __initCommandLine2(self):
12481247
self.CVMFS_locations = v.split(",")
12491248
elif o == "--pilotSecret":
12501249
self.pilotSecret = v
1250+
elif o == "--clientID":
1251+
self.clientID = v
1252+
elif o == "--refreshTokenEvery":
1253+
self.refreshTokenEvery = int(v)
12511254

12521255
def __loadJSON(self):
12531256
"""

Pilot/proxyTools.py

Lines changed: 131 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import json
66
import os
7+
import time
78
import re
89
import ssl
910
import sys
@@ -83,26 +84,23 @@ def getVO(proxy_data):
8384
class BaseRequest(object):
8485
"""This class helps supporting multiple kinds of requests that require connections"""
8586

86-
def __init__(self, url, caPath, name="unknown"):
87+
def __init__(self, url, caPath, pilotUUID, name="unknown"):
8788
self.name = name
8889
self.url = url
8990
self.caPath = caPath
9091
self.headers = {
9192
"User-Agent": "Dirac Pilot [Unknown ID]"
9293
}
94+
self.pilotUUID = pilotUUID
9395
# We assume we have only one context, so this variable could be shared to avoid opening n times a cert.
9496
# On the contrary, to avoid race conditions, we do avoid using "self.data" and "self.headers"
9597
self._context = None
9698

9799
self._prepareRequest()
98100

99-
def generateUserAgent(self, pilotUUID):
100-
"""To analyse the traffic, we can send a taylor-made User-Agent
101-
102-
:param pilotUUID: Unique ID of the Pilot
103-
:type pilotUUID: str
104-
"""
105-
self.headers["User-Agent"] = "Dirac Pilot [%s]" % pilotUUID
101+
def generateUserAgent(self):
102+
"""To analyse the traffic, we can send a taylor-made User-Agent"""
103+
self.addHeader("User-Agent", "Dirac Pilot [%s]" % self.pilotUUID)
106104

107105
def _prepareRequest(self):
108106
"""As previously, loads the SSL certificates of the server (to avoid "unknown issuer")"""
@@ -128,18 +126,18 @@ def executeRequest(self, raw_data, insecure=False, content_type="json"):
128126
"""
129127
if content_type == "json":
130128
data = json.dumps(raw_data).encode("utf-8")
131-
self.headers["Content-Type"] = "application/json"
129+
self.addHeader("Content-Type", "application/json")
132130
elif content_type == "x-www-form-urlencoded":
133131
if sys.version_info.major == 3:
134132
data = urlencode(raw_data).encode("utf-8") # encode to bytes ! for python3
135133
else:
136134
# Python2
137135
data = urlencode(raw_data)
138-
self.headers["Content-Type"] = "application/x-www-form-urlencoded"
136+
self.addHeader("Content-Type", "application/x-www-form-urlencoded")
139137
else:
140138
raise ValueError("Invalid content_type. Use 'json' or 'x-www-form-urlencoded'.")
141139

142-
self.headers["Content-Length"] = str(len(data))
140+
self.addHeader("Content-Length", str(len(data)))
143141

144142
request = Request(self.url, data=data, headers=self.headers, method="POST")
145143

@@ -173,21 +171,27 @@ def executeRequest(self, raw_data, insecure=False, content_type="json"):
173171
class TokenBasedRequest(BaseRequest):
174172
"""Connected Request with JWT support"""
175173

176-
def __init__(self, url, caPath, jwtData):
177-
super(TokenBasedRequest, self).__init__(url, caPath, "TokenBasedConnection")
178-
174+
def __init__(self, url, caPath, jwtData, pilotUUID):
175+
super(TokenBasedRequest, self).__init__(url, caPath, pilotUUID, "TokenBasedConnection")
179176
self.jwtData = jwtData
180177

181178
def addJwtToHeader(self):
182179
# Adds the JWT in the HTTP request (in the Bearer field)
183-
self.headers["Authorization"] = "Bearer: %s" % self.jwtData
180+
self.headers["Authorization"] = "Bearer: %s" % self.jwtData["access_token"]
184181

182+
def executeRequest(self, raw_data, insecure=False, content_type="json", is_token_refreshed=False):
183+
184+
return super(TokenBasedRequest, self).executeRequest(
185+
raw_data,
186+
insecure=insecure,
187+
content_type=content_type
188+
)
185189

186190
class X509BasedRequest(BaseRequest):
187191
"""Connected Request with X509 support"""
188192

189-
def __init__(self, url, caPath, certEnv):
190-
super(X509BasedRequest, self).__init__(url, caPath, "X509BasedConnection")
193+
def __init__(self, url, caPath, certEnv, pilotUUID):
194+
super(X509BasedRequest, self).__init__(url, caPath, pilotUUID, "X509BasedConnection")
191195

192196
self.certEnv = certEnv
193197
self._hasExtraCredentials = False
@@ -210,3 +214,113 @@ def executeRequest(self, raw_data, insecure=False, content_type="json"):
210214
insecure=insecure,
211215
content_type=content_type
212216
)
217+
218+
219+
def refreshPilotToken(url, pilotUUID, jwt, jwt_lock, clientID):
220+
"""
221+
Refresh the JWT token in a separate thread.
222+
223+
:param str url: Server URL
224+
:param str pilotUUID: Pilot unique ID
225+
:param dict jwt: Shared dict with current JWT; updated in-place
226+
:param threading.Lock jwt_lock: Lock to safely update the jwt dict
227+
:return: None
228+
"""
229+
230+
# PRECONDITION: jwt must contain "refresh_token"
231+
if not jwt or "refresh_token" not in jwt:
232+
raise ValueError("To refresh a token, a pilot needs a JWT with refresh_token")
233+
234+
# Get CA path from environment
235+
caPath = os.getenv("X509_CERT_DIR")
236+
237+
# Create request object with required configuration
238+
config = BaseRequest(
239+
url="%s/api/auth/token" % url,
240+
caPath=caPath,
241+
pilotUUID=pilotUUID
242+
)
243+
244+
# Prepare refresh token payload
245+
payload = {
246+
"grant_type": "refresh_token",
247+
"refresh_token": jwt["refresh_token"],
248+
"client_id": clientID
249+
}
250+
251+
# Perform the request to refresh the token
252+
response = config.executeRequest(
253+
raw_data=payload,
254+
insecure=True,
255+
content_type="x-www-form-urlencoded"
256+
)
257+
258+
# Ensure thread-safe update of the shared jwt dictionary
259+
jwt_lock.acquire()
260+
try:
261+
jwt.update(response)
262+
finally:
263+
jwt_lock.release()
264+
265+
266+
def revokePilotToken(url, pilotUUID, jwt, clientID):
267+
"""
268+
Refresh the JWT token in a separate thread.
269+
270+
:param str url: Server URL
271+
:param str pilotUUID: Pilot unique ID
272+
:param dict jwt: Shared dict with current JWT;
273+
:return: None
274+
"""
275+
276+
# PRECONDITION: jwt must contain "refresh_token"
277+
if not jwt or "refresh_token" not in jwt:
278+
raise ValueError("To refresh a token, a pilot needs a JWT with refresh_token")
279+
280+
# Get CA path from environment
281+
caPath = os.getenv("X509_CERT_DIR")
282+
283+
# Create request object with required configuration
284+
config = BaseRequest(
285+
url="%s/api/auth/revoke" % url,
286+
caPath=caPath,
287+
pilotUUID=pilotUUID
288+
)
289+
290+
# Prepare refresh token payload
291+
payload = {
292+
"refresh_token": jwt["refresh_token"],
293+
"client_id": clientID
294+
}
295+
296+
# Perform the request to revoke the token
297+
_response = config.executeRequest(
298+
raw_data=payload,
299+
insecure=True,
300+
content_type="x-www-form-urlencoded"
301+
)
302+
303+
# === Token refresher thread function ===
304+
def refreshTokenLoop(url, pilotUUID, jwt, jwt_lock, logger, clientID, interval=600):
305+
"""
306+
Periodically refresh the pilot JWT token.
307+
308+
:param str url: DiracX server URL
309+
:param str pilotUUID: Pilot UUID
310+
:param dict jwt: Shared JWT dictionary
311+
:param threading.Lock jwt_lock: Lock to safely update JWT
312+
:param Logger logger: Logger to debug
313+
:param str clientID: ClientID used to refresh tokens
314+
:param int interval: Sleep time between refreshes in seconds
315+
:return: None
316+
"""
317+
while True:
318+
time.sleep(interval)
319+
320+
try:
321+
refreshPilotToken(url, pilotUUID, jwt, jwt_lock, clientID)
322+
323+
logger.info("Token refreshed.")
324+
except Exception as e:
325+
logger.error("Token refresh failed: %s\n" % str(e))
326+
continue

0 commit comments

Comments
 (0)