diff --git a/healthchain/gateway/fhir/base.py b/healthchain/gateway/fhir/base.py index 845e3a5..797997d 100644 --- a/healthchain/gateway/fhir/base.py +++ b/healthchain/gateway/fhir/base.py @@ -1,11 +1,13 @@ import logging import inspect import warnings +import asyncio from fastapi import Depends, HTTPException, Path, Query from datetime import datetime from typing import Any, Callable, Dict, List, Type, TypeVar, Optional from fastapi.responses import JSONResponse +from fhir.resources.reference import Reference from fhir.resources.capabilitystatement import CapabilityStatement from fhir.resources.resource import Resource @@ -153,6 +155,16 @@ def build_capability_statement(self) -> CapabilityStatement: } ) operation_details.append("Aggregate: Multi-source data aggregation") + elif operation == "predict": + interactions.append( + { + "code": "read", + "documentation": "ML model prediction via REST endpoint", + } + ) + operation_details.append( + "Predict: Serve ML model predictions as FHIR resources" + ) if interactions: resource_name = self._get_resource_name(resource_type) @@ -266,6 +278,16 @@ def get_gateway_status(self) -> Dict[str, Any]: "parameters": ["id (optional)", "sources (optional)"], } ) + elif operation == "predict": + operation_list.append( + { + "type": "predict", + "endpoint": f"/predict/{resource_name}/{{id}}", + "description": f"Serve ML predictions as {resource_name} resources", + "method": "GET", + "parameters": ["id"], + } + ) if operation_list: available_operations[resource_name] = operation_list @@ -312,15 +334,22 @@ def _register_resource_handler( resource_type: Type[Resource], operation: str, handler: Callable, + **kwargs, ) -> None: """Register a custom handler for a resource operation.""" + resource_name = self._get_resource_name(resource_type) self._validate_handler_annotations(resource_type, operation, handler) if resource_type not in self._resource_handlers: self._resource_handlers[resource_type] = {} + + # Store the handler function self._resource_handlers[resource_type][operation] = handler - resource_name = self._get_resource_name(resource_type) + # Store any additional decorator kwargs + if kwargs.get("decorator_kwargs"): + self._resource_handlers[resource_type][f"{operation}_kwargs"] = kwargs["decorator_kwargs"] + logger.debug( f"Registered {operation} handler for {resource_name}: {handler.__name__}" ) @@ -334,7 +363,7 @@ def _validate_handler_annotations( handler: Callable, ) -> None: """Validate that handler annotations match the decorator resource type.""" - if operation != "transform": + if operation not in ["transform", "predict"]: return try: @@ -347,7 +376,7 @@ def _validate_handler_annotations( ) return - if return_annotation != resource_type: + if operation == "transform" and return_annotation != resource_type: raise TypeError( f"Handler {handler.__name__} return type ({return_annotation}) " f"doesn't match decorator resource type ({resource_type})" @@ -374,6 +403,12 @@ def _register_operation_route( path = f"/aggregate/{resource_name}" summary = f"Aggregate {resource_name}" description = f"Aggregate {resource_name} resources from multiple sources" + elif operation == "predict": + path = f"/predict/{resource_name}/{{id}}" + summary = f"Predict using {resource_name}" + description = ( + f"Generate a {resource_name} resource using a registered ML model" + ) else: raise ValueError(f"Unsupported operation: {operation}") @@ -398,11 +433,16 @@ def _create_route_handler( """Create a route handler for the given resource type and operation.""" get_self_gateway = self._get_gateway_dependency() - def _execute_handler(fhir: "BaseFHIRGateway", *args) -> Any: + async def _execute_handler(fhir: "BaseFHIRGateway", *args) -> Any: """Common handler execution logic with error handling.""" + handler_func = fhir._resource_handlers[resource_type][operation] try: - handler_func = fhir._resource_handlers[resource_type][operation] - result = handler_func(*args) + # Await if the handler is async + if asyncio.iscoroutinefunction(handler_func): + result = await handler_func(*args) + else: + result = handler_func(*args) + return result except Exception as e: logger.error(f"Error in {operation} handler: {str(e)}") @@ -418,7 +458,12 @@ async def handler( fhir: "BaseFHIRGateway" = Depends(get_self_gateway), ): """Transform a resource with registered handler.""" - return _execute_handler(fhir, id, source) + result = await _execute_handler(fhir, id, source) + # For predict, wrap the result in the FHIR resource + if operation == "predict": + # This part is now inside the route handler to access decorator kwargs + result = fhir._wrap_prediction(resource_type, id, result) + return result elif operation == "aggregate": @@ -430,13 +475,72 @@ async def handler( fhir: "BaseFHIRGateway" = Depends(get_self_gateway), ): """Aggregate resources with registered handler.""" - return _execute_handler(fhir, id, sources) + result = await _execute_handler(fhir, id, sources) + return result + + elif operation == "predict": + # Retrieve kwargs passed to the decorator + decorator_kwargs = self._resource_handlers[resource_type].get( + "predict_kwargs", {} + ) + + async def handler( + id: str = Path(..., description="Patient ID to run prediction for"), + fhir: "BaseFHIRGateway" = Depends(get_self_gateway), + ): + """Generate a prediction resource with a registered handler.""" + result = await _execute_handler(fhir, id) + # Wrap the prediction using decorator-provided kwargs + return fhir._wrap_prediction( + resource_type, id, result, **decorator_kwargs + ) else: raise ValueError(f"Unsupported operation: {operation}") return handler + def _wrap_prediction( + self, + resource_type: Type[Resource], + patient_id: str, + prediction_output: Any, + status: str = "final", + ) -> Resource: + """Wrap a raw prediction output into a FHIR resource.""" + resource_name = self._get_resource_name(resource_type) + + if resource_name == "RiskAssessment": + prediction_data = {} + if isinstance(prediction_output, float): + prediction_data["probabilityDecimal"] = prediction_output + elif isinstance(prediction_output, dict): + # Assuming keys like 'score', 'qualitativeRisk', etc. + if "score" in prediction_output: + prediction_data["probabilityDecimal"] = prediction_output["score"] + if "qualitativeRisk" in prediction_output: + prediction_data["qualitativeRisk"] = prediction_output[ + "qualitativeRisk" + ] + # The fhir.resource model expects a CodeableConcept, not a string. + prediction_data["qualitativeRisk"] = { + "coding": [{"display": prediction_output["qualitativeRisk"]}], + "text": prediction_output["qualitativeRisk"], + } + + elif not isinstance(prediction_output, (float, dict)): + raise TypeError( + f"Prediction function must return a float or dict, but returned {type(prediction_output)}" + ) + + return resource_type( + status=status, + subject=Reference(reference=f"Patient/{patient_id}"), + prediction=[prediction_data], + ) + raise NotImplementedError(f"Prediction for {resource_name} not implemented.") + + def add_source(self, name: str, connection_string: str) -> None: """ Add a FHIR data source using connection string with OAuth2.0 flow. @@ -469,6 +573,40 @@ def decorator(handler: Callable): return decorator + def predict(self, resource: Type[Resource], status: str = "final", **kwargs): + """ + Decorator to simplify ML model deployment as FHIR endpoints. + + Wraps a function that returns a prediction score (float) or dictionary, + and automatically constructs the specified FHIR resource. + + Currently, only `RiskAssessment` is fully supported. + + Args: + resource: The FHIR resource type to create (e.g., RiskAssessment). + status: The status to set on the created FHIR resource. Defaults to "final", + which is a spec-compliant value for RiskAssessment. + **kwargs: Additional fields to set on the created resource. + + Example: + @fhir.predict(resource=RiskAssessment) + def predict_sepsis_risk(patient_id: str) -> float: # The patient_id is passed from the URL + # Your model logic here + return 0.85 # High risk + """ + + def decorator(handler: Callable): + # The user-provided handler is registered for the 'predict' operation + self._register_resource_handler( + resource, "predict", handler, decorator_kwargs={"status": status} + ) + + # The actual endpoint handler is created by _register_operation_route + # which calls _create_route_handler, which wraps our logic. + return handler + + return decorator + def transform(self, resource_type: Type[Resource]): """ Decorator for custom transformation functions. Must return the same resource type. diff --git a/tests/gateway/test_base_fhir_gateway.py b/tests/gateway/test_base_fhir_gateway.py index c354c68..0ef2f6f 100644 --- a/tests/gateway/test_base_fhir_gateway.py +++ b/tests/gateway/test_base_fhir_gateway.py @@ -1,9 +1,11 @@ import pytest from unittest.mock import Mock, patch, AsyncMock from typing import Dict, Any, List +import asyncio from fhir.resources.patient import Patient from fhir.resources.observation import Observation +from fhir.resources.riskassessment import RiskAssessment from healthchain.gateway.fhir import FHIRGateway, AsyncFHIRGateway @@ -222,3 +224,92 @@ def test_resource_name_extraction(fhir_gateway): """_get_resource_name correctly extracts resource names from types.""" assert fhir_gateway._get_resource_name(Patient) == "Patient" assert fhir_gateway._get_resource_name(Observation) == "Observation" + + +@pytest.mark.asyncio +async def test_predict_handler_raises_for_invalid_output_type(fhir_gateway): + """The predict route handler should raise a TypeError for unsupported return types.""" + + @fhir_gateway.predict(resource=RiskAssessment) + def invalid_prediction(patient_id: str) -> list: + return [0.1, 0.9] # Invalid return type + + route_handler = fhir_gateway.routes[-1].endpoint + + with pytest.raises(TypeError, match="Prediction function must return a float or dict"): + await route_handler(id="Patient789", fhir=fhir_gateway) + + +def test_predict_decorator_registers_handler_and_route(fhir_gateway): + """The @predict decorator should register a 'predict' handler and create a route.""" + initial_routes = len(fhir_gateway.routes) + + @fhir_gateway.predict(resource=RiskAssessment) + def predict_risk(patient_id: str) -> float: + return 0.5 + + # Check handler registration + assert "predict" in fhir_gateway._resource_handlers[RiskAssessment] + assert ( + fhir_gateway._resource_handlers[RiskAssessment]["predict"] + == predict_risk + ) + # Check that decorator kwargs are stored + assert "predict_kwargs" in fhir_gateway._resource_handlers[RiskAssessment] + assert fhir_gateway._resource_handlers[RiskAssessment]["predict_kwargs"] == { + "status": "final" + } + + # Check route creation + assert len(fhir_gateway.routes) == initial_routes + 1 + new_route = fhir_gateway.routes[-1] + assert new_route.path == f"{fhir_gateway.prefix}/predict/RiskAssessment/{{id}}" + + +@pytest.mark.asyncio +async def test_predict_handler_wraps_float_output(fhir_gateway): + """The predict route handler should wrap a float from a sync function into a FHIR resource.""" + + @fhir_gateway.predict(resource=RiskAssessment) + def simple_prediction(patient_id: str) -> float: + return 0.75 + + # The handler created by _create_route_handler is what we need to test + # It's an async function that wraps the user's sync function + route_handler = fhir_gateway.routes[-1].endpoint + result = await route_handler(id="Patient123", fhir=fhir_gateway) + + assert isinstance(result, RiskAssessment) + assert result.status == "final" + assert result.subject.reference == "Patient/Patient123" + assert result.prediction[0].probabilityDecimal == 0.75 + + +@pytest.mark.asyncio +async def test_predict_handler_wraps_dict_output_from_async_func(fhir_gateway): + """The predict route handler should wrap a dict from an async function into a FHIR resource.""" + + @fhir_gateway.predict(resource=RiskAssessment) + async def complex_prediction(patient_id: str) -> dict: + await asyncio.sleep(0.01) # Simulate async work + return {"score": 0.4, "qualitativeRisk": "low"} + + route_handler = fhir_gateway.routes[-1].endpoint + result = await route_handler(id="Patient456", fhir=fhir_gateway) + + assert isinstance(result, RiskAssessment) + assert result.subject.reference == "Patient/Patient456" + assert result.prediction[0].probabilityDecimal == 0.4 + assert result.prediction[0].qualitativeRisk.text == "low" + + +def test_predict_decorator_in_capability_statement(fhir_gateway): + """CapabilityStatement should include resources with registered predict handlers.""" + + @fhir_gateway.predict(resource=RiskAssessment) + def predict_risk(patient_id: str) -> float: + return 0.5 + + capability = fhir_gateway.build_capability_statement() + resources = capability.rest[0].resource + assert "RiskAssessment" in [r.type for r in resources]