diff --git a/packages/inference/README.md b/packages/inference/README.md index 241ba7a523..88d0328883 100644 --- a/packages/inference/README.md +++ b/packages/inference/README.md @@ -52,6 +52,7 @@ Currently, we support the following providers: - [Hyperbolic](https://hyperbolic.xyz) - [Nebius](https://studio.nebius.ai) - [Novita](https://novita.ai/?utm_source=github_huggingface&utm_medium=github_readme&utm_campaign=link) +- [Nscale](https://nscale.com) - [Replicate](https://replicate.com) - [Sambanova](https://sambanova.ai) - [Together](https://together.xyz) @@ -79,6 +80,7 @@ Only a subset of models are supported when requesting third-party providers. You - [Fireworks AI supported models](https://huggingface.co/api/partners/fireworks-ai/models) - [Hyperbolic supported models](https://huggingface.co/api/partners/hyperbolic/models) - [Nebius supported models](https://huggingface.co/api/partners/nebius/models) +- [Nscale supported models](https://huggingface.co/api/partners/nscale/models) - [Replicate supported models](https://huggingface.co/api/partners/replicate/models) - [Sambanova supported models](https://huggingface.co/api/partners/sambanova/models) - [Together supported models](https://huggingface.co/api/partners/together/models) diff --git a/packages/inference/src/lib/getProviderHelper.ts b/packages/inference/src/lib/getProviderHelper.ts index c9c85d29b3..060cddffb7 100644 --- a/packages/inference/src/lib/getProviderHelper.ts +++ b/packages/inference/src/lib/getProviderHelper.ts @@ -8,6 +8,7 @@ import * as HFInference from "../providers/hf-inference"; import * as Hyperbolic from "../providers/hyperbolic"; import * as Nebius from "../providers/nebius"; import * as Novita from "../providers/novita"; +import * as Nscale from "../providers/nscale"; import * as OpenAI from "../providers/openai"; import type { AudioClassificationTaskHelper, @@ -109,6 +110,10 @@ export const PROVIDERS: Record Nscale model ID here: + * + * https://huggingface.co/api/partners/nscale-cloud/models + * + * This is a publicly available mapping. + * + * If you want to try to run inference for a new model locally before it's registered on huggingface.co, + * you can add it to the dictionary "HARDCODED_MODEL_ID_MAPPING" in consts.ts, for dev purposes. + * + * - If you work at Nscale and want to update this mapping, please use the model mapping API we provide on huggingface.co + * - If you're a community member and want to add a new supported HF model to Nscale, please open an issue on the present repo + * and we will tag Nscale team members. + * + * Thanks! + */ +import type { TextToImageInput } from "@huggingface/tasks"; +import { InferenceOutputError } from "../lib/InferenceOutputError"; +import type { BodyParams } from "../types"; +import { omit } from "../utils/omit"; +import { BaseConversationalTask, TaskProviderHelper, type TextToImageTaskHelper } from "./providerHelper"; + +const NSCALE_API_BASE_URL = "https://inference.api.nscale.com"; + +interface NscaleCloudBase64ImageGeneration { + data: Array<{ + b64_json: string; + }>; +} + +export class NscaleConversationalTask extends BaseConversationalTask { + constructor() { + super("nscale", NSCALE_API_BASE_URL); + } +} + +export class NscaleTextToImageTask extends TaskProviderHelper implements TextToImageTaskHelper { + constructor() { + super("nscale", NSCALE_API_BASE_URL); + } + + preparePayload(params: BodyParams): Record { + return { + ...omit(params.args, ["inputs", "parameters"]), + ...params.args.parameters, + response_format: "b64_json", + prompt: params.args.inputs, + model: params.model, + }; + } + + makeRoute(): string { + return "v1/images/generations"; + } + + async getResponse( + response: NscaleCloudBase64ImageGeneration, + url?: string, + headers?: HeadersInit, + outputType?: "url" | "blob" + ): Promise { + if ( + typeof response === "object" && + "data" in response && + Array.isArray(response.data) && + response.data.length > 0 && + "b64_json" in response.data[0] && + typeof response.data[0].b64_json === "string" + ) { + const base64Data = response.data[0].b64_json; + if (outputType === "url") { + return `data:image/jpeg;base64,${base64Data}`; + } + return fetch(`data:image/jpeg;base64,${base64Data}`).then((res) => res.blob()); + } + + throw new InferenceOutputError("Expected Nscale text-to-image response format"); + } +} diff --git a/packages/inference/src/types.ts b/packages/inference/src/types.ts index 1e465c00fd..032bd41901 100644 --- a/packages/inference/src/types.ts +++ b/packages/inference/src/types.ts @@ -47,6 +47,7 @@ export const INFERENCE_PROVIDERS = [ "hyperbolic", "nebius", "novita", + "nscale", "openai", "replicate", "sambanova", diff --git a/packages/inference/test/InferenceClient.spec.ts b/packages/inference/test/InferenceClient.spec.ts index f59b2cf2eb..2105902bc0 100644 --- a/packages/inference/test/InferenceClient.spec.ts +++ b/packages/inference/test/InferenceClient.spec.ts @@ -1690,4 +1690,65 @@ describe.skip("InferenceClient", () => { }, TIMEOUT ); + describe.concurrent( + "Nscale", + () => { + const client = new InferenceClient(env.HF_NSCALE_KEY ?? "dummy"); + + HARDCODED_MODEL_INFERENCE_MAPPING["nscale"] = { + "meta-llama/Llama-3.1-8B-Instruct": { + hfModelId: "meta-llama/Llama-3.1-8B-Instruct", + providerId: "nscale", + status: "live", + task: "conversational", + }, + "black-forest-labs/FLUX.1-schnell": { + hfModelId: "black-forest-labs/FLUX.1-schnell", + providerId: "flux-schnell", + status: "live", + task: "text-to-image", + }, + }; + + it("chatCompletion", async () => { + const res = await client.chatCompletion({ + model: "meta-llama/Llama-3.1-8B-Instruct", + provider: "nscale", + messages: [{ role: "user", content: "Complete this sentence with words, one plus one is equal " }], + }); + if (res.choices && res.choices.length > 0) { + const completion = res.choices[0].message?.content; + expect(completion).toContain("two"); + } + }); + it("chatCompletion stream", async () => { + const stream = client.chatCompletionStream({ + model: "meta-llama/Llama-3.1-8B-Instruct", + provider: "nscale", + messages: [{ role: "user", content: "Say 'this is a test'" }], + stream: true, + }) as AsyncGenerator; + let fullResponse = ""; + for await (const chunk of stream) { + if (chunk.choices && chunk.choices.length > 0) { + const content = chunk.choices[0].delta?.content; + if (content) { + fullResponse += content; + } + } + } + expect(fullResponse).toBeTruthy(); + expect(fullResponse.length).toBeGreaterThan(0); + }); + it("textToImage", async () => { + const res = await client.textToImage({ + model: "black-forest-labs/FLUX.1-schnell", + provider: "nscale", + inputs: "An astronaut riding a horse", + }); + expect(res).toBeInstanceOf(Blob); + }); + }, + TIMEOUT + ); });