From 182cb7c9968227009bbb5c6bd45d07fa1a755638 Mon Sep 17 00:00:00 2001 From: SBrandeis Date: Tue, 21 Jan 2025 17:52:24 +0100 Subject: [PATCH 1/2] Better types for HfInference --- packages/inference/package.json | 3 +- packages/inference/pnpm-lock.yaml | 8 ++ packages/inference/src/HfInference.ts | 85 +++++++++++++++++--- packages/inference/src/tasks/index.ts | 4 + packages/inference/src/utils/typedEntries.ts | 5 ++ packages/inference/test/HfInference.spec.ts | 24 +++--- 6 files changed, 104 insertions(+), 25 deletions(-) create mode 100644 packages/inference/src/utils/typedEntries.ts diff --git a/packages/inference/package.json b/packages/inference/package.json index 03ca74a2ed..0c297f413a 100644 --- a/packages/inference/package.json +++ b/packages/inference/package.json @@ -55,7 +55,8 @@ "@huggingface/tasks": "workspace:^" }, "devDependencies": { - "@types/node": "18.13.0" + "@types/node": "18.13.0", + "type-fest": "^3.13.1" }, "resolutions": {} } diff --git a/packages/inference/pnpm-lock.yaml b/packages/inference/pnpm-lock.yaml index 5bb56da2d1..6d2f5e5c2b 100644 --- a/packages/inference/pnpm-lock.yaml +++ b/packages/inference/pnpm-lock.yaml @@ -13,9 +13,17 @@ devDependencies: '@types/node': specifier: 18.13.0 version: 18.13.0 + type-fest: + specifier: ^3.13.1 + version: 3.13.1 packages: /@types/node@18.13.0: resolution: {integrity: sha512-gC3TazRzGoOnoKAhUx+Q0t8S9Tzs74z7m0ipwGpSqQrleP14hKxP4/JUeEQcD3W1/aIpnWl8pHowI7WokuZpXg==} dev: true + + /type-fest@3.13.1: + resolution: {integrity: sha512-tLq3bSNx+xSpwvAJnzrK0Ep5CLNWjvFTOp71URMaAEWBfRb9nnJiBoUe0tF8bI4ZFO3omgBR6NvnbzVUT3Ly4g==} + engines: {node: '>=14.16'} + dev: true diff --git a/packages/inference/src/HfInference.ts b/packages/inference/src/HfInference.ts index 6cc268cf29..401abefc5f 100644 --- a/packages/inference/src/HfInference.ts +++ b/packages/inference/src/HfInference.ts @@ -5,7 +5,7 @@ import type { DistributiveOmit } from "./utils/distributive-omit"; /* eslint-disable @typescript-eslint/no-empty-interface */ /* eslint-disable @typescript-eslint/no-unsafe-declaration-merging */ -type Task = typeof tasks; +type Task = Omit; type TaskWithNoAccessToken = { [key in keyof Task]: ( @@ -21,22 +21,83 @@ type TaskWithNoAccessTokenNoEndpointUrl = { ) => ReturnType; }; -export class HfInference { - private readonly accessToken: string; - private readonly defaultOptions: Options; +export class HfInference implements Task { + protected readonly accessToken: string; + protected readonly defaultOptions: Options; + + speechToText: typeof tasks.speechToText; + audioClassification: typeof tasks.audioClassification; + automaticSpeechRecognition: typeof tasks.automaticSpeechRecognition; + textToSpeech: typeof tasks.textToSpeech; + audioToAudio: typeof tasks.audioToAudio; + imageClassification: typeof tasks.imageClassification; + imageSegmentation: typeof tasks.imageSegmentation; + imageToText: typeof tasks.imageToText; + objectDetection: typeof tasks.objectDetection; + textToImage: typeof tasks.textToImage; + imageToImage: typeof tasks.imageToImage; + zeroShotImageClassification: typeof tasks.zeroShotImageClassification; + featureExtraction: typeof tasks.featureExtraction; + fillMask: typeof tasks.fillMask; + questionAnswering: typeof tasks.questionAnswering; + sentenceSimilarity: typeof tasks.sentenceSimilarity; + summarization: typeof tasks.summarization; + tableQuestionAnswering: typeof tasks.tableQuestionAnswering; + textClassification: typeof tasks.textClassification; + textGeneration: typeof tasks.textGeneration; + tokenClassification: typeof tasks.tokenClassification; + translation: typeof tasks.translation; + zeroShotClassification: typeof tasks.zeroShotClassification; + chatCompletion: typeof tasks.chatCompletion; + documentQuestionAnswering: typeof tasks.documentQuestionAnswering; + visualQuestionAnswering: typeof tasks.visualQuestionAnswering; + tabularRegression: typeof tasks.tabularRegression; + tabularClassification: typeof tasks.tabularClassification; + textGenerationStream: typeof tasks.textGenerationStream; + chatCompletionStream: typeof tasks.chatCompletionStream; + + static mapInferenceFn(instance: HfInference, func: (...args: [TArgs, Options?]) => TOut) { + return function (...[args, options]: Parameters<(...args: [TArgs, Options?]) => TOut>): TOut { + return func({ ...args, accessToken: instance.accessToken }, { ...instance.defaultOptions, ...(options ?? {}) }); + }; + } constructor(accessToken = "", defaultOptions: Options = {}) { this.accessToken = accessToken; this.defaultOptions = defaultOptions; - for (const [name, fn] of Object.entries(tasks)) { - Object.defineProperty(this, name, { - enumerable: false, - value: (params: RequestArgs, options: Options) => - // eslint-disable-next-line @typescript-eslint/no-explicit-any - fn({ ...params, accessToken } as any, { ...defaultOptions, ...options }), - }); - } + this.speechToText = HfInference.mapInferenceFn(this, tasks.speechToText); + this.audioClassification = HfInference.mapInferenceFn(this, tasks.audioClassification); + this.automaticSpeechRecognition = HfInference.mapInferenceFn(this, tasks.automaticSpeechRecognition); + this.textToSpeech = HfInference.mapInferenceFn(this, tasks.textToSpeech); + this.audioToAudio = HfInference.mapInferenceFn(this, tasks.audioToAudio); + this.imageClassification = HfInference.mapInferenceFn(this, tasks.imageClassification); + this.imageSegmentation = HfInference.mapInferenceFn(this, tasks.imageSegmentation); + this.imageToText = HfInference.mapInferenceFn(this, tasks.imageToText); + this.objectDetection = HfInference.mapInferenceFn(this, tasks.objectDetection); + this.textToImage = HfInference.mapInferenceFn(this, tasks.textToImage); + this.imageToImage = HfInference.mapInferenceFn(this, tasks.imageToImage); + this.zeroShotImageClassification = HfInference.mapInferenceFn(this, tasks.zeroShotImageClassification); + this.featureExtraction = HfInference.mapInferenceFn(this, tasks.featureExtraction); + this.fillMask = HfInference.mapInferenceFn(this, tasks.fillMask); + this.questionAnswering = HfInference.mapInferenceFn(this, tasks.questionAnswering); + this.sentenceSimilarity = HfInference.mapInferenceFn(this, tasks.sentenceSimilarity); + this.summarization = HfInference.mapInferenceFn(this, tasks.summarization); + this.tableQuestionAnswering = HfInference.mapInferenceFn(this, tasks.tableQuestionAnswering); + this.textClassification = HfInference.mapInferenceFn(this, tasks.textClassification); + this.textGeneration = HfInference.mapInferenceFn(this, tasks.textGeneration); + this.tokenClassification = HfInference.mapInferenceFn(this, tasks.tokenClassification); + this.translation = HfInference.mapInferenceFn(this, tasks.translation); + this.zeroShotClassification = HfInference.mapInferenceFn(this, tasks.zeroShotClassification); + this.chatCompletion = HfInference.mapInferenceFn(this, tasks.chatCompletion); + this.documentQuestionAnswering = HfInference.mapInferenceFn(this, tasks.documentQuestionAnswering); + this.visualQuestionAnswering = HfInference.mapInferenceFn(this, tasks.visualQuestionAnswering); + this.tabularRegression = HfInference.mapInferenceFn(this, tasks.tabularRegression); + this.tabularClassification = HfInference.mapInferenceFn(this, tasks.tabularClassification); + + /// Streaming methods + this.textGenerationStream = HfInference.mapInferenceFn(this, tasks.textGenerationStream); + this.chatCompletionStream = HfInference.mapInferenceFn(this, tasks.chatCompletionStream); } /** diff --git a/packages/inference/src/tasks/index.ts b/packages/inference/src/tasks/index.ts index ba3108446c..2b91f72949 100644 --- a/packages/inference/src/tasks/index.ts +++ b/packages/inference/src/tasks/index.ts @@ -1,3 +1,5 @@ +import { automaticSpeechRecognition } from "./audio/automaticSpeechRecognition"; + // Custom tasks with arbitrary inputs and outputs export * from "./custom/request"; export * from "./custom/streamingRequest"; @@ -40,3 +42,5 @@ export * from "./multimodal/visualQuestionAnswering"; // Tabular tasks export * from "./tabular/tabularRegression"; export * from "./tabular/tabularClassification"; + +export const speechToText = automaticSpeechRecognition; diff --git a/packages/inference/src/utils/typedEntries.ts b/packages/inference/src/utils/typedEntries.ts new file mode 100644 index 0000000000..ee6343b3ae --- /dev/null +++ b/packages/inference/src/utils/typedEntries.ts @@ -0,0 +1,5 @@ +import type { Entries } from "type-fest"; + +export function typedEntries>(obj: T): Entries { + return Object.entries(obj) as Entries; +} diff --git a/packages/inference/test/HfInference.spec.ts b/packages/inference/test/HfInference.spec.ts index d2db78fd4c..84e12ee8b8 100644 --- a/packages/inference/test/HfInference.spec.ts +++ b/packages/inference/test/HfInference.spec.ts @@ -589,18 +589,18 @@ describe.concurrent("HfInference", () => { generated_text: "a large brown and white giraffe standing in a field ", }); }); - it("request - openai-community/gpt2", async () => { - expect( - await hf.request({ - model: "openai-community/gpt2", - inputs: "one plus two equals", - }) - ).toMatchObject([ - { - generated_text: expect.any(String), - }, - ]); - }); + // it("request - openai-community/gpt2", async () => { + // expect( + // await hf.request({ + // model: "openai-community/gpt2", + // inputs: "one plus two equals", + // }) + // ).toMatchObject([ + // { + // generated_text: expect.any(String), + // }, + // ]); + // }); // Skipped at the moment because takes forever it.skip("tabularRegression", async () => { From 2c8429f34007bc2a3d600538143eecc84f020bee Mon Sep 17 00:00:00 2001 From: SBrandeis Date: Tue, 21 Jan 2025 17:56:44 +0100 Subject: [PATCH 2/2] rm unused typedEntries --- packages/inference/package.json | 3 +-- packages/inference/pnpm-lock.yaml | 8 -------- packages/inference/src/utils/typedEntries.ts | 5 ----- 3 files changed, 1 insertion(+), 15 deletions(-) delete mode 100644 packages/inference/src/utils/typedEntries.ts diff --git a/packages/inference/package.json b/packages/inference/package.json index 0c297f413a..03ca74a2ed 100644 --- a/packages/inference/package.json +++ b/packages/inference/package.json @@ -55,8 +55,7 @@ "@huggingface/tasks": "workspace:^" }, "devDependencies": { - "@types/node": "18.13.0", - "type-fest": "^3.13.1" + "@types/node": "18.13.0" }, "resolutions": {} } diff --git a/packages/inference/pnpm-lock.yaml b/packages/inference/pnpm-lock.yaml index 6d2f5e5c2b..5bb56da2d1 100644 --- a/packages/inference/pnpm-lock.yaml +++ b/packages/inference/pnpm-lock.yaml @@ -13,17 +13,9 @@ devDependencies: '@types/node': specifier: 18.13.0 version: 18.13.0 - type-fest: - specifier: ^3.13.1 - version: 3.13.1 packages: /@types/node@18.13.0: resolution: {integrity: sha512-gC3TazRzGoOnoKAhUx+Q0t8S9Tzs74z7m0ipwGpSqQrleP14hKxP4/JUeEQcD3W1/aIpnWl8pHowI7WokuZpXg==} dev: true - - /type-fest@3.13.1: - resolution: {integrity: sha512-tLq3bSNx+xSpwvAJnzrK0Ep5CLNWjvFTOp71URMaAEWBfRb9nnJiBoUe0tF8bI4ZFO3omgBR6NvnbzVUT3Ly4g==} - engines: {node: '>=14.16'} - dev: true diff --git a/packages/inference/src/utils/typedEntries.ts b/packages/inference/src/utils/typedEntries.ts deleted file mode 100644 index ee6343b3ae..0000000000 --- a/packages/inference/src/utils/typedEntries.ts +++ /dev/null @@ -1,5 +0,0 @@ -import type { Entries } from "type-fest"; - -export function typedEntries>(obj: T): Entries { - return Object.entries(obj) as Entries; -}