From e37ad350d34abe9cd6387f13a93e4481d2ef3266 Mon Sep 17 00:00:00 2001 From: Bhabaranjan Panigrahi <112833588+bhaba-ranjan@users.noreply.github.com> Date: Fri, 22 Mar 2024 20:00:21 -0400 Subject: [PATCH 1/2] feat: added client for predibase --- src/llmperf/common.py | 5 +- src/llmperf/ray_clients/predibase_client.py | 127 ++++++++++++++++++++ 2 files changed, 131 insertions(+), 1 deletion(-) create mode 100644 src/llmperf/ray_clients/predibase_client.py diff --git a/src/llmperf/common.py b/src/llmperf/common.py index 3efefa1..42659f1 100644 --- a/src/llmperf/common.py +++ b/src/llmperf/common.py @@ -5,10 +5,11 @@ ) from llmperf.ray_clients.sagemaker_client import SageMakerClient from llmperf.ray_clients.vertexai_client import VertexAIClient +from .ray_clients.predibase_client import PrediBaseClient from llmperf.ray_llm_client import LLMClient -SUPPORTED_APIS = ["openai", "anthropic", "litellm"] +SUPPORTED_APIS = ["openai", "anthropic", "litellm", "predibase"] def construct_clients(llm_api: str, num_clients: int) -> List[LLMClient]: @@ -28,6 +29,8 @@ def construct_clients(llm_api: str, num_clients: int) -> List[LLMClient]: clients = [SageMakerClient.remote() for _ in range(num_clients)] elif llm_api == "vertexai": clients = [VertexAIClient.remote() for _ in range(num_clients)] + elif llm_api == "predibase": + clients = [PrediBaseClient.remote() for _ in range(num_clients)] elif llm_api in SUPPORTED_APIS: clients = [LiteLLMClient.remote() for _ in range(num_clients)] else: diff --git a/src/llmperf/ray_clients/predibase_client.py b/src/llmperf/ray_clients/predibase_client.py new file mode 100644 index 0000000..08212f8 --- /dev/null +++ b/src/llmperf/ray_clients/predibase_client.py @@ -0,0 +1,127 @@ +import io +import json +import os +import time +from typing import Any, Dict + +import ray +import requests + +from llmperf.ray_llm_client import LLMClient +from llmperf.models import RequestConfig +from llmperf import common_metrics + + + +@ray.remote +class PrediBaseClient(LLMClient): + + def llm_request(self, request_config: RequestConfig) -> Dict[str, Any]: + prompt = request_config.prompt + prompt, prompt_len = prompt + + if not request_config.sampling_params: + raise ValueError("Set sampling_params to set the parameters in request body.") + else: + request_config.sampling_params['max_new_tokens'] = request_config.sampling_params.pop('max_tokens') + + body = { + "inputs": prompt, + "parameters": request_config.sampling_params + } + + time_to_next_token = [] + tokens_received = 0 + ttft = 0 + error_response_code = -1 + generated_text = "" + error_msg = "" + output_throughput = 0 + total_request_time = 0 + + metrics = {} + + metrics[common_metrics.ERROR_CODE] = None + metrics[common_metrics.ERROR_MSG] = "" + + start_time = time.monotonic() + most_recent_received_token_time = time.monotonic() + + address = os.environ.get("PREDIBASE_API_BASE") + key = os.environ.get("PREDIBASE_API_KEY") + + + if not address: + raise ValueError("the environment variable PREDIBASE_API_BASE must be set.") + + headers = {'Content-Type': 'application/json'} + if not key: + print(f"Warning: PREDIBASE_API_KEY is not set.") + else: + headers["Authorization"] = f"Bearer {key}" + + if not address: + raise ValueError("No host provided.") + if not address.endswith("/"): + address = address + "/" + address += "generate_stream" + + try: + with requests.post( + address, + json=body, + stream=True, + timeout=180, + headers=headers, + ) as response: + if response.status_code != 200: + error_msg = response.text + error_response_code = response.status_code + response.raise_for_status() + + for chunk in response.iter_lines(chunk_size=None): + chunk = chunk.strip() + + if not chunk: + continue + stem = "data:" + chunk = chunk[len(stem) :] + if chunk == b"[DONE]": + continue + tokens_received += 1 + data = json.loads(chunk) + if "error" in data: + error_msg = data["error"] + raise RuntimeError(error_msg) + + delta = data["token"] + if delta.get("text", None): + if not ttft: + ttft = time.monotonic() - start_time + time_to_next_token.append(ttft) + else: + time_to_next_token.append( + time.monotonic() - most_recent_received_token_time + ) + most_recent_received_token_time = time.monotonic() + generated_text += delta["text"] + + total_request_time = time.monotonic() - start_time + output_throughput = tokens_received / total_request_time + + except Exception as e: + metrics[common_metrics.ERROR_MSG] = error_msg + metrics[common_metrics.ERROR_CODE] = error_response_code + print(f"Warning Or Error: {e}") + print(error_response_code) + + metrics[common_metrics.INTER_TOKEN_LAT] = sum(time_to_next_token) #This should be same as metrics[common_metrics.E2E_LAT]. Leave it here for now + metrics[common_metrics.TTFT] = ttft + metrics[common_metrics.E2E_LAT] = total_request_time + metrics[common_metrics.REQ_OUTPUT_THROUGHPUT] = output_throughput + metrics[common_metrics.NUM_TOTAL_TOKENS] = tokens_received + prompt_len + metrics[common_metrics.NUM_OUTPUT_TOKENS] = tokens_received + metrics[common_metrics.NUM_INPUT_TOKENS] = prompt_len + + return metrics, generated_text, request_config + From f5223fba40904acab3dc81cbdc0eec15c3e2623c Mon Sep 17 00:00:00 2001 From: Bhabaranjan Panigrahi <112833588+bhaba-ranjan@users.noreply.github.com> Date: Fri, 22 Mar 2024 20:14:28 -0400 Subject: [PATCH 2/2] feat: added client for predibase --- src/llmperf/common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/llmperf/common.py b/src/llmperf/common.py index 42659f1..7af546b 100644 --- a/src/llmperf/common.py +++ b/src/llmperf/common.py @@ -5,7 +5,7 @@ ) from llmperf.ray_clients.sagemaker_client import SageMakerClient from llmperf.ray_clients.vertexai_client import VertexAIClient -from .ray_clients.predibase_client import PrediBaseClient +from llmperf.ray_clients.predibase_client import PrediBaseClient from llmperf.ray_llm_client import LLMClient