diff --git a/packages/tasks/src/library-to-tasks.ts b/packages/tasks/src/library-to-tasks.ts index 71fd8e600a..7799d056d6 100644 --- a/packages/tasks/src/library-to-tasks.ts +++ b/packages/tasks/src/library-to-tasks.ts @@ -43,6 +43,10 @@ export const LIBRARY_TASK_MAPPING: Partial { return tensorflowttsUnknown(model); }; +const thirdaiUDT = (model: ModelData): string[] => [ + `from thirdai import bolt + +model = bolt.UniversalDeepTransformer.load("${model.id}") +`, +]; + +const thirdaiNeuralDB = (model: ModelData): string[] => [ + `from thirdai import neural_db as ndb + +model = ndb.NeuralDB.from_checkpoint("${model.id}") +`, +]; + +export const thirdai = (model: ModelData): string[] => { + if (model.tags.includes("neural-db")) { + return thirdaiNeuralDB(model); + } + return thirdaiUDT(model); +}; + export const timm = (model: ModelData): string[] => [ `import timm diff --git a/packages/tasks/src/model-libraries.ts b/packages/tasks/src/model-libraries.ts index 156d21fe24..d0a2777512 100644 --- a/packages/tasks/src/model-libraries.ts +++ b/packages/tasks/src/model-libraries.ts @@ -389,6 +389,13 @@ export const MODEL_LIBRARIES_UI_ELEMENTS = { repoUrl: "https://github.com/TensorSpeech/TensorFlowTTS", snippets: snippets.tensorflowtts, }, + thirdai: { + prettyLabel: "thirdai", + repoName: "thirdai", + repoUrl: "https://github.com/ThirdAILabs/Demos", + snippets: snippets.thirdai, + filter: false, + }, timesfm: { prettyLabel: "TimesFM", repoName: "timesfm", diff --git a/packages/tasks/src/tasks/index.ts b/packages/tasks/src/tasks/index.ts index e3f60b89f2..4f3dff528c 100644 --- a/packages/tasks/src/tasks/index.ts +++ b/packages/tasks/src/tasks/index.ts @@ -143,7 +143,14 @@ export const TASKS_MODEL_LIBRARIES: Record = { "tabular-classification": ["sklearn"], "tabular-regression": ["sklearn"], "tabular-to-text": ["transformers"], - "text-classification": ["adapter-transformers", "setfit", "spacy", "transformers", "transformers.js"], + "text-classification": [ + "adapter-transformers", + "setfit", + "spacy", + "thirdai", + "transformers", + "transformers.js" + ], "text-generation": ["transformers", "transformers.js"], "text-retrieval": [], "text-to-image": ["diffusers"], @@ -158,6 +165,7 @@ export const TASKS_MODEL_LIBRARIES: Record = { "spacy", "span-marker", "stanza", + "thirdai", "transformers", "transformers.js", ],