diff --git a/packages/augmentation-lambda/src/augmentation_lambda/lambda_function.py b/packages/augmentation-lambda/src/augmentation_lambda/lambda_function.py index c1a40fdb..f09c6866 100644 --- a/packages/augmentation-lambda/src/augmentation_lambda/lambda_function.py +++ b/packages/augmentation-lambda/src/augmentation_lambda/lambda_function.py @@ -1,14 +1,26 @@ +import io import json +import os from typing import TypedDict from aws_lambda_typing import context as lambda_context from aws_lambda_typing import events as lambda_events +import lambda_handler from augmentation.models import TTCAugmenterConfig +from botocore.client import BaseClient from augmentation.models.application import TTCAugmenterOutput from augmentation.services.eicr_augmenter import EICRAugmenter from shared_models import TTCAugmenterInput +# Environment variables +S3_BUCKET = os.getenv("S3_BUCKET", "dibbs-text-to-code") +AUGMENTED_EICR_PREFIX = os.getenv("AUGMENTED_EICR_PREFIX", "AugmentationEICRV2/") +AUGMENTATION_METADATA_PREFIX = os.getenv("AUGMENTATION_METADATA_PREFIX", "AugmentationMetadata/") + +# Cache S3 client to reuse across Lambda invocations +_cached_s3_client: BaseClient | None = None + class HandlerResponse(TypedDict): """Response from the AWS Lambda handler.""" @@ -24,6 +36,12 @@ def handler(event: lambda_events.SQSEvent, context: lambda_context.Context) -> H :param context: The AWS Lambda context object. :return: A dictionary containing the results of the augmentation and any batch item failures. """ + global _cached_s3_client # noqa: PLW0603 + + if _cached_s3_client is None: + _cached_s3_client = lambda_handler.create_s3_client() + s3_client = _cached_s3_client + results: list[dict[str, object]] = [] batch_item_failures: list[dict[str, str]] = [] @@ -64,6 +82,9 @@ def handler(event: lambda_events.SQSEvent, context: lambda_context.Context) -> H metadata=metadata, ) + # Save augmented eICR and metadata to S3 + _save_augmentation_outputs(augmenter_input.eicr_id, output, s3_client) + results.append( { "messageId": message_id, @@ -85,3 +106,26 @@ def handler(event: lambda_events.SQSEvent, context: lambda_context.Context) -> H "results": results, "batchItemFailures": batch_item_failures, } + + +def _save_augmentation_outputs( + eicr_id: str, output: TTCAugmenterOutput, s3_client: BaseClient +) -> None: + """Save augmented eICR and metadata to S3. + + :param eicr_id: The eICR identifier. + :param output: The augmentation output containing the augmented eICR and metadata. + :param s3_client: The S3 client to use for uploading files. + """ + lambda_handler.put_file( + file_obj=io.BytesIO(output.augmented_eicr.encode("utf-8")), + bucket_name=S3_BUCKET, + object_key=f"{AUGMENTED_EICR_PREFIX}{eicr_id}", + s3_client=s3_client, + ) + lambda_handler.put_file( + file_obj=io.BytesIO(output.metadata.model_dump_json().encode("utf-8")), + bucket_name=S3_BUCKET, + object_key=f"{AUGMENTATION_METADATA_PREFIX}{eicr_id}", + s3_client=s3_client, + ) diff --git a/packages/augmentation-lambda/tests/test_augmentation_lambda_function.py b/packages/augmentation-lambda/tests/test_augmentation_lambda_function.py index cee906ec..aa9bc052 100644 --- a/packages/augmentation-lambda/tests/test_augmentation_lambda_function.py +++ b/packages/augmentation-lambda/tests/test_augmentation_lambda_function.py @@ -1,4 +1,8 @@ import json +from unittest.mock import MagicMock +from unittest.mock import patch + +import pytest from augmentation.models import Metadata from augmentation_lambda import lambda_function @@ -30,7 +34,18 @@ def augment(self) -> Metadata: ) -def test_handler_returns_success_result(mocker) -> None: +@pytest.fixture(autouse=True) +def mock_s3_client(): + """Mock the S3 client and put_file for all tests.""" + lambda_function._cached_s3_client = MagicMock() + with patch.object(lambda_function, "lambda_handler") as mock_handler: + mock_handler.create_s3_client.return_value = MagicMock() + mock_handler.put_file = MagicMock() + yield mock_handler + lambda_function._cached_s3_client = None + + +def test_handler_returns_success_result(mocker, mock_s3_client) -> None: """Tests that the handler returns a successful result when the augmenter runs without errors. :param mocker: The pytest-mock fixture for mocking objects. @@ -82,7 +97,49 @@ def test_handler_returns_success_result(mocker) -> None: } -def test_handler_uses_provided_config(mocker) -> None: +def test_handler_saves_outputs_to_s3(mocker, mock_s3_client) -> None: + """Tests that the handler writes augmented eICR and metadata to S3. + + :param mocker: The pytest-mock fixture for mocking objects. + """ + mocker.patch.object(lambda_function, "EICRAugmenter", FakeAugmenter) + + event = { + "Records": [ + { + "messageId": "message-s3", + "body": json.dumps( + { + "eicr_id": "test-eicr-id", + "eicr": "", + "nonstandard_codes": [], + } + ), + } + ] + } + + lambda_function.handler(event, None) + + # Verify put_file was called once for augmented eICR and once for metadata + expected_put_file_calls = 2 + assert mock_s3_client.put_file.call_count == expected_put_file_calls + + # First call: augmented eICR + eicr_call = mock_s3_client.put_file.call_args_list[0] + assert eicr_call.kwargs["bucket_name"] == lambda_function.S3_BUCKET + assert eicr_call.kwargs["object_key"] == f"{lambda_function.AUGMENTED_EICR_PREFIX}test-eicr-id" + + # Second call: metadata + metadata_call = mock_s3_client.put_file.call_args_list[1] + assert metadata_call.kwargs["bucket_name"] == lambda_function.S3_BUCKET + assert ( + metadata_call.kwargs["object_key"] + == f"{lambda_function.AUGMENTATION_METADATA_PREFIX}test-eicr-id" + ) + + +def test_handler_uses_provided_config(mocker, mock_s3_client) -> None: """Tests that the handler uses the provided config when creating the augmenter. :param mocker: The pytest-mock fixture for mocking objects. diff --git a/packages/augmentation/src/augmentation/main.py b/packages/augmentation/src/augmentation/main.py deleted file mode 100644 index c587a847..00000000 --- a/packages/augmentation/src/augmentation/main.py +++ /dev/null @@ -1,43 +0,0 @@ -"""I don't think we want this to be in `main.py` but I'm not 100% sure how this will get plumbed with AWS, so this is as good as anywhere for the moment.""" - -from io import BytesIO - -from augmentation.models import Metadata -from augmentation.models import TTCAugmenterConfig -from augmentation.services.eicr_augmenter import EICRAugmenter -from lambda_handler.lambda_handler import put_file -from shared_models import TTCAugmenterInput - - -def _retrieve_eicr(eicr_id: str) -> str: - return "" - - -def _retrieve_config() -> TTCAugmenterConfig: - return TTCAugmenterConfig() - - -def _save_eicr(eicr: str, eicr_id: str) -> None: - """Save augmented eICR to S3 bucket.""" - put_file(BytesIO(eicr.encode("utf-8")), "augmented_eicrs", eicr_id) - - -def _save_metadata(metadata: Metadata) -> None: - """Save augmentation metadata to S3 bucket.""" - put_file( - BytesIO(metadata.model_dump_json().encode("utf-8")), - "augmentation_metadata", - f"{metadata.augmented_eicr_id}_metadata.json", - ) - - -def augment(input: TTCAugmenterInput) -> None: - """Main entry point for the augmentation service.""" - eicr: str = _retrieve_eicr(input.eicr_id) - config = _retrieve_config() - augmenter = EICRAugmenter(eicr, input.nonstandard_codes, config) - - metadata = augmenter.augment() - - _save_eicr(augmenter.augmented_xml, input.eicr_id) - _save_metadata(metadata) diff --git a/packages/lambda-handler/src/lambda_handler/lambda_handler.py b/packages/lambda-handler/src/lambda_handler/lambda_handler.py index 7ac35e1c..a269d976 100644 --- a/packages/lambda-handler/src/lambda_handler/lambda_handler.py +++ b/packages/lambda-handler/src/lambda_handler/lambda_handler.py @@ -83,14 +83,17 @@ def create_opensearch_client(aws_auth: AWS4Auth) -> OpenSearch: ) -def get_file_content_from_s3(bucket_name: str, object_key: str) -> str: +def get_file_content_from_s3( + bucket_name: str, object_key: str, s3_client: BaseClient | None = None +) -> str: """Extracts the file content from an S3 bucket. :param bucket_name: The name of the S3 bucket. :param object_key: The key of the S3 object. + :param s3_client: Optional pre-created S3 client. If None, a new client is created. :return: The content of the file as a string. """ - client = create_s3_client() + client = s3_client or create_s3_client() # Check if object exists if not check_s3_object_exists(client, bucket_name, object_key): @@ -112,14 +115,20 @@ def get_eventbridge_data_from_s3_event(event: lambda_events.EventBridgeEvent) -> return {"bucket_name": bucket_name, "object_key": object_key} -def put_file(file_obj: typing.BinaryIO, bucket_name: str, object_key: str) -> None: +def put_file( + file_obj: typing.BinaryIO, + bucket_name: str, + object_key: str, + s3_client: BaseClient | None = None, +) -> None: """Uploads a file object to a S3 bucket. :param file_obj: The file object to upload. :param bucket_name: The name of the S3 bucket to upload to. :param object_key: The key to assign to the uploaded object in S3. + :param s3_client: Optional pre-created S3 client. If None, a new client is created. """ - client = create_s3_client() + client = s3_client or create_s3_client() client.put_object(Body=file_obj, Bucket=bucket_name, Key=object_key) diff --git a/packages/text-to-code-lambda/src/text_to_code_lambda/lambda_function.py b/packages/text-to-code-lambda/src/text_to_code_lambda/lambda_function.py index 9adbb96e..ec90a08c 100644 --- a/packages/text-to-code-lambda/src/text_to_code_lambda/lambda_function.py +++ b/packages/text-to-code-lambda/src/text_to_code_lambda/lambda_function.py @@ -25,9 +25,10 @@ logger = Logger(service="ttc") # Environment variables +S3_BUCKET = os.getenv("S3_BUCKET", "dibbs-text-to-code") EICR_INPUT_PREFIX = os.getenv("EICR_INPUT_PREFIX", "eCRMessageV2/") SCHEMATRON_ERROR_PREFIX = os.getenv("SCHEMATRON_ERROR_PREFIX", "schematronErrors/") -TTC_INPUT_PREFIX = os.getenv("TTC_INPUT_PREFIX", "TextToCodeSubmission/") +TTC_INPUT_PREFIX = os.getenv("TTC_INPUT_PREFIX", "TextToCodeValidateSubmissionV2/") TTC_OUTPUT_PREFIX = os.getenv("TTC_OUTPUT_PREFIX", "TTCOutput/") TTC_METADATA_PREFIX = os.getenv("TTC_METADATA_PREFIX", "TTCMetadata/") AWS_REGION = os.getenv("AWS_REGION") @@ -118,9 +119,8 @@ def process_record(record: SQSRecord, s3_client: BaseClient, opensearch_client: # Parse the EventBridge S3 event from the SQS message body eventbridge_data = lambda_handler.get_eventbridge_data_from_s3_event(s3_event) - bucket = eventbridge_data["bucket_name"] object_key = eventbridge_data["object_key"] - logger.info(f"Processing S3 Object: s3://{bucket}/{object_key}") + logger.info(f"Processing S3 Object: s3://{S3_BUCKET}/{object_key}") # Extract persistence_id from the RR object key persistence_id = lambda_handler.get_persistence_id(object_key, TTC_INPUT_PREFIX) @@ -129,7 +129,7 @@ def process_record(record: SQSRecord, s3_client: BaseClient, opensearch_client: with logger.append_context_keys( persistence_id=persistence_id, ): - _process_record_pipeline(bucket, persistence_id, s3_client, opensearch_client) + _process_record_pipeline(persistence_id, s3_client, opensearch_client) def _initialize_ttc_outputs(persistence_id: str) -> tuple[dict, dict]: @@ -152,19 +152,19 @@ def _initialize_ttc_outputs(persistence_id: str) -> tuple[dict, dict]: return ttc_output, ttc_metadata_output -def _load_schematron_data_fields(persistence_id: str) -> list: +def _load_schematron_data_fields(persistence_id: str, s3_client: BaseClient) -> list: """Load Schematron errors from S3 and extract relevant fields. :param persistence_id: The persistence ID extracted from the S3 object key + :param s3_client: The S3 client to use for fetching files. :return: The relevant Schematron data fields for TTC processing. """ - # S3 GET Schematron errors - # TODO: Confirm with APHL that the Schematron errors will be stored in the same bucket and follow a consistent naming convention that allows us to derive the Schematron error object key from the persistence_id. - schematron_bucket_name = SCHEMATRON_ERROR_PREFIX.split("/")[0] - logger.info("Loading Schematron errors", s3_key=f"{schematron_bucket_name}{persistence_id}") + object_key = f"{SCHEMATRON_ERROR_PREFIX}{persistence_id}" + logger.info("Loading Schematron errors", s3_key=f"s3://{S3_BUCKET}/{object_key}") schematron_errors = lambda_handler.get_file_content_from_s3( - bucket_name=schematron_bucket_name, - object_key=f"{persistence_id}", + bucket_name=S3_BUCKET, + object_key=object_key, + s3_client=s3_client, ) # Process Schematron errors to identify relevant data fields for TTC processing @@ -172,21 +172,17 @@ def _load_schematron_data_fields(persistence_id: str) -> list: return schematron_processor.get_data_fields_from_schematron_error(schematron_errors) -def _load_original_eicr(bucket: str, persistence_id: str) -> str: +def _load_original_eicr(persistence_id: str, s3_client: BaseClient) -> str: """Load the original eICR from S3. - :param bucket: The name of the S3 bucket :param persistence_id: The persistence ID extracted from the S3 object key + :param s3_client: The S3 client to use for fetching files. :return: The original eICR content. """ - # Construct eICR path: s3://// - logger.info(f"Retrieving eICR from s3://{EICR_INPUT_PREFIX}{persistence_id}") - - # S3 GET eICR - ecr_bucket_name = EICR_INPUT_PREFIX.split("/")[0] - logger.info("Loading eICR", s3_key=f"{ecr_bucket_name}/{persistence_id}") + object_key = f"{EICR_INPUT_PREFIX}{persistence_id}" + logger.info(f"Retrieving eICR from s3://{S3_BUCKET}/{object_key}") original_eicr_content = lambda_handler.get_file_content_from_s3( - bucket_name=bucket, object_key=persistence_id + bucket_name=S3_BUCKET, object_key=object_key, s3_client=s3_client ) logger.info(f"Retrieved eICR content for persistence_id {persistence_id}") return original_eicr_content @@ -286,34 +282,36 @@ def _process_schematron_errors( ttc_metadata_output["schematron_errors"][data_field].append(metadata_error) -def _save_ttc_outputs(persistence_id: str, ttc_output: dict, ttc_metadata_output: dict) -> None: +def _save_ttc_outputs( + persistence_id: str, ttc_output: dict, ttc_metadata_output: dict, s3_client: BaseClient +) -> None: """Save TTC output and metadata output to S3. :param persistence_id: The persistence ID extracted from the S3 object key :param ttc_output: The TTC output dictionary. :param ttc_metadata_output: The TTC metadata output dictionary. + :param s3_client: The S3 client to use for uploading files. """ # Save the TTC output to S3 for the Augmentation Lambda to consume logger.info(f"Saving TTC output to S3 for persistence_id {persistence_id}") - ttc_output_bucket_name = TTC_OUTPUT_PREFIX.split("/")[0] lambda_handler.put_file( file_obj=io.BytesIO(json.dumps(ttc_output, default=str).encode("utf-8")), - bucket_name=ttc_output_bucket_name, - object_key=persistence_id, + bucket_name=S3_BUCKET, + object_key=f"{TTC_OUTPUT_PREFIX}{persistence_id}", + s3_client=s3_client, ) # Save the TTC metadata output for completing model evaluation and analysis of TTC results logger.info(f"Saving TTC metadata output to S3 for persistence_id {persistence_id}") - ttc_metadata_output_bucket_name = TTC_METADATA_PREFIX.split("/")[0] lambda_handler.put_file( file_obj=io.BytesIO(json.dumps(ttc_metadata_output, default=str).encode("utf-8")), - bucket_name=ttc_metadata_output_bucket_name, - object_key=persistence_id, + bucket_name=S3_BUCKET, + object_key=f"{TTC_METADATA_PREFIX}{persistence_id}", + s3_client=s3_client, ) def _process_record_pipeline( - bucket: str, persistence_id: str, s3_client: BaseClient, opensearch_client: OpenSearch, @@ -332,13 +330,14 @@ def _process_record_pipeline( - Creating the output to pass to the Augmentation Lambda and saving it to S3 - Creating the metadata object to save in S3 for analysis of TTC results - :param bucket: The name of the S3 bucket :param persistence_id: The persistence ID extracted from the S3 object key + :param s3_client: The S3 client to use for S3 operations. + :param opensearch_client: The OpenSearch client. """ ttc_output, ttc_metadata_output = _initialize_ttc_outputs(persistence_id) logger.info("Starting TTC processing") - schematron_data_fields = _load_schematron_data_fields(persistence_id) + schematron_data_fields = _load_schematron_data_fields(persistence_id, s3_client) if not schematron_data_fields: logger.warning( @@ -347,15 +346,15 @@ def _process_record_pipeline( ttc_output["message"] = NO_DATA_FIELDS_MESSAGE ttc_metadata_output["reason_for_skipping"] = NO_DATA_FIELDS_MESSAGE logger.info(f"Saving TTC metadata output to S3 for persistence_id {persistence_id}") - ttc_metadata_output_bucket_name = TTC_METADATA_PREFIX.split("/")[0] lambda_handler.put_file( file_obj=io.BytesIO(json.dumps(ttc_metadata_output, default=str).encode("utf-8")), - bucket_name=ttc_metadata_output_bucket_name, - object_key=persistence_id, + bucket_name=S3_BUCKET, + object_key=f"{TTC_METADATA_PREFIX}{persistence_id}", + s3_client=s3_client, ) return ttc_output - original_eicr_content = _load_original_eicr(bucket, persistence_id) + original_eicr_content = _load_original_eicr(persistence_id, s3_client) _populate_eicr_metadata(original_eicr_content, ttc_output, ttc_metadata_output) _process_schematron_errors( original_eicr_content, @@ -364,6 +363,6 @@ def _process_record_pipeline( ttc_output, ttc_metadata_output, ) - _save_ttc_outputs(persistence_id, ttc_output, ttc_metadata_output) + _save_ttc_outputs(persistence_id, ttc_output, ttc_metadata_output, s3_client) return {"statusCode": 200, "message": "TTC processed successfully!"} diff --git a/packages/text-to-code-lambda/tests/conftest.py b/packages/text-to-code-lambda/tests/conftest.py index 2cc386cc..2b7937dc 100644 --- a/packages/text-to-code-lambda/tests/conftest.py +++ b/packages/text-to-code-lambda/tests/conftest.py @@ -11,21 +11,22 @@ from text_to_code_lambda import lambda_function +S3_BUCKET = "dibbs-text-to-code" EICR_INPUT_PREFIX = "eCRMessageV2/" SCHEMATRON_ERROR_PREFIX = "schematronErrors/" -TTC_INPUT_PREFIX = "TextToCodeSubmission/" +TTC_INPUT_PREFIX = "TextToCodeValidateSubmissionV2/" TTC_OUTPUT_PREFIX = "TTCOutput/" TTC_METADATA_PREFIX = "TTCMetadata/" AWS_REGION = "us-east-1" AWS_ACCESS_KEY_ID = "test_access_key_id" AWS_SECRET_ACCESS_KEY = "test_secret_access_key" # noqa: S105 OPENSEARCH_ENDPOINT_URL = "https://test-opensearch-endpoint.com" -TEST_BUCKET_NAME = "test-bucket" TEST_PERSISTENCE_ID = "2025/09/03/1-5f84c7a5-91d7f5c6a2b7c9e08f0d1234" def pytest_configure() -> None: """Configure env variables for pytest.""" + os.environ["S3_BUCKET"] = S3_BUCKET os.environ["EICR_INPUT_PREFIX"] = EICR_INPUT_PREFIX os.environ["SCHEMATRON_ERROR_PREFIX"] = SCHEMATRON_ERROR_PREFIX os.environ["TTC_INPUT_PREFIX"] = TTC_INPUT_PREFIX @@ -54,7 +55,7 @@ def example_s3_event_payload() -> dict: "resources": ["arn:aws:s3:::my-bucket-name"], "detail": { "version": "0", - "bucket": {"name": "eCRMessageV2"}, + "bucket": {"name": S3_BUCKET}, "object": { "key": f"{TTC_INPUT_PREFIX}{TEST_PERSISTENCE_ID}", "size": 1024, @@ -112,25 +113,17 @@ def mock_aws_setup(monkeypatch: pytest.MonkeyPatch) -> boto3.client: monkeypatch.setenv("AWS_ACCESS_KEY_ID", AWS_ACCESS_KEY_ID) monkeypatch.setenv("AWS_SECRET_ACCESS_KEY", AWS_SECRET_ACCESS_KEY) monkeypatch.setenv("OPENSEARCH_ENDPOINT_URL", OPENSEARCH_ENDPOINT_URL) - # Create the fake S3 bucket + # Create the single S3 bucket s3 = boto3.client( "s3", - region_name=os.environ["AWS_REGION"], - aws_access_key_id=os.environ["AWS_ACCESS_KEY_ID"], - aws_secret_access_key=os.environ["AWS_SECRET_ACCESS_KEY"], + region_name=AWS_REGION, + aws_access_key_id=AWS_ACCESS_KEY_ID, + aws_secret_access_key=AWS_SECRET_ACCESS_KEY, ) - s3.create_bucket(Bucket=os.getenv("EICR_INPUT_PREFIX").split("/")[0]) - s3.create_bucket(Bucket=os.getenv("SCHEMATRON_ERROR_PREFIX").split("/")[0]) - s3.create_bucket(Bucket=os.getenv("TTC_INPUT_PREFIX").split("/")[0]) - s3.create_bucket(Bucket=os.getenv("TTC_OUTPUT_PREFIX").split("/")[0]) - s3.create_bucket(Bucket=os.getenv("TTC_METADATA_PREFIX").split("/")[0]) - - # Add convenience attribute for tests - s3.ecr_bucket_name = os.getenv("EICR_INPUT_PREFIX").split("/")[0] - s3.schematron_bucket_name = os.getenv("SCHEMATRON_ERROR_PREFIX").split("/")[0] - s3.ttc_input_bucket_name = os.getenv("TTC_INPUT_PREFIX").split("/")[0] - s3.ttc_output_bucket_name = os.getenv("TTC_OUTPUT_PREFIX").split("/")[0] - s3.ttc_metadata_bucket_name = os.getenv("TTC_METADATA_PREFIX").split("/")[0] + s3.create_bucket(Bucket=S3_BUCKET) + + # Add convenience attributes for tests + s3.bucket_name = S3_BUCKET s3.persistence_id = TEST_PERSISTENCE_ID # Put test Schematron error file in the mock S3 bucket @@ -141,8 +134,8 @@ def mock_aws_setup(monkeypatch: pytest.MonkeyPatch) -> boto3.client: with schematron_path.open() as f: schematron_output = f.read() s3.put_object( - Bucket=s3.schematron_bucket_name, - Key=TEST_PERSISTENCE_ID, + Bucket=S3_BUCKET, + Key=f"{SCHEMATRON_ERROR_PREFIX}{TEST_PERSISTENCE_ID}", Body=schematron_output, ) @@ -151,8 +144,8 @@ def mock_aws_setup(monkeypatch: pytest.MonkeyPatch) -> boto3.client: with ecr_path.open() as f: ecr_message = f.read() s3.put_object( - Bucket=s3.ecr_bucket_name, - Key=TEST_PERSISTENCE_ID, + Bucket=S3_BUCKET, + Key=f"{EICR_INPUT_PREFIX}{TEST_PERSISTENCE_ID}", Body=ecr_message, ) diff --git a/packages/text-to-code-lambda/tests/test_lambda_function.py b/packages/text-to-code-lambda/tests/test_lambda_function.py index e4dfe66f..24b06687 100644 --- a/packages/text-to-code-lambda/tests/test_lambda_function.py +++ b/packages/text-to-code-lambda/tests/test_lambda_function.py @@ -3,6 +3,9 @@ import pytest import lambda_handler +from conftest import S3_BUCKET +from conftest import TTC_METADATA_PREFIX +from conftest import TTC_OUTPUT_PREFIX from text_to_code_lambda import lambda_function EXPECTED_RESULTED_ERRORS = 2 @@ -33,8 +36,8 @@ def test_handler_success(self, example_sqs_event, mock_aws_setup, mock_opensearc # Assert that the TTC output was saved to S3 ttc_output = json.loads( lambda_handler.get_file_content_from_s3( - bucket_name=mock_aws_setup.ttc_output_bucket_name, - object_key=mock_aws_setup.persistence_id, + bucket_name=S3_BUCKET, + object_key=f"{TTC_OUTPUT_PREFIX}{mock_aws_setup.persistence_id}", ) ) assert ttc_output is not None @@ -60,8 +63,8 @@ def test_handler_success(self, example_sqs_event, mock_aws_setup, mock_opensearc # Assert that the TTC metadata output was saved to S3 with the expected content ttc_metadata_output = json.loads( lambda_handler.get_file_content_from_s3( - bucket_name=mock_aws_setup.ttc_metadata_bucket_name, - object_key=mock_aws_setup.persistence_id, + bucket_name=S3_BUCKET, + object_key=f"{TTC_METADATA_PREFIX}{mock_aws_setup.persistence_id}", ) ) assert ttc_metadata_output is not None @@ -135,15 +138,15 @@ def test_handler_saves_metadata_when_no_relevant_schematron_fields( # Assert that the TTC output was not saved to S3 with pytest.raises(FileNotFoundError): lambda_handler.get_file_content_from_s3( - bucket_name=mock_aws_setup.ttc_output_bucket_name, - object_key=mock_aws_setup.persistence_id, + bucket_name=S3_BUCKET, + object_key=f"{TTC_OUTPUT_PREFIX}{mock_aws_setup.persistence_id}", ) # Assert that the TTC metadata output was saved to S3 with the expected content ttc_metadata_output = json.loads( lambda_handler.get_file_content_from_s3( - bucket_name=mock_aws_setup.ttc_metadata_bucket_name, - object_key=mock_aws_setup.persistence_id, + bucket_name=S3_BUCKET, + object_key=f"{TTC_METADATA_PREFIX}{mock_aws_setup.persistence_id}", ) ) assert ttc_metadata_output is not None diff --git a/terraform/README.md b/terraform/README.md index 42a5ad6c..ef23ed25 100644 --- a/terraform/README.md +++ b/terraform/README.md @@ -72,7 +72,7 @@ At runtime, the Lambda runs the real `text_to_code_lambda.lambda_function.handle 4. Generates embeddings and executes KNN queries against OpenSearch 5. Returns standardized code mappings (LOINC/SNOMED) -Environment variables injected at deploy time: `OPENSEARCH_ENDPOINT_URL`, `OPENSEARCH_INDEX`, `REGION`, `BUCKET_NAME`, `RETRIEVER_MODEL_PATH`, `RERANKER_MODEL_PATH`, `EICR_INPUT_PREFIX`, `SCHEMATRON_ERROR_PREFIX`, `TTC_INPUT_PREFIX`, `TTC_OUTPUT_PREFIX`, `TTC_METADATA_PREFIX`. +Environment variables injected at deploy time: `OPENSEARCH_ENDPOINT_URL`, `OPENSEARCH_INDEX`, `REGION`, `S3_BUCKET`, `RETRIEVER_MODEL_PATH`, `RERANKER_MODEL_PATH`, `EICR_INPUT_PREFIX`, `SCHEMATRON_ERROR_PREFIX`, `TTC_INPUT_PREFIX`, `TTC_OUTPUT_PREFIX`, `TTC_METADATA_PREFIX`. ### OpenSearch Ingestion Pipeline (`main.tf`) diff --git a/terraform/_variables.tf b/terraform/_variables.tf index 8cd59af2..2fd196e9 100644 --- a/terraform/_variables.tf +++ b/terraform/_variables.tf @@ -119,7 +119,7 @@ variable "schematron_error_prefix" { variable "ttc_input_prefix" { type = string - default = "TextToCodeSubmission/" + default = "TextToCodeValidateSubmissionV2/" description = "S3 prefix for TTC input submission files" } @@ -135,6 +135,18 @@ variable "ttc_metadata_prefix" { description = "S3 prefix for TTC metadata files" } +variable "augmented_eicr_prefix" { + type = string + default = "AugmentationEICRV2/" + description = "S3 prefix for augmented eICR output files" +} + +variable "augmentation_metadata_prefix" { + type = string + default = "AugmentationMetadata/" + description = "S3 prefix for augmentation metadata files" +} + ### Container Image Variables variable "ttc_lambda_image_tag" { type = string diff --git a/terraform/main.tf b/terraform/main.tf index 4e35b7eb..c7ec1d13 100644 --- a/terraform/main.tf +++ b/terraform/main.tf @@ -260,7 +260,7 @@ resource "aws_lambda_function" "lambda" { OPENSEARCH_ENDPOINT_URL = "https://${aws_opensearch_vpc_endpoint.os_vpc_endpoint.endpoint}" OPENSEARCH_INDEX = var.index_name REGION = var.region - BUCKET_NAME = var.s3_bucket + S3_BUCKET = var.s3_bucket RETRIEVER_MODEL_PATH = "/opt/retriever_model" RERANKER_MODEL_PATH = "/opt/reranker_model" EICR_INPUT_PREFIX = var.eicr_input_prefix @@ -473,7 +473,7 @@ resource "aws_lambda_function" "index_lambda" { OPENSEARCH_ENDPOINT_URL = "https://${aws_opensearch_vpc_endpoint.os_vpc_endpoint.endpoint}" REGION = var.region INDEX_NAME = var.index_name - BUCKET_NAME = var.s3_bucket + S3_BUCKET = var.s3_bucket } }