Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,14 +1,26 @@
import io
import json
import os
from typing import TypedDict

from aws_lambda_typing import context as lambda_context
from aws_lambda_typing import events as lambda_events

import lambda_handler
from augmentation.models import TTCAugmenterConfig
from botocore.client import BaseClient
from augmentation.models.application import TTCAugmenterOutput
from augmentation.services.eicr_augmenter import EICRAugmenter
from shared_models import TTCAugmenterInput

# Environment variables
S3_BUCKET = os.getenv("S3_BUCKET", "dibbs-text-to-code")
AUGMENTED_EICR_PREFIX = os.getenv("AUGMENTED_EICR_PREFIX", "AugmentationEICRV2/")
AUGMENTATION_METADATA_PREFIX = os.getenv("AUGMENTATION_METADATA_PREFIX", "AugmentationMetadata/")

# Cache S3 client to reuse across Lambda invocations
_cached_s3_client: BaseClient | None = None


class HandlerResponse(TypedDict):
"""Response from the AWS Lambda handler."""
Expand All @@ -24,6 +36,12 @@ def handler(event: lambda_events.SQSEvent, context: lambda_context.Context) -> H
:param context: The AWS Lambda context object.
:return: A dictionary containing the results of the augmentation and any batch item failures.
"""
global _cached_s3_client # noqa: PLW0603

if _cached_s3_client is None:
_cached_s3_client = lambda_handler.create_s3_client()
s3_client = _cached_s3_client

results: list[dict[str, object]] = []
batch_item_failures: list[dict[str, str]] = []

Expand Down Expand Up @@ -64,6 +82,9 @@ def handler(event: lambda_events.SQSEvent, context: lambda_context.Context) -> H
metadata=metadata,
)

# Save augmented eICR and metadata to S3
_save_augmentation_outputs(augmenter_input.eicr_id, output, s3_client)

results.append(
{
"messageId": message_id,
Expand All @@ -85,3 +106,26 @@ def handler(event: lambda_events.SQSEvent, context: lambda_context.Context) -> H
"results": results,
"batchItemFailures": batch_item_failures,
}


def _save_augmentation_outputs(
eicr_id: str, output: TTCAugmenterOutput, s3_client: BaseClient
) -> None:
"""Save augmented eICR and metadata to S3.

:param eicr_id: The eICR identifier.
:param output: The augmentation output containing the augmented eICR and metadata.
:param s3_client: The S3 client to use for uploading files.
"""
lambda_handler.put_file(
file_obj=io.BytesIO(output.augmented_eicr.encode("utf-8")),
bucket_name=S3_BUCKET,
object_key=f"{AUGMENTED_EICR_PREFIX}{eicr_id}",
s3_client=s3_client,
)
lambda_handler.put_file(
file_obj=io.BytesIO(output.metadata.model_dump_json().encode("utf-8")),
bucket_name=S3_BUCKET,
object_key=f"{AUGMENTATION_METADATA_PREFIX}{eicr_id}",
s3_client=s3_client,
)
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
import json
from unittest.mock import MagicMock
from unittest.mock import patch

import pytest

from augmentation.models import Metadata
from augmentation_lambda import lambda_function
Expand Down Expand Up @@ -30,7 +34,18 @@ def augment(self) -> Metadata:
)


def test_handler_returns_success_result(mocker) -> None:
@pytest.fixture(autouse=True)
def mock_s3_client():
"""Mock the S3 client and put_file for all tests."""
lambda_function._cached_s3_client = MagicMock()
with patch.object(lambda_function, "lambda_handler") as mock_handler:
mock_handler.create_s3_client.return_value = MagicMock()
mock_handler.put_file = MagicMock()
yield mock_handler
lambda_function._cached_s3_client = None


def test_handler_returns_success_result(mocker, mock_s3_client) -> None:
"""Tests that the handler returns a successful result when the augmenter runs without errors.

:param mocker: The pytest-mock fixture for mocking objects.
Expand Down Expand Up @@ -82,7 +97,49 @@ def test_handler_returns_success_result(mocker) -> None:
}


def test_handler_uses_provided_config(mocker) -> None:
def test_handler_saves_outputs_to_s3(mocker, mock_s3_client) -> None:
"""Tests that the handler writes augmented eICR and metadata to S3.

:param mocker: The pytest-mock fixture for mocking objects.
"""
mocker.patch.object(lambda_function, "EICRAugmenter", FakeAugmenter)

event = {
"Records": [
{
"messageId": "message-s3",
"body": json.dumps(
{
"eicr_id": "test-eicr-id",
"eicr": "<ClinicalDocument />",
"nonstandard_codes": [],
}
),
}
]
}

lambda_function.handler(event, None)

# Verify put_file was called once for augmented eICR and once for metadata
expected_put_file_calls = 2
assert mock_s3_client.put_file.call_count == expected_put_file_calls

# First call: augmented eICR
eicr_call = mock_s3_client.put_file.call_args_list[0]
assert eicr_call.kwargs["bucket_name"] == lambda_function.S3_BUCKET
assert eicr_call.kwargs["object_key"] == f"{lambda_function.AUGMENTED_EICR_PREFIX}test-eicr-id"

# Second call: metadata
metadata_call = mock_s3_client.put_file.call_args_list[1]
assert metadata_call.kwargs["bucket_name"] == lambda_function.S3_BUCKET
assert (
metadata_call.kwargs["object_key"]
== f"{lambda_function.AUGMENTATION_METADATA_PREFIX}test-eicr-id"
)


def test_handler_uses_provided_config(mocker, mock_s3_client) -> None:
"""Tests that the handler uses the provided config when creating the augmenter.

:param mocker: The pytest-mock fixture for mocking objects.
Expand Down
11 changes: 8 additions & 3 deletions packages/augmentation/src/augmentation/main.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""I don't think we want this to be in `main.py` but I'm not 100% sure how this will get plumbed with AWS, so this is as good as anywhere for the moment."""

import os
from io import BytesIO

from augmentation.models import Metadata
Expand All @@ -8,6 +9,10 @@
from lambda_handler.lambda_handler import put_file
from shared_models import TTCAugmenterInput

S3_BUCKET = os.getenv("S3_BUCKET", "dibbs-text-to-code")
AUGMENTED_EICR_PREFIX = os.getenv("AUGMENTED_EICR_PREFIX", "AugmentationEICRV2/")
AUGMENTATION_METADATA_PREFIX = os.getenv("AUGMENTATION_METADATA_PREFIX", "AugmentationMetadata/")


def _retrieve_eicr(eicr_id: str) -> str:
return "<ClinicalDocument></ClinicalDocument>"
Expand All @@ -19,15 +24,15 @@ def _retrieve_config() -> TTCAugmenterConfig:

def _save_eicr(eicr: str, eicr_id: str) -> None:
"""Save augmented eICR to S3 bucket."""
put_file(BytesIO(eicr.encode("utf-8")), "augmented_eicrs", eicr_id)
put_file(BytesIO(eicr.encode("utf-8")), S3_BUCKET, f"{AUGMENTED_EICR_PREFIX}{eicr_id}")


def _save_metadata(metadata: Metadata) -> None:
"""Save augmentation metadata to S3 bucket."""
put_file(
BytesIO(metadata.model_dump_json().encode("utf-8")),
"augmentation_metadata",
f"{metadata.augmented_eicr_id}_metadata.json",
S3_BUCKET,
f"{AUGMENTATION_METADATA_PREFIX}{metadata.augmented_eicr_id}",
)


Expand Down
17 changes: 13 additions & 4 deletions packages/lambda-handler/src/lambda_handler/lambda_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,14 +83,17 @@ def create_opensearch_client(aws_auth: AWS4Auth) -> OpenSearch:
)


def get_file_content_from_s3(bucket_name: str, object_key: str) -> str:
def get_file_content_from_s3(
bucket_name: str, object_key: str, s3_client: BaseClient | None = None
) -> str:
"""Extracts the file content from an S3 bucket.

:param bucket_name: The name of the S3 bucket.
:param object_key: The key of the S3 object.
:param s3_client: Optional pre-created S3 client. If None, a new client is created.
:return: The content of the file as a string.
"""
client = create_s3_client()
client = s3_client or create_s3_client()

# Check if object exists
if not check_s3_object_exists(client, bucket_name, object_key):
Expand All @@ -112,14 +115,20 @@ def get_eventbridge_data_from_s3_event(event: lambda_events.EventBridgeEvent) ->
return {"bucket_name": bucket_name, "object_key": object_key}


def put_file(file_obj: typing.BinaryIO, bucket_name: str, object_key: str) -> None:
def put_file(
file_obj: typing.BinaryIO,
bucket_name: str,
object_key: str,
s3_client: BaseClient | None = None,
) -> None:
"""Uploads a file object to a S3 bucket.

:param file_obj: The file object to upload.
:param bucket_name: The name of the S3 bucket to upload to.
:param object_key: The key to assign to the uploaded object in S3.
:param s3_client: Optional pre-created S3 client. If None, a new client is created.
"""
client = create_s3_client()
client = s3_client or create_s3_client()
client.put_object(Body=file_obj, Bucket=bucket_name, Key=object_key)


Expand Down
Loading
Loading