Skip to content

Commit 85cc8c0

Browse files
committed
✨ Support for Featherless.ai as inference provider.
1 parent 68c6201 commit 85cc8c0

File tree

4 files changed

+120
-1
lines changed

4 files changed

+120
-1
lines changed

packages/inference/src/lib/getProviderHelper.ts

+5-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ import * as Cohere from "../providers/cohere";
44
import * as FalAI from "../providers/fal-ai";
55
import * as Fireworks from "../providers/fireworks-ai";
66
import * as HFInference from "../providers/hf-inference";
7-
7+
import * as FeatherlessAI from "../providers/featherless-ai";
88
import * as Hyperbolic from "../providers/hyperbolic";
99
import * as Nebius from "../providers/nebius";
1010
import * as Novita from "../providers/novita";
@@ -62,6 +62,10 @@ export const PROVIDERS: Record<InferenceProvider, Partial<Record<InferenceTask,
6262
"text-to-video": new FalAI.FalAITextToVideoTask(),
6363
"automatic-speech-recognition": new FalAI.FalAIAutomaticSpeechRecognitionTask(),
6464
},
65+
"featherless-ai": {
66+
conversational: new FeatherlessAI.FeatherlessAIConversationalTask(),
67+
"text-generation": new FeatherlessAI.FeatherlessAITextGenerationTask(),
68+
},
6569
"hf-inference": {
6670
"text-to-image": new HFInference.HFInferenceTextToImageTask(),
6771
conversational: new HFInference.HFInferenceConversationalTask(),
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
import { BaseConversationalTask, BaseTextGenerationTask } from "./providerHelper";
2+
import type { ChatCompletionOutput, TextGenerationOutputFinishReason } from "@huggingface/tasks";
3+
import { InferenceOutputError } from "../lib/InferenceOutputError";
4+
5+
interface FeatherlessAITextCompletionOutput extends Omit<ChatCompletionOutput, "choices"> {
6+
choices: Array<{
7+
text: string;
8+
finish_reason: TextGenerationOutputFinishReason;
9+
seed: number;
10+
logprobs: unknown;
11+
index: number;
12+
}>;
13+
}
14+
15+
const FEATHERLESS_API_BASE_URL = "https://api.featherless.ai";
16+
17+
export class FeatherlessAIConversationalTask extends BaseConversationalTask {
18+
constructor() {
19+
super("featherless-ai", FEATHERLESS_API_BASE_URL);
20+
}
21+
}
22+
23+
export class FeatherlessAITextGenerationTask extends BaseTextGenerationTask {
24+
constructor() {
25+
super("featherless-ai", FEATHERLESS_API_BASE_URL);
26+
}
27+
28+
override preparePayload(params: BodyParams): Record<string, unknown> {
29+
return {
30+
model: params.model,
31+
...params.args,
32+
...params.args.parameters,
33+
prompt: params.args.inputs,
34+
};
35+
}
36+
37+
override async getResponse(response: FeatherlessAITextCompletionOutput): Promise<TextGenerationOutput> {
38+
if (
39+
typeof response === "object" &&
40+
"choices" in response &&
41+
Array.isArray(response?.choices) &&
42+
typeof response?.model === "string"
43+
) {
44+
const completion = response.choices[0];
45+
return {
46+
generated_text: completion.text,
47+
};
48+
}
49+
throw new InferenceOutputError("Expected Together text generation response format");
50+
}
51+
}

packages/inference/src/types.ts

+1
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ export const INFERENCE_PROVIDERS = [
4141
"cerebras",
4242
"cohere",
4343
"fal-ai",
44+
"featherless-ai",
4445
"fireworks-ai",
4546
"hf-inference",
4647
"hyperbolic",

packages/inference/test/InferenceClient.spec.ts

+63
Original file line numberDiff line numberDiff line change
@@ -826,6 +826,69 @@ describe.concurrent("InferenceClient", () => {
826826
TIMEOUT
827827
);
828828

829+
describe.concurrent(
830+
"Featherless",
831+
() => {
832+
HARDCODED_MODEL_ID_MAPPING['featherless-ai'] = {
833+
"meta-llama/Llama-3.1-8B": "meta-llama/Meta-Llama-3.1-8B",
834+
"meta-llama/Llama-3.1-8B-Instruct": "meta-llama/Meta-Llama-3.1-8B-Instruct",
835+
};
836+
837+
it("chatCompletion", async () => {
838+
const res = await chatCompletion({
839+
accessToken: env.HF_FEATHERLESS_KEY ?? "dummy",
840+
model: "meta-llama/Llama-3.1-8B-Instruct",
841+
provider: "featherless-ai",
842+
messages: [{ role: "user", content: "Complete this sentence with words, one plus one is equal " }],
843+
temperature: 0.1,
844+
});
845+
846+
expect(res).toBeDefined();
847+
expect(res.choices).toBeDefined();
848+
expect(res.choices?.length).toBeGreaterThan(0);
849+
850+
if (res.choices && res.choices.length > 0) {
851+
const completion = res.choices[0].message?.content;
852+
expect(completion).toBeDefined();
853+
expect(typeof completion).toBe("string");
854+
expect(completion).toContain("two");
855+
}
856+
});
857+
858+
it("chatCompletion stream", async () => {
859+
const stream = chatCompletionStream({
860+
accessToken: env.HF_FEATHERLESS_KEY ?? "dummy",
861+
model: "meta-llama/Llama-3.1-8B-Instruct",
862+
provider: "featherless-ai",
863+
messages: [{ role: "user", content: "Complete the equation 1 + 1 = , just the answer" }],
864+
}) as AsyncGenerator<ChatCompletionStreamOutput>;
865+
let out = "";
866+
for await (const chunk of stream) {
867+
if (chunk.choices && chunk.choices.length > 0) {
868+
out += chunk.choices[0].delta.content;
869+
}
870+
}
871+
expect(out).toContain("2");
872+
});
873+
874+
it("textGeneration", async () => {
875+
const res = await textGeneration({
876+
accessToken: env.HF_FEATHERLESS_KEY ?? "dummy",
877+
model: "meta-llama/Llama-3.1-8B",
878+
provider: "featherless-ai",
879+
inputs: "Paris is a city of ",
880+
parameters: {
881+
temperature: 0,
882+
top_p: 0.01,
883+
max_tokens: 10,
884+
},
885+
});
886+
expect(res).toMatchObject({ generated_text: "2.2 million people, and it is the" });
887+
});
888+
},
889+
TIMEOUT
890+
);
891+
829892
describe.concurrent(
830893
"Replicate",
831894
() => {

0 commit comments

Comments
 (0)