Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions packages/augmentation-lambda/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,13 @@
name = "augmentation-lambda"
version = "0.1.0"
readme = "README.md"
dependencies = ["aws-lambda-typing>=2.20.0"]
dependencies = [
"aws-lambda-powertools>=2.0.0",
"structlog>=24.0.0",
]

[dependency-groups]
dev = []
dev = ["aws-lambda-typing>=2.20.0", "boto3>=1.40.60", "moto"]

[build-system]
requires = ["hatchling"]
Expand Down
Original file line number Diff line number Diff line change
@@ -1,131 +1,198 @@
import io
import json
import os
from typing import TypedDict

from aws_lambda_typing import context as lambda_context
from aws_lambda_typing import events as lambda_events
import structlog
from aws_lambda_powertools.utilities.data_classes import SQSEvent
from aws_lambda_powertools.utilities.data_classes import event_source
from aws_lambda_powertools.utilities.data_classes.sqs_event import SQSRecord
from aws_lambda_powertools.utilities.typing import LambdaContext
from botocore.client import BaseClient

import lambda_handler
from augmentation.models import TTCAugmenterConfig
from augmentation.models.application import TTCAugmenterOutput
from augmentation.services.eicr_augmenter import EICRAugmenter
from shared_models import NonstandardCodeInstance
from shared_models import TTCAugmenterInput

logger = structlog.get_logger()

# Environment variables
S3_BUCKET = os.getenv("S3_BUCKET", "dibbs-text-to-code")
EICR_INPUT_PREFIX = os.getenv("EICR_INPUT_PREFIX", "eCRMessageV2/")
TTC_OUTPUT_PREFIX = os.getenv("TTC_OUTPUT_PREFIX", "TTCAugmentationMetadataV2/")
AUGMENTED_EICR_PREFIX = os.getenv("AUGMENTED_EICR_PREFIX", "AugmentationEICRV2/")
AUGMENTATION_METADATA_PREFIX = os.getenv("AUGMENTATION_METADATA_PREFIX", "AugmentationMetadataV2/")

# Cache S3 client to reuse across Lambda invocations
_cached_s3_client: BaseClient | None = None


class HandlerResponse(TypedDict):
"""Response from the AWS Lambda handler."""

results: list[dict[str, object]]
batchItemFailures: list[dict[str, str]]


def handler(event: lambda_events.SQSEvent, context: lambda_context.Context) -> HandlerResponse:
@event_source(data_class=SQSEvent)
def handler(event: SQSEvent, context: LambdaContext) -> dict:
"""AWS Lambda handler for augmenting eICRs with nonstandard codes.

:param event: The SQS event containing messages with eICRs to augment.
Triggered by S3 events when TTC output objects are created in TTCAugmentationMetadataV2/.
Reads TTC output and original eICR from S3, performs augmentation, and writes results to S3.

:param event: The SQS event containing S3 event data.
:param context: The AWS Lambda context object.
:return: A dictionary containing the results of the augmentation and any batch item failures.
:return: A dictionary containing processing results and any batch item failures.
"""
global _cached_s3_client # noqa: PLW0603

if _cached_s3_client is None:
_cached_s3_client = lambda_handler.create_s3_client()
s3_client = _cached_s3_client

results: list[dict[str, object]] = []
batch_item_failures: list[dict[str, str]] = []
logger.info(f"Received event with {len(event['Records'])} record(s)")

for record in event["Records"]:
message_id = record["messageId"]
failures = []
successes = []

for record in event.records:
try:
payload = json.loads(record["body"])
augmenter_input = TTCAugmenterInput.model_validate(
{
"eicr_id": payload["eicr_id"],
"nonstandard_codes": payload["nonstandard_codes"],
}
)

eicr = payload["eicr"]

# TODO: will need to determine config based on application code when there are multiple applications using the augmentation service. For now, since TTC is the only application, we can directly initialize the config as a TTC config.
config = (
TTCAugmenterConfig.model_validate(payload["config"])
if "config" in payload
else TTCAugmenterConfig()
)

# TODO: in the future, when there are multiple applications using the augmentation service, we will need to determine which augmenter to use based on the application code in the config. For now, since TTC is the only application, we can directly initialize the EICRAugmenter.
augmenter = EICRAugmenter(
document=eicr,
nonstandard_codes=augmenter_input.nonstandard_codes,
config=config,
)

metadata = augmenter.augment()

# TODO: the output of the augmenter will likely need to be modified when there are multiple applications and augmenters, but for now we can directly create a TTC augmenter output.
output = TTCAugmenterOutput(
eicr_id=augmenter_input.eicr_id,
augmented_eicr=augmenter.augmented_xml,
metadata=metadata,
)

# Save augmented eICR and metadata to S3
_save_augmentation_outputs(augmenter_input.eicr_id, output, s3_client)

results.append(
{
"messageId": message_id,
"status": "success",
"result": output.model_dump(),
}
)
except Exception as exc:
batch_item_failures.append({"itemIdentifier": message_id})
results.append(
{
"messageId": message_id,
"status": "error",
"error": str(exc),
}
)

return {
"results": results,
"batchItemFailures": batch_item_failures,
}
_process_record(record, s3_client)
successes.append(record.message_id)
except Exception as e:
logger.exception(f"Error processing record: {e}", message_id=record.message_id)
failures.append({"message_id": record.message_id, "error": str(e)})

return (
{
"statusCode": 200,
"message": "Augmentation processed with some failures!",
"failures": failures,
"num_failure_eicrs": len(failures),
"num_success_eicrs": len(successes),
}
if failures
else {
"statusCode": 200,
"message": "Augmentation processed successfully!",
"num_success_eicrs": len(successes),
}
)


def _process_record(record: SQSRecord, s3_client: BaseClient) -> None:
"""Process a single SQS record containing an S3 event.

:param record: The SQS record with an EventBridge S3 event in the body.
:param s3_client: The S3 client to use for reading and writing files.
"""
if not record.body:
logger.warning("Empty SQS body", message_id=record.message_id)
return

s3_event = json.loads(record.body)

eventbridge_data = lambda_handler.get_eventbridge_data_from_s3_event(s3_event)
object_key = eventbridge_data["object_key"]
bucket_name = eventbridge_data.get("bucket_name") or S3_BUCKET
logger.info(f"Processing S3 Object: s3://{bucket_name}/{object_key}")

persistence_id = lambda_handler.get_persistence_id(object_key, TTC_OUTPUT_PREFIX)
logger.info(f"Extracted persistence_id: {persistence_id}")

ttc_output = _load_ttc_output(persistence_id, s3_client, bucket_name)
original_eicr = _load_original_eicr(persistence_id, s3_client, bucket_name)
nonstandard_codes = _parse_nonstandard_codes(ttc_output)

augmenter_input = TTCAugmenterInput(
eicr_id=persistence_id,
nonstandard_codes=nonstandard_codes,
)

config = TTCAugmenterConfig()
augmenter = EICRAugmenter(
document=original_eicr,
nonstandard_codes=augmenter_input.nonstandard_codes,
config=config,
)

metadata = augmenter.augment()

output = TTCAugmenterOutput(
eicr_id=augmenter_input.eicr_id,
augmented_eicr=augmenter.augmented_xml,
metadata=metadata,
)

_save_augmentation_outputs(persistence_id, output, s3_client, bucket_name)


def _load_ttc_output(persistence_id: str, s3_client: BaseClient, bucket_name: str) -> dict:
"""Load TTC output from S3.

:param persistence_id: The persistence ID for the S3 object key.
:param s3_client: The S3 client.
:param bucket_name: The S3 bucket name.
:return: The parsed TTC output dictionary.
"""
object_key = f"{TTC_OUTPUT_PREFIX}{persistence_id}"
logger.info(f"Retrieving TTC output from s3://{bucket_name}/{object_key}")
content = lambda_handler.get_file_content_from_s3(
bucket_name=bucket_name, object_key=object_key, s3_client=s3_client
)
return json.loads(content)


def _load_original_eicr(persistence_id: str, s3_client: BaseClient, bucket_name: str) -> str:
"""Load original eICR XML from S3.

:param persistence_id: The persistence ID for the S3 object key.
:param s3_client: The S3 client.
:param bucket_name: The S3 bucket name.
:return: The raw eICR XML string.
"""
object_key = f"{EICR_INPUT_PREFIX}{persistence_id}"
logger.info(f"Retrieving eICR from s3://{bucket_name}/{object_key}")
return lambda_handler.get_file_content_from_s3(
bucket_name=bucket_name, object_key=object_key, s3_client=s3_client
)


def _parse_nonstandard_codes(ttc_output: dict) -> list[NonstandardCodeInstance]:
"""Parse nonstandard codes from TTC output.

The TTC Lambda writes NonstandardCodeInstance model dumps to the schematron_errors
field of the TTC output. This function validates and reconstructs them.

:param ttc_output: The TTC output dictionary from S3.
:return: A list of NonstandardCodeInstance objects.
"""
codes = []
for entries in ttc_output.get("schematron_errors", {}).values():
for entry in entries:
if "new_translation" in entry:
codes.append(NonstandardCodeInstance.model_validate(entry))
return codes


def _save_augmentation_outputs(
eicr_id: str, output: TTCAugmenterOutput, s3_client: BaseClient
persistence_id: str,
output: TTCAugmenterOutput,
s3_client: BaseClient,
bucket_name: str,
) -> None:
"""Save augmented eICR and metadata to S3.

:param eicr_id: The eICR identifier.
:param persistence_id: The persistence ID for the S3 object key.
:param output: The augmentation output containing the augmented eICR and metadata.
:param s3_client: The S3 client to use for uploading files.
:param bucket_name: The S3 bucket name to write to.
"""
lambda_handler.put_file(
file_obj=io.BytesIO(output.augmented_eicr.encode("utf-8")),
bucket_name=S3_BUCKET,
object_key=f"{AUGMENTED_EICR_PREFIX}{eicr_id}",
bucket_name=bucket_name,
object_key=f"{AUGMENTED_EICR_PREFIX}{persistence_id}",
s3_client=s3_client,
)
lambda_handler.put_file(
file_obj=io.BytesIO(output.metadata.model_dump_json().encode("utf-8")),
bucket_name=S3_BUCKET,
object_key=f"{AUGMENTATION_METADATA_PREFIX}{eicr_id}",
bucket_name=bucket_name,
object_key=f"{AUGMENTATION_METADATA_PREFIX}{persistence_id}",
s3_client=s3_client,
)
Loading
Loading