Skip to content

Add Swarmind as an inference provider #1619

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions packages/inference/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ Currently, we support the following providers:
- [Cohere](https://cohere.com)
- [Cerebras](https://cerebras.ai/)
- [Groq](https://groq.com)
- [Swarmind](https://swarmind.ai)

To send requests to a third-party provider, you have to pass the `provider` parameter to the inference function. The default value of the `provider` parameter is "auto", which will select the first of the providers available for the model, sorted by your preferred order in https://hf.co/settings/inference-providers.

Expand Down
5 changes: 5 additions & 0 deletions packages/inference/src/lib/getProviderHelper.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import * as Novita from "../providers/novita.js";
import * as Nscale from "../providers/nscale.js";
import * as OpenAI from "../providers/openai.js";
import * as OvhCloud from "../providers/ovhcloud.js";
import * as Swarmind from "../providers/swarmind.js"
import type {
AudioClassificationTaskHelper,
AudioToAudioTaskHelper,
Expand Down Expand Up @@ -147,6 +148,10 @@ export const PROVIDERS: Record<InferenceProvider, Partial<Record<InferenceTask,
conversational: new Sambanova.SambanovaConversationalTask(),
"feature-extraction": new Sambanova.SambanovaFeatureExtractionTask(),
},
swarmind: {
conversational: new Swarmind.SwarmindConversationalTask(),
"text-generation": new Swarmind.SwarmindTextGenerationTask(),
},
together: {
"text-to-image": new Together.TogetherTextToImageTask(),
conversational: new Together.TogetherConversationalTask(),
Expand Down
1 change: 1 addition & 0 deletions packages/inference/src/providers/consts.ts
Original file line number Diff line number Diff line change
Expand Up @@ -34,5 +34,6 @@ export const HARDCODED_MODEL_INFERENCE_MAPPING: Record<
ovhcloud: {},
replicate: {},
sambanova: {},
swarmind: {},
together: {},
};
45 changes: 45 additions & 0 deletions packages/inference/src/providers/swarmind.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import {
BaseConversationalTask,
BaseTextGenerationTask,
TaskProviderHelper,
type TextToImageTaskHelper,
} from "./providerHelper.js";

/**
* See the registered mapping of HF model ID => Swarmind model ID here:
*
* https://huggingface.co/api/partners/swarmind/models
*
* This is a publicly available mapping.
*
* If you want to try to run inference for a new model locally before it's registered on huggingface.co,
* you can add it to the dictionary "HARDCODED_MODEL_ID_MAPPING" in consts.ts, for dev purposes.
*
* - If you work at Swarmind and want to update this mapping, please use the model mapping API we provide on huggingface.co
* - If you're a community member and want to add a new supported HF model to Swarmind, please open an issue on the present repo
* and we will tag Swarmind team members.
*
* Thanks!
*/

const SWARMIND_API_BASE_URL = "https://api.swarmind.ai/lai/private";

export class SwarmindTextGenerationTask extends BaseTextGenerationTask {
constructor() {
super("swarmind", SWARMIND_API_BASE_URL);
}

override makeRoute(): string {
return "/v1/chat/completions";
}
}

export class SwarmindConversationalTask extends BaseConversationalTask {
constructor() {
super("swarmind", SWARMIND_API_BASE_URL);
}

override makeRoute(): string {
return "/v1/chat/completions";
}
}
1 change: 1 addition & 0 deletions packages/inference/src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ export const INFERENCE_PROVIDERS = [
"ovhcloud",
"replicate",
"sambanova",
"swarmind",
"together",
] as const;

Expand Down
51 changes: 51 additions & 0 deletions packages/inference/test/InferenceClient.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2107,4 +2107,55 @@ describe.skip("InferenceClient", () => {
},
TIMEOUT
);
describe.concurrent(
"Swarmind",
() => {
const client = new InferenceClient(env.HF_SWARMIND_KEY ?? "dummy");

HARDCODED_MODEL_INFERENCE_MAPPING["swarmind"] = {
"TheDrummer/Tiger-Gemma-9B-v1": {
hfModelId: "TheDrummer/Tiger-Gemma-9B-v1",
providerId: "tiger-gemma-9b-v1-i1",
status: "live",
task: "conversational",
},
};

it("chatCompletion", async () => {
const res = await client.chatCompletion({
model: "TheDrummer/Tiger-Gemma-9B-v1",
provider: "swarmind",
messages: [{ role: "user", content: "Complete this sentence with words, one plus one is equal " }],
});
if (res.choices && res.choices.length > 0) {
const completion = res.choices[0].message?.content;
expect(completion).toContain("two");
}
});

it("chatCompletion stream", async () => {
const stream = client.chatCompletionStream({
model: "TheDrummer/Tiger-Gemma-9B-v1",
provider: "swarmind",
messages: [{ role: "user", content: "Say 'this is a test'" }],
stream: true,
}) as AsyncGenerator<ChatCompletionStreamOutput>;

let fullResponse = "";
for await (const chunk of stream) {
if (chunk.choices && chunk.choices.length > 0) {
const content = chunk.choices[0].delta?.content;
if (content) {
fullResponse += content;
}
}
}

// Verify we got a meaningful response
expect(fullResponse).toBeTruthy();
expect(fullResponse.length).toBeGreaterThan(0);
});
},
TIMEOUT
);
});