Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions src/mcp_server_uyuni/config.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
import os

REQUIRED_VARS = [
"UYUNI_SERVER",
"UYUNI_USER",
"UYUNI_PASS",
"UYUNI_SERVER"
]

missing_vars = [key for key in REQUIRED_VARS if key not in os.environ]
Expand Down
87 changes: 69 additions & 18 deletions src/mcp_server_uyuni/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from datetime import datetime, timezone
from pydantic import BaseModel

from fastmcp.server.middleware import Middleware, MiddlewareContext
from fastmcp import FastMCP, Context
from mcp import LoggingLevel, ServerSession, types
from mcp_server_uyuni.logging_config import get_logger, Transport
Expand All @@ -37,6 +38,30 @@ class ActivationKeySchema(BaseModel):

logger = get_logger(log_file=CONFIG["UYUNI_MCP_LOG_FILE_PATH"], transport=CONFIG["UYUNI_MCP_TRANSPORT"])

class AuthTokenMiddleware(Middleware):
async def on_call_tool(self, ctx: MiddlewareContext, call_next):
"""
Extracts the JWT token from the Authorization header (if present)
and injects it into the context state for other tools to use.
"""
fastmcp_ctx = ctx.fastmcp_context
auth_header = fastmcp_ctx.request_context.request.headers['authorization']
token = None
if auth_header:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could we add a config variable that controls if auth is expected for mlm? Could be we have a token for the mcp authorization but mlm does not yet support it..

# Expecting "Authorization: Bearer <token>"
parts = auth_header.split()
if len(parts) == 2 and parts[0] == "Bearer":
token = parts[1]
logger.debug("Successfully extracted token from header.")
else:
logger.warning(f"Malformed Authorization header received: {auth_header}")
else:
logger.debug("No Authorization header found in the request.")

fastmcp_ctx.set_state('token', token)
result = await call_next(ctx)
return result

def write_tool(*decorator_args, **decorator_kwargs):
"""
A decorator that registers a function as an MCP tool only if write
Expand Down Expand Up @@ -79,16 +104,17 @@ async def get_list_of_active_systems(ctx: Context) -> List[Dict[str, Any]]:
logger.info(log_string)
await ctx.info(log_string)

return await _get_list_of_active_systems()
return await _get_list_of_active_systems(ctx.get_state('token'))

async def _get_list_of_active_systems() -> List[Dict[str, Union[str, int]]]:
async def _get_list_of_active_systems(token: str) -> List[Dict[str, Union[str, int]]]:

async with httpx.AsyncClient(verify=CONFIG["UYUNI_MCP_SSL_VERIFY"]) as client:
systems_data_result = await call_uyuni_api(
client=client,
method="GET",
api_path="/rhn/manager/api/system/listSystems",
error_context="fetching active systems",
token=token,
default_on_error=[]
)

Expand All @@ -104,7 +130,7 @@ async def _get_list_of_active_systems() -> List[Dict[str, Union[str, int]]]:

return filtered_systems

async def _resolve_system_id(system_identifier: Union[str, int]) -> Optional[str]:
async def _resolve_system_id(system_identifier: Union[str, int], token: str) -> Optional[str]:
"""
Resolves a system identifier, which can be a name or an ID, to a numeric system ID string.

Expand Down Expand Up @@ -132,6 +158,7 @@ async def _resolve_system_id(system_identifier: Union[str, int]) -> Optional[str
api_path="/rhn/manager/api/system/getId",
params={'name': system_name},
error_context=f"resolving system ID for name '{system_name}'",
token=token,
default_on_error=[] # Return an empty list on failure
)

Expand Down Expand Up @@ -174,10 +201,10 @@ async def get_cpu_of_a_system(system_identifier: Union[str, int], ctx: Context)
log_string = f"Getting CPU information of system with id {system_identifier}"
logger.info(log_string)
await ctx.info(log_string)
return await _get_cpu_of_a_system(system_identifier)
return await _get_cpu_of_a_system(system_identifier, ctx.get_state('token'))

async def _get_cpu_of_a_system(system_identifier: Union[str, int]) -> Dict[str, Any]:
system_id = await _resolve_system_id(system_identifier)
async def _get_cpu_of_a_system(system_identifier: Union[str, int], token: str) -> Dict[str, Any]:
system_id = await _resolve_system_id(system_identifier, token)
if not system_id:
return {} # Helper function already logged the reason for failure.

Expand All @@ -188,6 +215,7 @@ async def _get_cpu_of_a_system(system_identifier: Union[str, int]) -> Dict[str,
api_path="/rhn/manager/api/system/getCpu",
params={'sid': system_id},
error_context=f"fetching CPU data for system {system_identifier}",
token=token,
default_on_error={}
)

Expand Down Expand Up @@ -223,7 +251,7 @@ async def get_all_systems_cpu_info(ctx: Context) -> List[Dict[str, Any]]:
await ctx.info(log_string)

all_systems_cpu_data = []
active_systems = await _get_list_of_active_systems() # Calls your existing tool
active_systems = await _get_list_of_active_systems(ctx.get_state('token'))

if not active_systems:
print("Warning: No active systems found or failed to retrieve system list.")
Expand All @@ -238,7 +266,7 @@ async def get_all_systems_cpu_info(ctx: Context) -> List[Dict[str, Any]]:
continue

print(f"Fetching CPU info for system: {system_name} (ID: {system_id})")
cpu_info = await _get_cpu_of_a_system(str(system_id)) # Calls your other existing tool
cpu_info = await _get_cpu_of_a_system(str(system_id), ctx.get_state('token'))

all_systems_cpu_data.append({
'system_name': system_name,
Expand Down Expand Up @@ -320,7 +348,8 @@ async def check_system_updates(system_identifier: Union[str, int], ctx: Context)
return await _check_system_updates(system_identifier, ctx)

async def _check_system_updates(system_identifier: Union[str, int], ctx: Context) -> Dict[str, Any]:
system_id = await _resolve_system_id(system_identifier)
token = ctx.get_state('token')
system_id = await _resolve_system_id(system_identifier, token)
default_error_response = {
'system_identifier': system_identifier,
'has_pending_updates': False,
Expand All @@ -340,6 +369,7 @@ async def _check_system_updates(system_identifier: Union[str, int], ctx: Context
api_path="/rhn/manager/api/system/getRelevantErrata",
params={'sid': system_id},
error_context=f"checking updates for system {system_identifier}",
token=token,
default_on_error=None # Distinguish API error from empty list
)

Expand All @@ -349,6 +379,7 @@ async def _check_system_updates(system_identifier: Union[str, int], ctx: Context
api_path="/rhn/manager/api/system/getUnscheduledErrata",
params={'sid': str(system_id)},
error_context=f"checking unscheduled errata for system ID {system_id}",
token=token,
default_on_error=[] # Return empty list on failure
)

Expand Down Expand Up @@ -440,7 +471,7 @@ async def check_all_systems_for_updates(ctx: Context) -> List[Dict[str, Any]]:
await ctx.info(log_string)

systems_with_updates = []
active_systems = await _get_list_of_active_systems() # Get the list of all systems
active_systems = await get_list_of_active_systems(ctx) # Get the list of all systems

if not active_systems:
print("Warning: No active systems found or failed to retrieve system list.")
Expand Down Expand Up @@ -505,6 +536,8 @@ async def schedule_apply_pending_updates_to_system(system_identifier: Union[str,
if not is_confirmed:
return f"CONFIRMATION REQUIRED: This will apply pending updates to the system {system_identifier}. Do you confirm?"

token = ctx.get_state('token')

# 1. Use check_system_updates to get relevant errata
update_info = await _check_system_updates(system_identifier, ctx)

Expand All @@ -525,7 +558,7 @@ async def schedule_apply_pending_updates_to_system(system_identifier: Union[str,
print(f"Could not extract any valid errata IDs for system {system_identifier} from the update information: {errata_list}")
return ""

system_id = await _resolve_system_id(system_identifier)
system_id = await _resolve_system_id(system_identifier, token)
if not system_id:
return "" # Helper function already logged the reason for failure.

Expand All @@ -540,6 +573,7 @@ async def schedule_apply_pending_updates_to_system(system_identifier: Union[str,
api_path="/rhn/manager/api/system/scheduleApplyErrata",
json_body=payload,
error_context=f"scheduling errata application for system {system_identifier}",
token=token,
default_on_error=None # Helper will return None on error
)

Expand Down Expand Up @@ -584,7 +618,8 @@ async def schedule_apply_specific_update(system_identifier: Union[str, int], err
return f"Invalid errata ID '{errata_id}'. The ID must be an integer."


system_id = await _resolve_system_id(system_identifier)
token = ctx.get_state('token')
system_id = await _resolve_system_id(system_identifier, token)
if not system_id:
return "" # Helper function already logged the reason for failure.

Expand All @@ -602,6 +637,7 @@ async def schedule_apply_specific_update(system_identifier: Union[str, int], err
api_path="/rhn/manager/api/system/scheduleApplyErrata",
json_body=payload,
error_context=f"scheduling specific update (errata ID: {errata_id_int}) for system {system_identifier}",
token=token,
default_on_error=None # Helper returns None on error
)

Expand Down Expand Up @@ -681,8 +717,10 @@ async def add_system(
elif not activation_key: # Fallback if elicitation is not supported
return "You need to provide an activation key."

token = ctx.get_state('token')

# Check if the system already exists
active_systems = await _get_list_of_active_systems()
active_systems = await _get_list_of_active_systems(token)
for system in active_systems:
if system.get('system_name') == host:
message = f"System '{host}' already exists in Uyuni. No action taken."
Expand Down Expand Up @@ -725,6 +763,7 @@ async def add_system(
api_path="/rhn/manager/api/system/bootstrapWithPrivateSshKey",
json_body=payload,
error_context=f"adding system {host}",
token=token,
default_on_error=None,
expect_timeout=True,
)
Expand Down Expand Up @@ -772,12 +811,13 @@ async def remove_system(system_identifier: Union[str, int], ctx: Context, cleanu

is_confirmed = _to_bool(confirm)

system_id = await _resolve_system_id(system_identifier)
token = ctx.get_state('token')
system_id = await _resolve_system_id(system_identifier, token)
if not system_id:
return "" # Helper function already logged the reason for failure.

# Check if the system exists before proceeding
active_systems = await _get_list_of_active_systems()
active_systems = await _get_list_of_active_systems(token)
if not any(s.get('system_id') == int(system_id) for s in active_systems):
message = f"System with ID {system_id} not found."
logger.warning(message)
Expand All @@ -796,6 +836,7 @@ async def remove_system(system_identifier: Union[str, int], ctx: Context, cleanu
api_path="/rhn/manager/api/system/deleteSystem",
json_body={"sid": system_id, "cleanupType": cleanup_type},
error_context=f"removing system ID {system_id}",
token=token,
default_on_error=None
)

Expand Down Expand Up @@ -840,6 +881,7 @@ async def get_systems_needing_security_update_for_cve(cve_identifier: str, ctx:
find_by_cve_path = '/rhn/manager/api/errata/findByCve'
list_affected_systems_path = '/rhn/manager/api/errata/listAffectedSystems'

token = ctx.get_state('token')
async with httpx.AsyncClient(verify=CONFIG["UYUNI_MCP_SSL_VERIFY"]) as client:
# 1. Call findByCve (login will be handled by the helper)
print(f"Searching for errata related to CVE: {cve_identifier}")
Expand All @@ -849,6 +891,7 @@ async def get_systems_needing_security_update_for_cve(cve_identifier: str, ctx:
api_path=find_by_cve_path,
params={'cveName': cve_identifier},
error_context=f"finding errata for CVE {cve_identifier}",
token=token,
default_on_error=None # Distinguish API error from empty list
)

Expand Down Expand Up @@ -937,6 +980,7 @@ async def get_systems_needing_reboot(ctx: Context) -> List[Dict[str, Any]]: # No
method="GET",
api_path=list_reboot_path,
error_context="fetching systems needing reboot",
token=ctx.get_state('token'),
default_on_error=[] # Return empty list on error
)

Expand Down Expand Up @@ -984,7 +1028,8 @@ async def schedule_system_reboot(system_identifier: Union[str, int], ctx:Context

is_confirmed = _to_bool(confirm)

system_id = await _resolve_system_id(system_identifier)
token = ctx.get_state('token')
system_id = await _resolve_system_id(system_identifier, token)
if not system_id:
return "" # Helper function already logged the reason for failure.

Expand All @@ -1004,6 +1049,7 @@ async def schedule_system_reboot(system_identifier: Union[str, int], ctx:Context
api_path=schedule_reboot_path,
json_body=payload,
error_context=f"scheduling reboot for system {system_identifier}",
token=token,
default_on_error=None # Helper returns None on error
)

Expand Down Expand Up @@ -1050,6 +1096,7 @@ async def list_all_scheduled_actions(ctx: Context) -> List[Dict[str, Any]]:
method="GET",
api_path=list_actions_path,
error_context="listing all scheduled actions",
token=ctx.get_state('token'),
default_on_error=[] # Return empty list on error
)

Expand Down Expand Up @@ -1112,6 +1159,7 @@ async def cancel_action(action_id: int, ctx: Context, confirm: Union[bool, str]
api_path=cancel_actions_path,
json_body=payload,
error_context=f"canceling action {action_id}",
token=ctx.get_state('token'),
default_on_error=0 # API returns 1 on success, so 0 can signify an error or unexpected response from helper
)
if api_result == 1:
Expand All @@ -1121,7 +1169,7 @@ async def cancel_action(action_id: int, ctx: Context, confirm: Union[bool, str]
return f"Failed to cancel action: {action_id}. The API did not return success (expected 1, got {api_result}). Check server logs for details."

@mcp.tool()
async def list_activation_keys() -> List[Dict[str, str]]:
async def list_activation_keys(ctx: Context) -> List[Dict[str, str]]:
"""
Fetches a list of activation keys from the Uyuni server.

Expand All @@ -1143,6 +1191,7 @@ async def list_activation_keys() -> List[Dict[str, str]]:
method="GET",
api_path=list_keys_path,
error_context="listing activation keys",
token=ctx.get_state('token'),
default_on_error=[]
)

Expand Down Expand Up @@ -1184,6 +1233,7 @@ async def get_unscheduled_errata(system_id: int, ctx: Context) -> List[Dict[str,
api_path=get_unscheduled_errata,
params=payload,
error_context=f"fetching unscheduled errata for system ID {system_id}",
token=ctx.get_state('token'),
default_on_error=None
)

Expand All @@ -1203,7 +1253,8 @@ def main_cli():
logger.info("Running Uyuni MCP server.")

if CONFIG["UYUNI_MCP_TRANSPORT"] == Transport.HTTP.value:
mcp.run(transport="streamable-http")
mcp.add_middleware(AuthTokenMiddleware())
mcp.run(transport="streamable-http", host=CONFIG["UYUNI_MCP_HOST"], port=CONFIG["UYUNI_MCP_PORT"])
elif CONFIG["UYUNI_MCP_TRANSPORT"] == Transport.STDIO.value:
mcp.run(transport="stdio")
else:
Expand Down
13 changes: 11 additions & 2 deletions src/mcp_server_uyuni/uyuni_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ async def call(
method: str,
api_path: str,
error_context: str,
token: Optional[str] = None,
params: Dict[str, Any] = None,
json_body: Dict[str, Any] = None,
perform_login: bool = True,
Expand All @@ -40,9 +41,17 @@ async def call(
return error_msg

if perform_login:
login_data = {"login": CONFIG["UYUNI_USER"], "password": CONFIG["UYUNI_PASS"]}
try:
login_response = await client.post(CONFIG["UYUNI_SERVER"] + '/rhn/manager/api/login', json=login_data)
if token:
login_response = await client.post(
CONFIG["UYUNI_SERVER"] + '/rhn/manager/api/oidcLogin',
headers={"Authorization": f"Bearer {token}"}
)
elif CONFIG["UYUNI_USER"] and CONFIG["UYUNI_PASS"]:
login_response = await client.post(
CONFIG["UYUNI_SERVER"] + '/rhn/manager/api/login',
json={"login": CONFIG["UYUNI_USER"], "password": CONFIG["UYUNI_PASS"]}
)
login_response.raise_for_status()
except httpx.HTTPStatusError as e:
logger.error(f"HTTP error during login for {error_context}: {e.request.url} - {e.response.status_code} - {e.response.text}")
Expand Down