Skip to content

Commit 5e3e199

Browse files
committed
Forward auth token to MLM calls
1 parent 92c4370 commit 5e3e199

File tree

3 files changed

+81
-23
lines changed

3 files changed

+81
-23
lines changed

src/mcp_server_uyuni/config.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
11
import os
22

33
REQUIRED_VARS = [
4-
"UYUNI_SERVER",
5-
"UYUNI_USER",
6-
"UYUNI_PASS",
4+
"UYUNI_SERVER"
75
]
86

97
missing_vars = [key for key in REQUIRED_VARS if key not in os.environ]

src/mcp_server_uyuni/server.py

Lines changed: 69 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from datetime import datetime, timezone
2121
from pydantic import BaseModel
2222

23+
from fastmcp.server.middleware import Middleware, MiddlewareContext
2324
from fastmcp import FastMCP, Context
2425
from mcp import LoggingLevel, ServerSession, types
2526
from mcp_server_uyuni.logging_config import get_logger, Transport
@@ -37,6 +38,30 @@ class ActivationKeySchema(BaseModel):
3738

3839
logger = get_logger(log_file=CONFIG["UYUNI_MCP_LOG_FILE_PATH"], transport=CONFIG["UYUNI_MCP_TRANSPORT"])
3940

41+
class AuthTokenMiddleware(Middleware):
42+
async def on_call_tool(self, ctx: MiddlewareContext, call_next):
43+
"""
44+
Extracts the JWT token from the Authorization header (if present)
45+
and injects it into the context state for other tools to use.
46+
"""
47+
fastmcp_ctx = ctx.fastmcp_context
48+
auth_header = fastmcp_ctx.request_context.request.headers['authorization']
49+
token = None
50+
if auth_header:
51+
# Expecting "Authorization: Bearer <token>"
52+
parts = auth_header.split()
53+
if len(parts) == 2 and parts[0] == "Bearer":
54+
token = parts[1]
55+
logger.debug("Successfully extracted token from header.")
56+
else:
57+
logger.warning(f"Malformed Authorization header received: {auth_header}")
58+
else:
59+
logger.debug("No Authorization header found in the request.")
60+
61+
fastmcp_ctx.set_state('token', token)
62+
result = await call_next(ctx)
63+
return result
64+
4065
def write_tool(*decorator_args, **decorator_kwargs):
4166
"""
4267
A decorator that registers a function as an MCP tool only if write
@@ -79,16 +104,17 @@ async def get_list_of_active_systems(ctx: Context) -> List[Dict[str, Any]]:
79104
logger.info(log_string)
80105
await ctx.info(log_string)
81106

82-
return await _get_list_of_active_systems()
107+
return await _get_list_of_active_systems(ctx.get_state('token'))
83108

84-
async def _get_list_of_active_systems() -> List[Dict[str, Union[str, int]]]:
109+
async def _get_list_of_active_systems(token: str) -> List[Dict[str, Union[str, int]]]:
85110

86111
async with httpx.AsyncClient(verify=CONFIG["UYUNI_MCP_SSL_VERIFY"]) as client:
87112
systems_data_result = await call_uyuni_api(
88113
client=client,
89114
method="GET",
90115
api_path="/rhn/manager/api/system/listSystems",
91116
error_context="fetching active systems",
117+
token=token,
92118
default_on_error=[]
93119
)
94120

@@ -104,7 +130,7 @@ async def _get_list_of_active_systems() -> List[Dict[str, Union[str, int]]]:
104130

105131
return filtered_systems
106132

107-
async def _resolve_system_id(system_identifier: Union[str, int]) -> Optional[str]:
133+
async def _resolve_system_id(system_identifier: Union[str, int], token: str) -> Optional[str]:
108134
"""
109135
Resolves a system identifier, which can be a name or an ID, to a numeric system ID string.
110136
@@ -132,6 +158,7 @@ async def _resolve_system_id(system_identifier: Union[str, int]) -> Optional[str
132158
api_path="/rhn/manager/api/system/getId",
133159
params={'name': system_name},
134160
error_context=f"resolving system ID for name '{system_name}'",
161+
token=token,
135162
default_on_error=[] # Return an empty list on failure
136163
)
137164

@@ -174,10 +201,10 @@ async def get_cpu_of_a_system(system_identifier: Union[str, int], ctx: Context)
174201
log_string = f"Getting CPU information of system with id {system_identifier}"
175202
logger.info(log_string)
176203
await ctx.info(log_string)
177-
return await _get_cpu_of_a_system(system_identifier)
204+
return await _get_cpu_of_a_system(system_identifier, ctx.get_state('token'))
178205

179-
async def _get_cpu_of_a_system(system_identifier: Union[str, int]) -> Dict[str, Any]:
180-
system_id = await _resolve_system_id(system_identifier)
206+
async def _get_cpu_of_a_system(system_identifier: Union[str, int], token: str) -> Dict[str, Any]:
207+
system_id = await _resolve_system_id(system_identifier, token)
181208
if not system_id:
182209
return {} # Helper function already logged the reason for failure.
183210

@@ -188,6 +215,7 @@ async def _get_cpu_of_a_system(system_identifier: Union[str, int]) -> Dict[str,
188215
api_path="/rhn/manager/api/system/getCpu",
189216
params={'sid': system_id},
190217
error_context=f"fetching CPU data for system {system_identifier}",
218+
token=token,
191219
default_on_error={}
192220
)
193221

@@ -223,7 +251,7 @@ async def get_all_systems_cpu_info(ctx: Context) -> List[Dict[str, Any]]:
223251
await ctx.info(log_string)
224252

225253
all_systems_cpu_data = []
226-
active_systems = await _get_list_of_active_systems() # Calls your existing tool
254+
active_systems = await _get_list_of_active_systems(ctx.get_state('token'))
227255

228256
if not active_systems:
229257
print("Warning: No active systems found or failed to retrieve system list.")
@@ -238,7 +266,7 @@ async def get_all_systems_cpu_info(ctx: Context) -> List[Dict[str, Any]]:
238266
continue
239267

240268
print(f"Fetching CPU info for system: {system_name} (ID: {system_id})")
241-
cpu_info = await _get_cpu_of_a_system(str(system_id)) # Calls your other existing tool
269+
cpu_info = await _get_cpu_of_a_system(str(system_id), ctx.get_state('token'))
242270

243271
all_systems_cpu_data.append({
244272
'system_name': system_name,
@@ -320,7 +348,8 @@ async def check_system_updates(system_identifier: Union[str, int], ctx: Context)
320348
return await _check_system_updates(system_identifier, ctx)
321349

322350
async def _check_system_updates(system_identifier: Union[str, int], ctx: Context) -> Dict[str, Any]:
323-
system_id = await _resolve_system_id(system_identifier)
351+
token = ctx.get_state('token')
352+
system_id = await _resolve_system_id(system_identifier, token)
324353
default_error_response = {
325354
'system_identifier': system_identifier,
326355
'has_pending_updates': False,
@@ -340,6 +369,7 @@ async def _check_system_updates(system_identifier: Union[str, int], ctx: Context
340369
api_path="/rhn/manager/api/system/getRelevantErrata",
341370
params={'sid': system_id},
342371
error_context=f"checking updates for system {system_identifier}",
372+
token=token,
343373
default_on_error=None # Distinguish API error from empty list
344374
)
345375

@@ -349,6 +379,7 @@ async def _check_system_updates(system_identifier: Union[str, int], ctx: Context
349379
api_path="/rhn/manager/api/system/getUnscheduledErrata",
350380
params={'sid': str(system_id)},
351381
error_context=f"checking unscheduled errata for system ID {system_id}",
382+
token=token,
352383
default_on_error=[] # Return empty list on failure
353384
)
354385

@@ -440,7 +471,7 @@ async def check_all_systems_for_updates(ctx: Context) -> List[Dict[str, Any]]:
440471
await ctx.info(log_string)
441472

442473
systems_with_updates = []
443-
active_systems = await _get_list_of_active_systems() # Get the list of all systems
474+
active_systems = await get_list_of_active_systems(ctx) # Get the list of all systems
444475

445476
if not active_systems:
446477
print("Warning: No active systems found or failed to retrieve system list.")
@@ -505,6 +536,8 @@ async def schedule_apply_pending_updates_to_system(system_identifier: Union[str,
505536
if not is_confirmed:
506537
return f"CONFIRMATION REQUIRED: This will apply pending updates to the system {system_identifier}. Do you confirm?"
507538

539+
token = ctx.get_state('token')
540+
508541
# 1. Use check_system_updates to get relevant errata
509542
update_info = await _check_system_updates(system_identifier, ctx)
510543

@@ -525,7 +558,7 @@ async def schedule_apply_pending_updates_to_system(system_identifier: Union[str,
525558
print(f"Could not extract any valid errata IDs for system {system_identifier} from the update information: {errata_list}")
526559
return ""
527560

528-
system_id = await _resolve_system_id(system_identifier)
561+
system_id = await _resolve_system_id(system_identifier, token)
529562
if not system_id:
530563
return "" # Helper function already logged the reason for failure.
531564

@@ -540,6 +573,7 @@ async def schedule_apply_pending_updates_to_system(system_identifier: Union[str,
540573
api_path="/rhn/manager/api/system/scheduleApplyErrata",
541574
json_body=payload,
542575
error_context=f"scheduling errata application for system {system_identifier}",
576+
token=token,
543577
default_on_error=None # Helper will return None on error
544578
)
545579

@@ -584,7 +618,8 @@ async def schedule_apply_specific_update(system_identifier: Union[str, int], err
584618
return f"Invalid errata ID '{errata_id}'. The ID must be an integer."
585619

586620

587-
system_id = await _resolve_system_id(system_identifier)
621+
token = ctx.get_state('token')
622+
system_id = await _resolve_system_id(system_identifier, token)
588623
if not system_id:
589624
return "" # Helper function already logged the reason for failure.
590625

@@ -602,6 +637,7 @@ async def schedule_apply_specific_update(system_identifier: Union[str, int], err
602637
api_path="/rhn/manager/api/system/scheduleApplyErrata",
603638
json_body=payload,
604639
error_context=f"scheduling specific update (errata ID: {errata_id_int}) for system {system_identifier}",
640+
token=token,
605641
default_on_error=None # Helper returns None on error
606642
)
607643

@@ -681,8 +717,10 @@ async def add_system(
681717
elif not activation_key: # Fallback if elicitation is not supported
682718
return "You need to provide an activation key."
683719

720+
token = ctx.get_state('token')
721+
684722
# Check if the system already exists
685-
active_systems = await _get_list_of_active_systems()
723+
active_systems = await _get_list_of_active_systems(token)
686724
for system in active_systems:
687725
if system.get('system_name') == host:
688726
message = f"System '{host}' already exists in Uyuni. No action taken."
@@ -725,6 +763,7 @@ async def add_system(
725763
api_path="/rhn/manager/api/system/bootstrapWithPrivateSshKey",
726764
json_body=payload,
727765
error_context=f"adding system {host}",
766+
token=token,
728767
default_on_error=None,
729768
expect_timeout=True,
730769
)
@@ -772,12 +811,13 @@ async def remove_system(system_identifier: Union[str, int], ctx: Context, cleanu
772811

773812
is_confirmed = _to_bool(confirm)
774813

775-
system_id = await _resolve_system_id(system_identifier)
814+
token = ctx.get_state('token')
815+
system_id = await _resolve_system_id(system_identifier, token)
776816
if not system_id:
777817
return "" # Helper function already logged the reason for failure.
778818

779819
# Check if the system exists before proceeding
780-
active_systems = await _get_list_of_active_systems()
820+
active_systems = await _get_list_of_active_systems(token)
781821
if not any(s.get('system_id') == int(system_id) for s in active_systems):
782822
message = f"System with ID {system_id} not found."
783823
logger.warning(message)
@@ -796,6 +836,7 @@ async def remove_system(system_identifier: Union[str, int], ctx: Context, cleanu
796836
api_path="/rhn/manager/api/system/deleteSystem",
797837
json_body={"sid": system_id, "cleanupType": cleanup_type},
798838
error_context=f"removing system ID {system_id}",
839+
token=token,
799840
default_on_error=None
800841
)
801842

@@ -840,6 +881,7 @@ async def get_systems_needing_security_update_for_cve(cve_identifier: str, ctx:
840881
find_by_cve_path = '/rhn/manager/api/errata/findByCve'
841882
list_affected_systems_path = '/rhn/manager/api/errata/listAffectedSystems'
842883

884+
token = ctx.get_state('token')
843885
async with httpx.AsyncClient(verify=CONFIG["UYUNI_MCP_SSL_VERIFY"]) as client:
844886
# 1. Call findByCve (login will be handled by the helper)
845887
print(f"Searching for errata related to CVE: {cve_identifier}")
@@ -849,6 +891,7 @@ async def get_systems_needing_security_update_for_cve(cve_identifier: str, ctx:
849891
api_path=find_by_cve_path,
850892
params={'cveName': cve_identifier},
851893
error_context=f"finding errata for CVE {cve_identifier}",
894+
token=token,
852895
default_on_error=None # Distinguish API error from empty list
853896
)
854897

@@ -937,6 +980,7 @@ async def get_systems_needing_reboot(ctx: Context) -> List[Dict[str, Any]]: # No
937980
method="GET",
938981
api_path=list_reboot_path,
939982
error_context="fetching systems needing reboot",
983+
token=ctx.get_state('token'),
940984
default_on_error=[] # Return empty list on error
941985
)
942986

@@ -984,7 +1028,8 @@ async def schedule_system_reboot(system_identifier: Union[str, int], ctx:Context
9841028

9851029
is_confirmed = _to_bool(confirm)
9861030

987-
system_id = await _resolve_system_id(system_identifier)
1031+
token = ctx.get_state('token')
1032+
system_id = await _resolve_system_id(system_identifier, token)
9881033
if not system_id:
9891034
return "" # Helper function already logged the reason for failure.
9901035

@@ -1004,6 +1049,7 @@ async def schedule_system_reboot(system_identifier: Union[str, int], ctx:Context
10041049
api_path=schedule_reboot_path,
10051050
json_body=payload,
10061051
error_context=f"scheduling reboot for system {system_identifier}",
1052+
token=token,
10071053
default_on_error=None # Helper returns None on error
10081054
)
10091055

@@ -1050,6 +1096,7 @@ async def list_all_scheduled_actions(ctx: Context) -> List[Dict[str, Any]]:
10501096
method="GET",
10511097
api_path=list_actions_path,
10521098
error_context="listing all scheduled actions",
1099+
token=ctx.get_state('token'),
10531100
default_on_error=[] # Return empty list on error
10541101
)
10551102

@@ -1112,6 +1159,7 @@ async def cancel_action(action_id: int, ctx: Context, confirm: Union[bool, str]
11121159
api_path=cancel_actions_path,
11131160
json_body=payload,
11141161
error_context=f"canceling action {action_id}",
1162+
token=ctx.get_state('token'),
11151163
default_on_error=0 # API returns 1 on success, so 0 can signify an error or unexpected response from helper
11161164
)
11171165
if api_result == 1:
@@ -1121,7 +1169,7 @@ async def cancel_action(action_id: int, ctx: Context, confirm: Union[bool, str]
11211169
return f"Failed to cancel action: {action_id}. The API did not return success (expected 1, got {api_result}). Check server logs for details."
11221170

11231171
@mcp.tool()
1124-
async def list_activation_keys() -> List[Dict[str, str]]:
1172+
async def list_activation_keys(ctx: Context) -> List[Dict[str, str]]:
11251173
"""
11261174
Fetches a list of activation keys from the Uyuni server.
11271175
@@ -1143,6 +1191,7 @@ async def list_activation_keys() -> List[Dict[str, str]]:
11431191
method="GET",
11441192
api_path=list_keys_path,
11451193
error_context="listing activation keys",
1194+
token=ctx.get_state('token'),
11461195
default_on_error=[]
11471196
)
11481197

@@ -1184,6 +1233,7 @@ async def get_unscheduled_errata(system_id: int, ctx: Context) -> List[Dict[str,
11841233
api_path=get_unscheduled_errata,
11851234
params=payload,
11861235
error_context=f"fetching unscheduled errata for system ID {system_id}",
1236+
token=ctx.get_state('token'),
11871237
default_on_error=None
11881238
)
11891239

@@ -1203,7 +1253,8 @@ def main_cli():
12031253
logger.info("Running Uyuni MCP server.")
12041254

12051255
if CONFIG["UYUNI_MCP_TRANSPORT"] == Transport.HTTP.value:
1206-
mcp.run(transport="streamable-http")
1256+
mcp.add_middleware(AuthTokenMiddleware())
1257+
mcp.run(transport="streamable-http", host=CONFIG["UYUNI_MCP_HOST"], port=CONFIG["UYUNI_MCP_PORT"])
12071258
elif CONFIG["UYUNI_MCP_TRANSPORT"] == Transport.STDIO.value:
12081259
mcp.run(transport="stdio")
12091260
else:

src/mcp_server_uyuni/uyuni_api.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ async def call(
1919
method: str,
2020
api_path: str,
2121
error_context: str,
22+
token: Optional[str] = None,
2223
params: Dict[str, Any] = None,
2324
json_body: Dict[str, Any] = None,
2425
perform_login: bool = True,
@@ -40,9 +41,17 @@ async def call(
4041
return error_msg
4142

4243
if perform_login:
43-
login_data = {"login": CONFIG["UYUNI_USER"], "password": CONFIG["UYUNI_PASS"]}
4444
try:
45-
login_response = await client.post(CONFIG["UYUNI_SERVER"] + '/rhn/manager/api/login', json=login_data)
45+
if token:
46+
login_response = await client.post(
47+
CONFIG["UYUNI_SERVER"] + '/rhn/manager/api/oidcLogin',
48+
headers={"Authorization": f"Bearer {token}"}
49+
)
50+
elif CONFIG["UYUNI_USER"] and CONFIG["UYUNI_PASS"]:
51+
login_response = await client.post(
52+
CONFIG["UYUNI_SERVER"] + '/rhn/manager/api/login',
53+
json={"login": CONFIG["UYUNI_USER"], "password": CONFIG["UYUNI_PASS"]}
54+
)
4655
login_response.raise_for_status()
4756
except httpx.HTTPStatusError as e:
4857
logger.error(f"HTTP error during login for {error_context}: {e.request.url} - {e.response.status_code} - {e.response.text}")

0 commit comments

Comments
 (0)