From f6bb70a597d77913184e9d2049d243b12e948229 Mon Sep 17 00:00:00 2001 From: Julien Chaumond Date: Fri, 24 Jan 2025 14:43:17 +0100 Subject: [PATCH] more slow progress --- packages/tasks/src/snippets/curl.ts | 4 ++ packages/tasks/src/snippets/js.ts | 72 +++++++++++++++++++-------- packages/tasks/src/snippets/python.ts | 31 ++++++++---- packages/tasks/src/snippets/types.ts | 2 +- 4 files changed, 77 insertions(+), 32 deletions(-) diff --git a/packages/tasks/src/snippets/curl.ts b/packages/tasks/src/snippets/curl.ts index bc512764b..4fad12679 100644 --- a/packages/tasks/src/snippets/curl.ts +++ b/packages/tasks/src/snippets/curl.ts @@ -15,6 +15,7 @@ export const snippetBasic = ( } return [ { + client: "curl", content: `\ curl https://api-inference.huggingface.co/models/${model.id} \\ -X POST \\ @@ -55,6 +56,7 @@ export const snippetTextGeneration = ( }; return [ { + client: "curl", content: `curl '${baseUrl}' \\ -H 'Authorization: Bearer ${accessToken || `{API_TOKEN}`}' \\ -H 'Content-Type: application/json' \\ @@ -89,6 +91,7 @@ export const snippetZeroShotClassification = ( } return [ { + client: "curl", content: `curl https://api-inference.huggingface.co/models/${model.id} \\ -X POST \\ -d '{"inputs": ${getModelInputSnippet(model, true)}, "parameters": {"candidate_labels": ["refund", "legal", "faq"]}}' \\ @@ -108,6 +111,7 @@ export const snippetFile = ( } return [ { + client: "curl", content: `curl https://api-inference.huggingface.co/models/${model.id} \\ -X POST \\ --data-binary '@${getModelInputSnippet(model, true, true)}' \\ diff --git a/packages/tasks/src/snippets/js.ts b/packages/tasks/src/snippets/js.ts index 0f2b6ec8d..c48aa00e9 100644 --- a/packages/tasks/src/snippets/js.ts +++ b/packages/tasks/src/snippets/js.ts @@ -5,36 +5,66 @@ import { stringifyGenerationConfig, stringifyMessages } from "./common.js"; import { getModelInputSnippet } from "./inputs.js"; import type { InferenceSnippet, ModelDataMinimal } from "./types.js"; +const HFJS_METHODS: Record = { + "text-classification": "textClassification", + "token-classification": "tokenClassification", + "table-question-answering": "tableQuestionAnswering", + "question-answering": "questionAnswering", + translation: "translation", + summarization: "summarization", + "feature-extraction": "featureExtraction", + "text-generation": "textGeneration", + "text2text-generation": "textGeneration", + "fill-mask": "fillMask", + "sentence-similarity": "sentenceSimilarity", +}; + export const snippetBasic = ( model: ModelDataMinimal, accessToken: string, provider: InferenceProvider ): InferenceSnippet[] => { - if (provider !== "hf-inference") { - return []; - } return [ + ...(model.pipeline_tag && model.pipeline_tag in HFJS_METHODS + ? [ + { + client: "huggingface.js", + content: `\ +import { HfInference } from "@huggingface/inference"; + +const client = new HfInference("${accessToken || `{API_TOKEN}`}"); + +const image = await client.${HFJS_METHODS[model.pipeline_tag]}({ + model: "${model.id}", + inputs: ${getModelInputSnippet(model)}, + provider: "${provider}", +}); +`, + }, + ] + : []), { client: "fetch", - content: `async function query(data) { - const response = await fetch( - "https://api-inference.huggingface.co/models/${model.id}", - { - headers: { - Authorization: "Bearer ${accessToken || `{API_TOKEN}`}", - "Content-Type": "application/json", - }, - method: "POST", - body: JSON.stringify(data), - } - ); - const result = await response.json(); - return result; + content: `\ +async function query(data) { + const response = await fetch( + "https://api-inference.huggingface.co/models/${model.id}", + { + headers: { + Authorization: "Bearer ${accessToken || `{API_TOKEN}`}", + "Content-Type": "application/json", + }, + method: "POST", + body: JSON.stringify(data), } - - query({"inputs": ${getModelInputSnippet(model)}}).then((response) => { - console.log(JSON.stringify(response)); - });`, + ); + const result = await response.json(); + return result; +} + +query({"inputs": ${getModelInputSnippet(model)}}).then((response) => { + console.log(JSON.stringify(response)); +});`, }, ]; }; diff --git a/packages/tasks/src/snippets/python.ts b/packages/tasks/src/snippets/python.ts index 8203a1bb0..53f5da8e8 100644 --- a/packages/tasks/src/snippets/python.ts +++ b/packages/tasks/src/snippets/python.ts @@ -1,4 +1,4 @@ -import { openAIbaseUrl, type InferenceProvider } from "../inference-providers.js"; +import { HF_HUB_INFERENCE_PROXY_TEMPLATE, openAIbaseUrl, type InferenceProvider } from "../inference-providers.js"; import type { PipelineType } from "../pipelines.js"; import type { ChatCompletionInputMessage, GenerationParameters } from "../tasks/index.js"; import { stringifyGenerationConfig, stringifyMessages } from "./common.js"; @@ -125,6 +125,7 @@ print(completion.choices[0].message)`, export const snippetZeroShotClassification = (model: ModelDataMinimal): InferenceSnippet[] => { return [ { + client: "requests", content: `def query(payload): response = requests.post(API_URL, headers=headers, json=payload) return response.json() @@ -140,6 +141,7 @@ output = query({ export const snippetZeroShotImageClassification = (model: ModelDataMinimal): InferenceSnippet[] => { return [ { + client: "requests", content: `def query(data): with open(data["image_path"], "rb") as f: img = f.read() @@ -161,6 +163,7 @@ output = query({ export const snippetBasic = (model: ModelDataMinimal): InferenceSnippet[] => { return [ { + client: "requests", content: `def query(payload): response = requests.post(API_URL, headers=headers, json=payload) return response.json() @@ -175,6 +178,7 @@ output = query({ export const snippetFile = (model: ModelDataMinimal): InferenceSnippet[] => { return [ { + client: "requests", content: `def query(filename): with open(filename, "rb") as f: data = f.read() @@ -248,6 +252,7 @@ image = Image.open(io.BytesIO(image_bytes))`, export const snippetTabular = (model: ModelDataMinimal): InferenceSnippet[] => { return [ { + client: "requests", content: `def query(payload): response = requests.post(API_URL, headers=headers, json=payload) return response.content @@ -258,17 +263,14 @@ export const snippetTabular = (model: ModelDataMinimal): InferenceSnippet[] => { ]; }; -export const snippetTextToAudio = ( - model: ModelDataMinimal, - accessToken: string, - provider: InferenceProvider -): InferenceSnippet[] => { +export const snippetTextToAudio = (model: ModelDataMinimal): InferenceSnippet[] => { // Transformers TTS pipeline and api-inference-community (AIC) pipeline outputs are diverged // with the latest update to inference-api (IA). // Transformers IA returns a byte object (wav file), whereas AIC returns wav and sampling_rate. if (model.library_name === "transformers") { return [ { + client: "requests", content: `def query(payload): response = requests.post(API_URL, headers=headers, json=payload) return response.content @@ -284,6 +286,7 @@ Audio(audio_bytes)`, } else { return [ { + client: "requests", content: `def query(payload): response = requests.post(API_URL, headers=headers, json=payload) return response.json() @@ -302,6 +305,7 @@ Audio(audio, rate=sampling_rate)`, export const snippetDocumentQuestionAnswering = (model: ModelDataMinimal): InferenceSnippet[] => { return [ { + client: "requests", content: `def query(payload): with open(payload["image"], "rb") as f: img = f.read() @@ -372,17 +376,24 @@ export function getPythonInferenceSnippet( ? pythonSnippets[model.pipeline_tag]?.(model, accessToken, provider) ?? [] : []; + const baseUrl = + provider === "hf-inference" + ? `https://api-inference.huggingface.co/models/${model.id}` + : HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", provider); + return snippets.map((snippet) => { return { ...snippet, - content: snippet.content.includes("requests") - ? `import requests + content: + snippet.client === "requests" + ? `\ +import requests -API_URL = "https://api-inference.huggingface.co/models/${model.id}" +API_URL = "${baseUrl}" headers = {"Authorization": ${accessToken ? `"Bearer ${accessToken}"` : `f"Bearer {API_TOKEN}"`}} ${snippet.content}` - : snippet.content, + : snippet.content, }; }); } diff --git a/packages/tasks/src/snippets/types.ts b/packages/tasks/src/snippets/types.ts index f6bdc65be..20098c3b5 100644 --- a/packages/tasks/src/snippets/types.ts +++ b/packages/tasks/src/snippets/types.ts @@ -12,6 +12,6 @@ export type ModelDataMinimal = Pick< export interface InferenceSnippet { content: string; - client?: string; // for instance: `client` could be `huggingface_hub` or `openai` client for Python snippets + client: string; // for instance: `client` could be `huggingface_hub` or `openai` client for Python snippets setup?: string; // Optional setup / installation steps }