Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
154 changes: 146 additions & 8 deletions healthchain/gateway/fhir/base.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import logging
import inspect
import warnings
import asyncio

from fastapi import Depends, HTTPException, Path, Query
from datetime import datetime
from typing import Any, Callable, Dict, List, Type, TypeVar, Optional
from fastapi.responses import JSONResponse
from fhir.resources.reference import Reference

from fhir.resources.capabilitystatement import CapabilityStatement
from fhir.resources.resource import Resource
Expand Down Expand Up @@ -153,6 +155,16 @@ def build_capability_statement(self) -> CapabilityStatement:
}
)
operation_details.append("Aggregate: Multi-source data aggregation")
elif operation == "predict":
interactions.append(
{
"code": "read",
"documentation": "ML model prediction via REST endpoint",
}
)
operation_details.append(
"Predict: Serve ML model predictions as FHIR resources"
)

if interactions:
resource_name = self._get_resource_name(resource_type)
Expand Down Expand Up @@ -266,6 +278,16 @@ def get_gateway_status(self) -> Dict[str, Any]:
"parameters": ["id (optional)", "sources (optional)"],
}
)
elif operation == "predict":
operation_list.append(
{
"type": "predict",
"endpoint": f"/predict/{resource_name}/{{id}}",
"description": f"Serve ML predictions as {resource_name} resources",
"method": "GET",
"parameters": ["id"],
}
)

if operation_list:
available_operations[resource_name] = operation_list
Expand Down Expand Up @@ -312,15 +334,22 @@ def _register_resource_handler(
resource_type: Type[Resource],
operation: str,
handler: Callable,
**kwargs,
) -> None:
"""Register a custom handler for a resource operation."""
resource_name = self._get_resource_name(resource_type)
self._validate_handler_annotations(resource_type, operation, handler)

if resource_type not in self._resource_handlers:
self._resource_handlers[resource_type] = {}

# Store the handler function
self._resource_handlers[resource_type][operation] = handler

resource_name = self._get_resource_name(resource_type)
# Store any additional decorator kwargs
if kwargs.get("decorator_kwargs"):
self._resource_handlers[resource_type][f"{operation}_kwargs"] = kwargs["decorator_kwargs"]

logger.debug(
f"Registered {operation} handler for {resource_name}: {handler.__name__}"
)
Expand All @@ -334,7 +363,7 @@ def _validate_handler_annotations(
handler: Callable,
) -> None:
"""Validate that handler annotations match the decorator resource type."""
if operation != "transform":
if operation not in ["transform", "predict"]:
return

try:
Expand All @@ -347,7 +376,7 @@ def _validate_handler_annotations(
)
return

if return_annotation != resource_type:
if operation == "transform" and return_annotation != resource_type:
raise TypeError(
f"Handler {handler.__name__} return type ({return_annotation}) "
f"doesn't match decorator resource type ({resource_type})"
Expand All @@ -374,6 +403,12 @@ def _register_operation_route(
path = f"/aggregate/{resource_name}"
summary = f"Aggregate {resource_name}"
description = f"Aggregate {resource_name} resources from multiple sources"
elif operation == "predict":
path = f"/predict/{resource_name}/{{id}}"
summary = f"Predict using {resource_name}"
description = (
f"Generate a {resource_name} resource using a registered ML model"
)
else:
raise ValueError(f"Unsupported operation: {operation}")

Expand All @@ -398,11 +433,16 @@ def _create_route_handler(
"""Create a route handler for the given resource type and operation."""
get_self_gateway = self._get_gateway_dependency()

def _execute_handler(fhir: "BaseFHIRGateway", *args) -> Any:
async def _execute_handler(fhir: "BaseFHIRGateway", *args) -> Any:
"""Common handler execution logic with error handling."""
handler_func = fhir._resource_handlers[resource_type][operation]
try:
handler_func = fhir._resource_handlers[resource_type][operation]
result = handler_func(*args)
# Await if the handler is async
if asyncio.iscoroutinefunction(handler_func):
result = await handler_func(*args)
else:
result = handler_func(*args)

return result
except Exception as e:
logger.error(f"Error in {operation} handler: {str(e)}")
Expand All @@ -418,7 +458,12 @@ async def handler(
fhir: "BaseFHIRGateway" = Depends(get_self_gateway),
):
"""Transform a resource with registered handler."""
return _execute_handler(fhir, id, source)
result = await _execute_handler(fhir, id, source)
# For predict, wrap the result in the FHIR resource
if operation == "predict":
# This part is now inside the route handler to access decorator kwargs
result = fhir._wrap_prediction(resource_type, id, result)
return result

elif operation == "aggregate":

Expand All @@ -430,13 +475,72 @@ async def handler(
fhir: "BaseFHIRGateway" = Depends(get_self_gateway),
):
"""Aggregate resources with registered handler."""
return _execute_handler(fhir, id, sources)
result = await _execute_handler(fhir, id, sources)
return result

elif operation == "predict":
# Retrieve kwargs passed to the decorator
decorator_kwargs = self._resource_handlers[resource_type].get(
"predict_kwargs", {}
)

async def handler(
id: str = Path(..., description="Patient ID to run prediction for"),
fhir: "BaseFHIRGateway" = Depends(get_self_gateway),
):
"""Generate a prediction resource with a registered handler."""
result = await _execute_handler(fhir, id)
# Wrap the prediction using decorator-provided kwargs
return fhir._wrap_prediction(
resource_type, id, result, **decorator_kwargs
)

else:
raise ValueError(f"Unsupported operation: {operation}")

return handler

def _wrap_prediction(
self,
resource_type: Type[Resource],
patient_id: str,
prediction_output: Any,
status: str = "final",
) -> Resource:
"""Wrap a raw prediction output into a FHIR resource."""
resource_name = self._get_resource_name(resource_type)

if resource_name == "RiskAssessment":
prediction_data = {}
if isinstance(prediction_output, float):
prediction_data["probabilityDecimal"] = prediction_output
elif isinstance(prediction_output, dict):
# Assuming keys like 'score', 'qualitativeRisk', etc.
if "score" in prediction_output:
prediction_data["probabilityDecimal"] = prediction_output["score"]
if "qualitativeRisk" in prediction_output:
prediction_data["qualitativeRisk"] = prediction_output[
"qualitativeRisk"
]
# The fhir.resource model expects a CodeableConcept, not a string.
prediction_data["qualitativeRisk"] = {
"coding": [{"display": prediction_output["qualitativeRisk"]}],
"text": prediction_output["qualitativeRisk"],
}

elif not isinstance(prediction_output, (float, dict)):
raise TypeError(
f"Prediction function must return a float or dict, but returned {type(prediction_output)}"
)

return resource_type(
status=status,
subject=Reference(reference=f"Patient/{patient_id}"),
prediction=[prediction_data],
)
raise NotImplementedError(f"Prediction for {resource_name} not implemented.")


def add_source(self, name: str, connection_string: str) -> None:
"""
Add a FHIR data source using connection string with OAuth2.0 flow.
Expand Down Expand Up @@ -469,6 +573,40 @@ def decorator(handler: Callable):

return decorator

def predict(self, resource: Type[Resource], status: str = "final", **kwargs):
"""
Decorator to simplify ML model deployment as FHIR endpoints.

Wraps a function that returns a prediction score (float) or dictionary,
and automatically constructs the specified FHIR resource.

Currently, only `RiskAssessment` is fully supported.

Args:
resource: The FHIR resource type to create (e.g., RiskAssessment).
status: The status to set on the created FHIR resource. Defaults to "final",
which is a spec-compliant value for RiskAssessment.
**kwargs: Additional fields to set on the created resource.

Example:
@fhir.predict(resource=RiskAssessment)
def predict_sepsis_risk(patient_id: str) -> float: # The patient_id is passed from the URL
# Your model logic here
return 0.85 # High risk
"""

def decorator(handler: Callable):
# The user-provided handler is registered for the 'predict' operation
self._register_resource_handler(
resource, "predict", handler, decorator_kwargs={"status": status}
)

# The actual endpoint handler is created by _register_operation_route
# which calls _create_route_handler, which wraps our logic.
return handler

return decorator

def transform(self, resource_type: Type[Resource]):
"""
Decorator for custom transformation functions. Must return the same resource type.
Expand Down
91 changes: 91 additions & 0 deletions tests/gateway/test_base_fhir_gateway.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import pytest
from unittest.mock import Mock, patch, AsyncMock
from typing import Dict, Any, List
import asyncio

from fhir.resources.patient import Patient
from fhir.resources.observation import Observation
from fhir.resources.riskassessment import RiskAssessment

from healthchain.gateway.fhir import FHIRGateway, AsyncFHIRGateway

Expand Down Expand Up @@ -222,3 +224,92 @@ def test_resource_name_extraction(fhir_gateway):
"""_get_resource_name correctly extracts resource names from types."""
assert fhir_gateway._get_resource_name(Patient) == "Patient"
assert fhir_gateway._get_resource_name(Observation) == "Observation"


@pytest.mark.asyncio
async def test_predict_handler_raises_for_invalid_output_type(fhir_gateway):
"""The predict route handler should raise a TypeError for unsupported return types."""

@fhir_gateway.predict(resource=RiskAssessment)
def invalid_prediction(patient_id: str) -> list:
return [0.1, 0.9] # Invalid return type

route_handler = fhir_gateway.routes[-1].endpoint

with pytest.raises(TypeError, match="Prediction function must return a float or dict"):
await route_handler(id="Patient789", fhir=fhir_gateway)


def test_predict_decorator_registers_handler_and_route(fhir_gateway):
"""The @predict decorator should register a 'predict' handler and create a route."""
initial_routes = len(fhir_gateway.routes)

@fhir_gateway.predict(resource=RiskAssessment)
def predict_risk(patient_id: str) -> float:
return 0.5

# Check handler registration
assert "predict" in fhir_gateway._resource_handlers[RiskAssessment]
assert (
fhir_gateway._resource_handlers[RiskAssessment]["predict"]
== predict_risk
)
# Check that decorator kwargs are stored
assert "predict_kwargs" in fhir_gateway._resource_handlers[RiskAssessment]
assert fhir_gateway._resource_handlers[RiskAssessment]["predict_kwargs"] == {
"status": "final"
}

# Check route creation
assert len(fhir_gateway.routes) == initial_routes + 1
new_route = fhir_gateway.routes[-1]
assert new_route.path == f"{fhir_gateway.prefix}/predict/RiskAssessment/{{id}}"


@pytest.mark.asyncio
async def test_predict_handler_wraps_float_output(fhir_gateway):
"""The predict route handler should wrap a float from a sync function into a FHIR resource."""

@fhir_gateway.predict(resource=RiskAssessment)
def simple_prediction(patient_id: str) -> float:
return 0.75

# The handler created by _create_route_handler is what we need to test
# It's an async function that wraps the user's sync function
route_handler = fhir_gateway.routes[-1].endpoint
result = await route_handler(id="Patient123", fhir=fhir_gateway)

assert isinstance(result, RiskAssessment)
assert result.status == "final"
assert result.subject.reference == "Patient/Patient123"
assert result.prediction[0].probabilityDecimal == 0.75


@pytest.mark.asyncio
async def test_predict_handler_wraps_dict_output_from_async_func(fhir_gateway):
"""The predict route handler should wrap a dict from an async function into a FHIR resource."""

@fhir_gateway.predict(resource=RiskAssessment)
async def complex_prediction(patient_id: str) -> dict:
await asyncio.sleep(0.01) # Simulate async work
return {"score": 0.4, "qualitativeRisk": "low"}

route_handler = fhir_gateway.routes[-1].endpoint
result = await route_handler(id="Patient456", fhir=fhir_gateway)

assert isinstance(result, RiskAssessment)
assert result.subject.reference == "Patient/Patient456"
assert result.prediction[0].probabilityDecimal == 0.4
assert result.prediction[0].qualitativeRisk.text == "low"


def test_predict_decorator_in_capability_statement(fhir_gateway):
"""CapabilityStatement should include resources with registered predict handlers."""

@fhir_gateway.predict(resource=RiskAssessment)
def predict_risk(patient_id: str) -> float:
return 0.5

capability = fhir_gateway.build_capability_statement()
resources = capability.rest[0].resource
assert "RiskAssessment" in [r.type for r in resources]