Skip to content

Commit

Permalink
more slow progress
Browse files Browse the repository at this point in the history
  • Loading branch information
julien-c committed Jan 24, 2025
1 parent 5b696c0 commit f6bb70a
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 32 deletions.
4 changes: 4 additions & 0 deletions packages/tasks/src/snippets/curl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ export const snippetBasic = (
}
return [
{
client: "curl",
content: `\
curl https://api-inference.huggingface.co/models/${model.id} \\
-X POST \\
Expand Down Expand Up @@ -55,6 +56,7 @@ export const snippetTextGeneration = (
};
return [
{
client: "curl",
content: `curl '${baseUrl}' \\
-H 'Authorization: Bearer ${accessToken || `{API_TOKEN}`}' \\
-H 'Content-Type: application/json' \\
Expand Down Expand Up @@ -89,6 +91,7 @@ export const snippetZeroShotClassification = (
}
return [
{
client: "curl",
content: `curl https://api-inference.huggingface.co/models/${model.id} \\
-X POST \\
-d '{"inputs": ${getModelInputSnippet(model, true)}, "parameters": {"candidate_labels": ["refund", "legal", "faq"]}}' \\
Expand All @@ -108,6 +111,7 @@ export const snippetFile = (
}
return [
{
client: "curl",
content: `curl https://api-inference.huggingface.co/models/${model.id} \\
-X POST \\
--data-binary '@${getModelInputSnippet(model, true, true)}' \\
Expand Down
72 changes: 51 additions & 21 deletions packages/tasks/src/snippets/js.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,36 +5,66 @@ import { stringifyGenerationConfig, stringifyMessages } from "./common.js";
import { getModelInputSnippet } from "./inputs.js";
import type { InferenceSnippet, ModelDataMinimal } from "./types.js";

const HFJS_METHODS: Record<string, string> = {
"text-classification": "textClassification",
"token-classification": "tokenClassification",
"table-question-answering": "tableQuestionAnswering",
"question-answering": "questionAnswering",
translation: "translation",
summarization: "summarization",
"feature-extraction": "featureExtraction",
"text-generation": "textGeneration",
"text2text-generation": "textGeneration",
"fill-mask": "fillMask",
"sentence-similarity": "sentenceSimilarity",
};

export const snippetBasic = (
model: ModelDataMinimal,
accessToken: string,
provider: InferenceProvider
): InferenceSnippet[] => {
if (provider !== "hf-inference") {
return [];
}
return [
...(model.pipeline_tag && model.pipeline_tag in HFJS_METHODS
? [
{
client: "huggingface.js",
content: `\
import { HfInference } from "@huggingface/inference";
const client = new HfInference("${accessToken || `{API_TOKEN}`}");
const image = await client.${HFJS_METHODS[model.pipeline_tag]}({
model: "${model.id}",
inputs: ${getModelInputSnippet(model)},
provider: "${provider}",
});
`,
},
]
: []),
{
client: "fetch",
content: `async function query(data) {
const response = await fetch(
"https://api-inference.huggingface.co/models/${model.id}",
{
headers: {
Authorization: "Bearer ${accessToken || `{API_TOKEN}`}",
"Content-Type": "application/json",
},
method: "POST",
body: JSON.stringify(data),
}
);
const result = await response.json();
return result;
content: `\
async function query(data) {
const response = await fetch(
"https://api-inference.huggingface.co/models/${model.id}",
{
headers: {
Authorization: "Bearer ${accessToken || `{API_TOKEN}`}",
"Content-Type": "application/json",
},
method: "POST",
body: JSON.stringify(data),
}
query({"inputs": ${getModelInputSnippet(model)}}).then((response) => {
console.log(JSON.stringify(response));
});`,
);
const result = await response.json();
return result;
}
query({"inputs": ${getModelInputSnippet(model)}}).then((response) => {
console.log(JSON.stringify(response));
});`,
},
];
};
Expand Down
31 changes: 21 additions & 10 deletions packages/tasks/src/snippets/python.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { openAIbaseUrl, type InferenceProvider } from "../inference-providers.js";
import { HF_HUB_INFERENCE_PROXY_TEMPLATE, openAIbaseUrl, type InferenceProvider } from "../inference-providers.js";
import type { PipelineType } from "../pipelines.js";
import type { ChatCompletionInputMessage, GenerationParameters } from "../tasks/index.js";
import { stringifyGenerationConfig, stringifyMessages } from "./common.js";
Expand Down Expand Up @@ -125,6 +125,7 @@ print(completion.choices[0].message)`,
export const snippetZeroShotClassification = (model: ModelDataMinimal): InferenceSnippet[] => {
return [
{
client: "requests",
content: `def query(payload):
response = requests.post(API_URL, headers=headers, json=payload)
return response.json()
Expand All @@ -140,6 +141,7 @@ output = query({
export const snippetZeroShotImageClassification = (model: ModelDataMinimal): InferenceSnippet[] => {
return [
{
client: "requests",
content: `def query(data):
with open(data["image_path"], "rb") as f:
img = f.read()
Expand All @@ -161,6 +163,7 @@ output = query({
export const snippetBasic = (model: ModelDataMinimal): InferenceSnippet[] => {
return [
{
client: "requests",
content: `def query(payload):
response = requests.post(API_URL, headers=headers, json=payload)
return response.json()
Expand All @@ -175,6 +178,7 @@ output = query({
export const snippetFile = (model: ModelDataMinimal): InferenceSnippet[] => {
return [
{
client: "requests",
content: `def query(filename):
with open(filename, "rb") as f:
data = f.read()
Expand Down Expand Up @@ -248,6 +252,7 @@ image = Image.open(io.BytesIO(image_bytes))`,
export const snippetTabular = (model: ModelDataMinimal): InferenceSnippet[] => {
return [
{
client: "requests",
content: `def query(payload):
response = requests.post(API_URL, headers=headers, json=payload)
return response.content
Expand All @@ -258,17 +263,14 @@ export const snippetTabular = (model: ModelDataMinimal): InferenceSnippet[] => {
];
};

export const snippetTextToAudio = (
model: ModelDataMinimal,
accessToken: string,
provider: InferenceProvider
): InferenceSnippet[] => {
export const snippetTextToAudio = (model: ModelDataMinimal): InferenceSnippet[] => {
// Transformers TTS pipeline and api-inference-community (AIC) pipeline outputs are diverged
// with the latest update to inference-api (IA).
// Transformers IA returns a byte object (wav file), whereas AIC returns wav and sampling_rate.
if (model.library_name === "transformers") {
return [
{
client: "requests",
content: `def query(payload):
response = requests.post(API_URL, headers=headers, json=payload)
return response.content
Expand All @@ -284,6 +286,7 @@ Audio(audio_bytes)`,
} else {
return [
{
client: "requests",
content: `def query(payload):
response = requests.post(API_URL, headers=headers, json=payload)
return response.json()
Expand All @@ -302,6 +305,7 @@ Audio(audio, rate=sampling_rate)`,
export const snippetDocumentQuestionAnswering = (model: ModelDataMinimal): InferenceSnippet[] => {
return [
{
client: "requests",
content: `def query(payload):
with open(payload["image"], "rb") as f:
img = f.read()
Expand Down Expand Up @@ -372,17 +376,24 @@ export function getPythonInferenceSnippet(
? pythonSnippets[model.pipeline_tag]?.(model, accessToken, provider) ?? []
: [];

const baseUrl =
provider === "hf-inference"
? `https://api-inference.huggingface.co/models/${model.id}`
: HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", provider);

return snippets.map((snippet) => {
return {
...snippet,
content: snippet.content.includes("requests")
? `import requests
content:
snippet.client === "requests"
? `\
import requests
API_URL = "https://api-inference.huggingface.co/models/${model.id}"
API_URL = "${baseUrl}"
headers = {"Authorization": ${accessToken ? `"Bearer ${accessToken}"` : `f"Bearer {API_TOKEN}"`}}
${snippet.content}`
: snippet.content,
: snippet.content,
};
});
}
Expand Down
2 changes: 1 addition & 1 deletion packages/tasks/src/snippets/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,6 @@ export type ModelDataMinimal = Pick<

export interface InferenceSnippet {
content: string;
client?: string; // for instance: `client` could be `huggingface_hub` or `openai` client for Python snippets
client: string; // for instance: `client` could be `huggingface_hub` or `openai` client for Python snippets
setup?: string; // Optional setup / installation steps
}

0 comments on commit f6bb70a

Please sign in to comment.