Skip to content

Commit

Permalink
allez
Browse files Browse the repository at this point in the history
  • Loading branch information
julien-c committed Jan 23, 2025
1 parent 8aafe25 commit a941c0b
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 11 deletions.
11 changes: 10 additions & 1 deletion packages/tasks/src/inference-providers.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,13 @@ export const INFERENCE_PROVIDERS = ["hf-inference", "fal-ai", "replicate", "samb

export type InferenceProvider = (typeof INFERENCE_PROVIDERS)[number];

export const HF_HUB_INFERENCE_PROXY_TEMPLATE = `https://huggingface.co/api/inference-proxy/{{PROVIDER}}`;
const HF_HUB_INFERENCE_PROXY_TEMPLATE = `https://huggingface.co/api/inference-proxy/{{PROVIDER}}`;

/**
* URL to set as baseUrl in the OpenAI SDK.
*/
export function openAIbaseUrl(provider: InferenceProvider): string {
return provider === "hf-inference"
? "https://api-inference.huggingface.co/v1/"
: HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", provider);
}
15 changes: 5 additions & 10 deletions packages/tasks/src/snippets/js.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { HF_HUB_INFERENCE_PROXY_TEMPLATE, type InferenceProvider } from "../inference-providers.js";
import { HF_HUB_INFERENCE_PROXY_TEMPLATE, openAIbaseUrl, type InferenceProvider } from "../inference-providers.js";
import type { PipelineType } from "../pipelines.js";
import type { ChatCompletionInputMessage, GenerationParameters } from "../tasks/index.js";
import { stringifyGenerationConfig, stringifyMessages } from "./common.js";
Expand Down Expand Up @@ -51,11 +51,6 @@ export const snippetTextGeneration = (
top_p?: GenerationParameters["top_p"];
}
): InferenceSnippet[] => {
const openAIbaseUrl =
provider === "hf-inference"
? "https://api-inference.huggingface.co/v1/"
: HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", provider);

if (model.tags.includes("conversational")) {
// Conversational model detected, so we display a code snippet that features the Messages API
const streaming = opts?.streaming ?? true;
Expand Down Expand Up @@ -103,7 +98,7 @@ for await (const chunk of stream) {
content: `import { OpenAI } from "openai";
const client = new OpenAI({
baseURL: "${openAIbaseUrl}",
baseURL: "${openAIbaseUrl(provider)}",
apiKey: "${accessToken || `{API_TOKEN}`}"
});
Expand Down Expand Up @@ -147,7 +142,7 @@ console.log(chatCompletion.choices[0].message);`,
content: `import { OpenAI } from "openai";
const client = new OpenAI({
baseURL: "${openAIbaseUrl}",
baseURL: "${openAIbaseUrl(provider)}",
apiKey: "${accessToken || `{API_TOKEN}`}"
});
Expand Down Expand Up @@ -187,8 +182,8 @@ export const snippetZeroShotClassification = (model: ModelDataMinimal, accessTok
}
query({"inputs": ${getModelInputSnippet(
model
)}, "parameters": {"candidate_labels": ["refund", "legal", "faq"]}}).then((response) => {
model
)}, "parameters": {"candidate_labels": ["refund", "legal", "faq"]}}).then((response) => {
console.log(JSON.stringify(response));
});`,
},
Expand Down

0 comments on commit a941c0b

Please sign in to comment.