diff --git a/spacy_llm/models/__init__.py b/spacy_llm/models/__init__.py index c1427009..94160c42 100644 --- a/spacy_llm/models/__init__.py +++ b/spacy_llm/models/__init__.py @@ -1,9 +1,10 @@ from .hf import dolly_hf, openllama_hf, stablelm_hf from .langchain import query_langchain -from .rest import anthropic, cohere, noop, openai, palm +from .rest import anthropic, bedrock, cohere, noop, openai, palm __all__ = [ "anthropic", + "bedrock", "cohere", "openai", "dolly_hf", diff --git a/spacy_llm/models/rest/__init__.py b/spacy_llm/models/rest/__init__.py index 96263967..45b061d0 100644 --- a/spacy_llm/models/rest/__init__.py +++ b/spacy_llm/models/rest/__init__.py @@ -1,9 +1,10 @@ -from . import anthropic, azure, base, cohere, noop, openai +from . import anthropic, azure, base, bedrock, cohere, noop, openai __all__ = [ "anthropic", "azure", "base", + "bedrock", "cohere", "openai", "noop", diff --git a/spacy_llm/models/rest/bedrock/__init__.py b/spacy_llm/models/rest/bedrock/__init__.py new file mode 100644 index 00000000..f7f3ec00 --- /dev/null +++ b/spacy_llm/models/rest/bedrock/__init__.py @@ -0,0 +1,4 @@ +from .model import Bedrock +from .registry import bedrock + +__all__ = ["Bedrock", "bedrock"] diff --git a/spacy_llm/models/rest/bedrock/model.py b/spacy_llm/models/rest/bedrock/model.py new file mode 100644 index 00000000..38663483 --- /dev/null +++ b/spacy_llm/models/rest/bedrock/model.py @@ -0,0 +1,224 @@ +import json +import os +import warnings +from enum import Enum +from typing import Any, Dict, Iterable, List, Optional, Tuple + +from ..base import REST + + +class Models(str, Enum): + # Completion models + TITAN_EXPRESS = "amazon.titan-text-express-v1" + TITAN_LITE = "amazon.titan-text-lite-v1" + AI21_JURASSIC_ULTRA = "ai21.j2-ultra-v1" + AI21_JURASSIC_MID = "ai21.j2-mid-v1" + COHERE_COMMAND = "cohere.command-text-v14" + ANTHROPIC_CLAUDE = "anthropic.claude-v2" + ANTHROPIC_CLAUDE_INSTANT = "anthropic.claude-instant-v1" + + +TITAN_PARAMS = ["maxTokenCount", "stopSequences", "temperature", "topP"] +AI21_JURASSIC_PARAMS = [ + "maxTokens", + "temperature", + "topP", + "countPenalty", + "presencePenalty", + "frequencyPenalty", +] +COHERE_PARAMS = ["max_tokens", "temperature"] +ANTHROPIC_PARAMS = [ + "max_tokens_to_sample", + "temperature", + "top_k", + "top_p", + "stop_sequences", +] + + +class Bedrock(REST): + def __init__( + self, + model_id: str, + region: str, + config: Dict[Any, Any], + max_tries: int = 5, + ): + self._region = region + self._model_id = model_id + self._max_tries = max_tries + self.strict = True + self.endpoint = f"https://bedrock-runtime.{self._region}.amazonaws.com" + self._config = {} + + if self._model_id in [Models.TITAN_EXPRESS, Models.TITAN_LITE]: + config_params = TITAN_PARAMS + if self._model_id in [Models.AI21_JURASSIC_ULTRA, Models.AI21_JURASSIC_MID]: + config_params = AI21_JURASSIC_PARAMS + if self._model_id in [Models.COHERE_COMMAND]: + config_params = COHERE_PARAMS + if self._model_id in [Models.ANTHROPIC_CLAUDE_INSTANT, Models.ANTHROPIC_CLAUDE]: + config_params = ANTHROPIC_PARAMS + + for i in config_params: + self._config[i] = config[i] + + super().__init__( + name=model_id, + config=self._config, + max_tries=max_tries, + strict=True, + endpoint="", + interval=3, + max_request_time=30, + ) + + def get_session_kwargs(self) -> Dict[str, Optional[str]]: + + # Fetch and check the credentials + profile = os.getenv("AWS_PROFILE") if not None else "default" + secret_key_id = os.getenv("AWS_ACCESS_KEY_ID") + secret_access_key = os.getenv("AWS_SECRET_ACCESS_KEY") + session_token = os.getenv("AWS_SESSION_TOKEN") + + if profile is None: + warnings.warn( + "Could not find the AWS_PROFILE to access the Amazon Bedrock . Ensure you have an AWS_PROFILE " + "set up by making it available as an environment variable AWS_PROFILE." + ) + + if secret_key_id is None: + warnings.warn( + "Could not find the AWS_ACCESS_KEY_ID to access the Amazon Bedrock . Ensure you have an AWS_ACCESS_KEY_ID " + "set up by making it available as an environment variable AWS_ACCESS_KEY_ID." + ) + + if secret_access_key is None: + warnings.warn( + "Could not find the AWS_SECRET_ACCESS_KEY to access the Amazon Bedrock . Ensure you have an AWS_SECRET_ACCESS_KEY " + "set up by making it available as an environment variable AWS_SECRET_ACCESS_KEY." + ) + + if session_token is None: + warnings.warn( + "Could not find the AWS_SESSION_TOKEN to access the Amazon Bedrock . Ensure you have an AWS_SESSION_TOKEN " + "set up by making it available as an environment variable AWS_SESSION_TOKEN." + ) + + assert secret_key_id is not None + assert secret_access_key is not None + assert session_token is not None + + session_kwargs = { + "profile_name": profile, + "region_name": self._region, + "aws_access_key_id": secret_key_id, + "aws_secret_access_key": secret_access_key, + "aws_session_token": session_token, + } + return session_kwargs + + def __call__(self, prompts: Iterable[str]) -> Iterable[str]: + api_responses: List[str] = [] + prompts = list(prompts) + + def _request(json_data: str) -> str: + try: + import boto3 + except ImportError as err: + warnings.warn( + "To use Bedrock, you need to install boto3. Use pip install boto3 " + ) + raise err + from botocore.config import Config + + session_kwargs = self.get_session_kwargs() + session = boto3.Session(**session_kwargs) + api_config = Config(retries=dict(max_attempts=self._max_tries)) + bedrock = session.client(service_name="bedrock-runtime", config=api_config) + accept = "application/json" + contentType = "application/json" + r = bedrock.invoke_model( + body=json_data, + modelId=self._model_id, + accept=accept, + contentType=contentType, + ) + if self._model_id in [Models.TITAN_EXPRESS, Models.TITAN_LITE]: + responses = json.loads(r["body"].read().decode())["results"][0][ + "outputText" + ] + elif self._model_id in [ + Models.AI21_JURASSIC_ULTRA, + Models.AI21_JURASSIC_MID, + ]: + responses = json.loads(r["body"].read().decode())["completions"][0][ + "data" + ]["text"] + elif self._model_id in [Models.COHERE_COMMAND]: + responses = json.loads(r["body"].read().decode())["generations"][0][ + "text" + ] + elif self._model_id in [ + Models.ANTHROPIC_CLAUDE_INSTANT, + Models.ANTHROPIC_CLAUDE, + ]: + responses = json.loads(r["body"].read().decode())["completion"] + + return responses + + for prompt in prompts: + if self._model_id in [Models.TITAN_EXPRESS, Models.TITAN_LITE]: + responses = _request( + json.dumps( + {"inputText": prompt, "textGenerationConfig": self._config} + ) + ) + elif self._model_id in [ + Models.AI21_JURASSIC_ULTRA, + Models.AI21_JURASSIC_MID, + ]: + responses = _request(json.dumps({"prompt": prompt, **self._config})) + elif self._model_id in [Models.COHERE_COMMAND]: + responses = _request(json.dumps({"prompt": prompt, **self._config})) + elif self._model_id in [ + Models.ANTHROPIC_CLAUDE_INSTANT, + Models.ANTHROPIC_CLAUDE, + ]: + responses = _request( + json.dumps( + {"prompt": f"\n\nHuman: {prompt}\n\nAssistant:", **self._config} + ) + ) + api_responses.append(responses) + + return api_responses + + def _verify_auth(self) -> None: + try: + import boto3 + from botocore.exceptions import NoCredentialsError + + session_kwargs = self.get_session_kwargs() + session = boto3.Session(**session_kwargs) + bedrock = session.client(service_name="bedrock") + bedrock.list_foundation_models() + except NoCredentialsError: + raise NoCredentialsError + + @property + def credentials(self) -> Dict[str, Optional[str]]: # type: ignore + return self.get_session_kwargs() + + @classmethod + def get_model_names(self) -> Tuple[str, ...]: + return ( + "amazon.titan-text-express-v1", + "amazon.titan-text-lite-v1", + "ai21.j2-ultra-v1", + "ai21.j2-mid-v1", + "cohere.command-text-v14", + "anthropic.claude-v2", + "anthropic.claude-instant-v1", + ) diff --git a/spacy_llm/models/rest/bedrock/registry.py b/spacy_llm/models/rest/bedrock/registry.py new file mode 100644 index 00000000..aeaa4534 --- /dev/null +++ b/spacy_llm/models/rest/bedrock/registry.py @@ -0,0 +1,50 @@ +from typing import Any, Callable, Dict, Iterable, List + +from confection import SimpleFrozenDict + +from ....registry import registry +from .model import Bedrock, Models + +_DEFAULT_RETRIES: int = 5 +_DEFAULT_TEMPERATURE: float = 0.0 +_DEFAULT_MAX_TOKEN_COUNT: int = 512 +_DEFAULT_TOP_P: int = 1 +_DEFAULT_TOP_K: int = 250 +_DEFAULT_STOP_SEQUENCES: List[str] = [] +_DEFAULT_COUNT_PENALTY: Dict[str, Any] = {"scale": 0} +_DEFAULT_PRESENCE_PENALTY: Dict[str, Any] = {"scale": 0} +_DEFAULT_FREQUENCY_PENALTY: Dict[str, Any] = {"scale": 0} +_DEFAULT_MAX_TOKEN_TO_SAMPLE: int = 300 + + +@registry.llm_models("spacy.Bedrock.v1") +def bedrock( + region: str, + model_id: Models = Models.TITAN_EXPRESS, + config: Dict[Any, Any] = SimpleFrozenDict( + # Params for Titan models + temperature=_DEFAULT_TEMPERATURE, + maxTokenCount=_DEFAULT_MAX_TOKEN_COUNT, + stopSequences=_DEFAULT_STOP_SEQUENCES, + topP=_DEFAULT_TOP_P, + # Params for Jurassic models + maxTokens=_DEFAULT_MAX_TOKEN_COUNT, + countPenalty=_DEFAULT_COUNT_PENALTY, + presencePenalty=_DEFAULT_PRESENCE_PENALTY, + frequencyPenalty=_DEFAULT_FREQUENCY_PENALTY, + stop_sequences=_DEFAULT_STOP_SEQUENCES, + # Params for Cohere models + max_tokens=_DEFAULT_MAX_TOKEN_COUNT, + # Params for Anthropic models + max_tokens_to_sample=_DEFAULT_MAX_TOKEN_TO_SAMPLE, + top_k=_DEFAULT_TOP_K, + top_p=_DEFAULT_TOP_P, + ), + max_tries: int = _DEFAULT_RETRIES, +) -> Callable[[Iterable[str]], Iterable[str]]: + """Returns Bedrock instance for 'amazon-titan-express' model using boto3 to prompt API. + model_id (ModelId): ID of the deployed model (titan-express) + region (str): Specify the AWS region for the service + config (Dict[Any, Any]): LLM config passed on to the model's initialization. + """ + return Bedrock(model_id=model_id, region=region, config=config, max_tries=max_tries) diff --git a/usage_examples/ner_v3_titan/README.md b/usage_examples/ner_v3_titan/README.md new file mode 100644 index 00000000..e08bcffb --- /dev/null +++ b/usage_examples/ner_v3_titan/README.md @@ -0,0 +1,76 @@ +# Using Titan Express Model from Amazon Bedrock for Named Entity Recognition (NER) + + +This example shows how you can use a model from OpenAI for Named Entity Recognition (NER). +The NER prompt is based on the [PromptNER](https://arxiv.org/abs/2305.15444) paper and +utilizes Chain-of-Thought reasoning to extract named entities. + +First, create a new credentials from AWS Console +Record the secret key and make sure this is available as an environmental +variable: + +```sh +export AWS_ACCESS_KEY_ID="" +export AWS_SECRET_ACCESS_KEY="" +export AWS_SESSION_TOKEN="" +``` + +Then, you can run the pipeline on a sample text via: + + +```sh +python run_pipeline.py [TEXT] [PATH TO CONFIG] [PATH TO FILE WITH EXAMPLES] +``` + +For example: + +```sh +python run_pipeline.py \ + ""Sriracha sauce goes really well with hoisin stir fry, but you should add it after you use the wok." \ + ./fewshot.cfg + ./examples.json +``` + +This example assings labels for DISH, INGREDIENT, and EQUIPMENT. + +You can change around the labels and examples for your use case. +You can find the few-shot examples in the +`examples.json` file. Feel free to change and update it to your liking. +We also support other file formats, including `yml` and `jsonl` for these examples. + + +### Negative examples + +While not required, The Chain-of-Thought reasoning for the `spacy.NER.v3` task +works best in our experience when both positive and negative examples are provided. + +This prompts the Language model with concrete examples of what **is not** an entity +for your use case. + +Here's an example that helps define the INGREDIENT label for the LLM. + +```json +[ + { + "text": "You can't get a great chocolate flavor with carob.", + "spans": [ + { + "text": "chocolate", + "is_entity": false, + "label": "==NONE==", + "reason": "is a flavor in this context, not an ingredient" + }, + { + "text": "carob", + "is_entity": true, + "label": "INGREDIENT", + "reason": "is an ingredient to add chocolate flavor" + } + ] + } + ... +] +``` + +In this example, "chocolate" is not an ingredient even though it could be in other contexts. +We explain that via the "reason" property of this example. diff --git a/usage_examples/ner_v3_titan/__init__.py b/usage_examples/ner_v3_titan/__init__.py new file mode 100644 index 00000000..06fab2f6 --- /dev/null +++ b/usage_examples/ner_v3_titan/__init__.py @@ -0,0 +1,3 @@ +from .run_pipeline import run_pipeline + +__all__ = ["run_pipeline"] diff --git a/usage_examples/ner_v3_titan/examples.json b/usage_examples/ner_v3_titan/examples.json new file mode 100644 index 00000000..f38f5e92 --- /dev/null +++ b/usage_examples/ner_v3_titan/examples.json @@ -0,0 +1,36 @@ +[ + { + "text": "You can't get a great chocolate flavor with carob.", + "spans": [ + { + "text": "chocolate", + "is_entity": false, + "label": "==NONE==", + "reason": "is a flavor in this context, not an ingredient" + }, + { + "text": "carob", + "is_entity": true, + "label": "INGREDIENT", + "reason": "is an ingredient to add chocolate flavor" + } + ] + }, + { + "text": "You can probably sand-blast it if it's an anodized aluminum pan", + "spans": [ + { + "text": "sand-blast", + "is_entity": false, + "label": "==NONE==", + "reason": "is a cleaning technique, not some kind of equipment" + }, + { + "text": "anodized aluminum pan", + "is_entity": true, + "label": "EQUIPMENT", + "reason": "is a piece of cooking equipment, anodized is included since it describes the type of pan" + } + ] + } +] diff --git a/usage_examples/ner_v3_titan/fewshot.cfg b/usage_examples/ner_v3_titan/fewshot.cfg new file mode 100644 index 00000000..9ced17e2 --- /dev/null +++ b/usage_examples/ner_v3_titan/fewshot.cfg @@ -0,0 +1,34 @@ +[paths] +examples = null + +[nlp] +lang = "en" +pipeline = ["llm"] + +[components] + +[components.llm] +factory = "llm" + +[components.llm.task] +@llm_tasks = "spacy.NER.v3" +labels = ["DISH", "INGREDIENT", "EQUIPMENT"] +description = Entities are the names food dishes, + ingredients, and any kind of cooking equipment. + Adjectives, verbs, adverbs are not entities. + Pronouns are not entities. + +[components.llm.task.label_definitions] +DISH = "Known food dishes, e.g. Lobster Ravioli, garlic bread" +INGREDIENT = "Individual parts of a food dish, including herbs and spices." +EQUIPMENT = "Any kind of cooking equipment. e.g. oven, cooking pot, grill" + +[components.llm.task.examples] +@misc = "spacy.FewShotReader.v1" +path = "${paths.examples}" + +[components.llm.model] +@llm_models = "spacy.Bedrock.v1" +region = +model_id = amazon.titan-text-express-v1 + diff --git a/usage_examples/ner_v3_titan/run_pipeline.py b/usage_examples/ner_v3_titan/run_pipeline.py new file mode 100644 index 00000000..ac182bbc --- /dev/null +++ b/usage_examples/ner_v3_titan/run_pipeline.py @@ -0,0 +1,29 @@ +from pathlib import Path + +import typer +from wasabi import msg + +from spacy_llm.util import assemble + +Arg = typer.Argument +Opt = typer.Option + + +def run_pipeline( + # fmt: off + text: str = Arg("", help="Text to perform Named Entity Recognition on."), + config_path: Path = Arg(..., help="Path to the configuration file to use."), + examples_path: Path = Arg(..., help="Path to the examples file to use."), + verbose: bool = Opt(False, "--verbose", "-v", help="Show extra information."), + # fmt: on +): + msg.text(f"Loading config from {config_path}", show=verbose) + nlp = assemble(config_path, overrides={"paths.examples": str(examples_path)}) + doc = nlp(text) + + msg.text(f"Text: {doc.text}") + msg.text(f"Entities: {[(ent.text, ent.label_) for ent in doc.ents]}") + + +if __name__ == "__main__": + typer.run(run_pipeline)