diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 95d353300c..299772d208 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -20,9 +20,7 @@ It's not a hard requirement, but please consider using an icon from [Gitmoji](ht If you want to run only specific tests, you can do `pnpm test -- -t "test name"`. -You can also do `npx vitest ./packages/hub/src/utils/XetBlob.spec.ts` to run a specific test file. - -Or `cd packages/hub && npx vitest --browser.name=chrome --browser.headless --config vitest-browser.config.mts ./src/utils/XetBlob.spec.ts` to run browser tests on a specific file +You can also do `pnpm --filter hub test ./src/utils/XetBlob.spec.ts` to run a specific test file. ## Adding a package diff --git a/packages/hub/src/lib/commit.spec.ts b/packages/hub/src/lib/commit.spec.ts index 617be8ee89..024155bbc7 100644 --- a/packages/hub/src/lib/commit.spec.ts +++ b/packages/hub/src/lib/commit.spec.ts @@ -33,7 +33,7 @@ describe("commit", () => { try { const readme1 = await downloadFile({ repo, path: "README.md", hubUrl: TEST_HUB_URL }); - assert.strictEqual(readme1?.status, 200); + assert(readme1, "Readme doesn't exist"); const nodeOperation: CommitFile[] = isFrontend ? [] @@ -77,11 +77,9 @@ describe("commit", () => { }); const fileContent = await downloadFile({ repo, path: "test.txt", hubUrl: TEST_HUB_URL }); - assert.strictEqual(fileContent?.status, 200); assert.strictEqual(await fileContent?.text(), "This is me"); const lfsFileContent = await downloadFile({ repo, path: "test.lfs.txt", hubUrl: TEST_HUB_URL }); - assert.strictEqual(lfsFileContent?.status, 200); assert.strictEqual(await lfsFileContent?.text(), lfsContent); const lfsFileUrl = `${TEST_HUB_URL}/${repoName}/raw/main/test.lfs.txt`; @@ -98,7 +96,6 @@ size ${lfsContent.length} if (!isFrontend) { const fileUrlContent = await downloadFile({ repo, path: "tsconfig.json", hubUrl: TEST_HUB_URL }); - assert.strictEqual(fileUrlContent?.status, 200); assert.strictEqual( await fileUrlContent?.text(), (await import("node:fs")).readFileSync("./tsconfig.json", "utf-8") @@ -106,7 +103,6 @@ size ${lfsContent.length} } const webResourceContent = await downloadFile({ repo, path: "lamaral.json", hubUrl: TEST_HUB_URL }); - assert.strictEqual(webResourceContent?.status, 200); assert.strictEqual(await webResourceContent?.text(), await (await fetch(tokenizerJsonUrl)).text()); const readme2 = await downloadFile({ repo, path: "README.md", hubUrl: TEST_HUB_URL }); diff --git a/packages/hub/src/lib/download-file-to-cache-dir.spec.ts b/packages/hub/src/lib/download-file-to-cache-dir.spec.ts index 29d17f2870..fb407c4c2e 100644 --- a/packages/hub/src/lib/download-file-to-cache-dir.spec.ts +++ b/packages/hub/src/lib/download-file-to-cache-dir.spec.ts @@ -1,16 +1,15 @@ import { expect, test, describe, vi, beforeEach } from "vitest"; import type { RepoDesignation, RepoId } from "../types/public"; import { dirname, join } from "node:path"; -import { lstat, mkdir, stat, symlink, writeFile, rename } from "node:fs/promises"; +import { lstat, mkdir, stat, symlink, rename } from "node:fs/promises"; import { pathsInfo } from "./paths-info"; -import type { Stats } from "node:fs"; +import { createWriteStream, type Stats } from "node:fs"; import { getHFHubCachePath, getRepoFolderName } from "./cache-management"; import { toRepoId } from "../utils/toRepoId"; import { downloadFileToCacheDir } from "./download-file-to-cache-dir"; import { createSymlink } from "../utils/symlink"; vi.mock("node:fs/promises", () => ({ - writeFile: vi.fn(), rename: vi.fn(), symlink: vi.fn(), lstat: vi.fn(), @@ -18,6 +17,10 @@ vi.mock("node:fs/promises", () => ({ stat: vi.fn(), })); +vi.mock("node:fs", () => ({ + createWriteStream: vi.fn(), +})); + vi.mock("./paths-info", () => ({ pathsInfo: vi.fn(), })); @@ -63,11 +66,15 @@ describe("downloadFileToCacheDir", () => { beforeEach(() => { vi.resetAllMocks(); // mock 200 request - vi.mocked(fetchMock).mockResolvedValue({ - status: 200, - ok: true, - body: "dummy-body", - } as unknown as Response); + vi.mocked(fetchMock).mockResolvedValue( + new Response("dummy-body", { + status: 200, + headers: { + etag: DUMMY_ETAG, + "Content-Range": "bytes 0-54/55", + }, + }) + ); // prevent to use caching vi.mocked(stat).mockRejectedValue(new Error("Do not exists")); @@ -235,6 +242,9 @@ describe("downloadFileToCacheDir", () => { }, ]); + // eslint-disable-next-line @typescript-eslint/no-explicit-any + vi.mocked(createWriteStream).mockReturnValue(async function* () {} as any); + const output = await downloadFileToCacheDir({ repo: DUMMY_REPO, path: "/README.md", @@ -276,6 +286,9 @@ describe("downloadFileToCacheDir", () => { }, ]); + // eslint-disable-next-line @typescript-eslint/no-explicit-any + vi.mocked(createWriteStream).mockReturnValue(async function* () {} as any); + await downloadFileToCacheDir({ repo: DUMMY_REPO, path: "/README.md", @@ -284,7 +297,7 @@ describe("downloadFileToCacheDir", () => { const incomplete = `${expectedBlob}.incomplete`; // 1. should write fetch#response#body to incomplete file - expect(writeFile).toHaveBeenCalledWith(incomplete, "dummy-body"); + expect(createWriteStream).toHaveBeenCalledWith(incomplete); // 2. should rename the incomplete to the blob expected name expect(rename).toHaveBeenCalledWith(incomplete, expectedBlob); // 3. should create symlink pointing to blob diff --git a/packages/hub/src/lib/download-file-to-cache-dir.ts b/packages/hub/src/lib/download-file-to-cache-dir.ts index 2c2ff44c59..a7b67d9d21 100644 --- a/packages/hub/src/lib/download-file-to-cache-dir.ts +++ b/packages/hub/src/lib/download-file-to-cache-dir.ts @@ -1,12 +1,16 @@ import { getHFHubCachePath, getRepoFolderName } from "./cache-management"; import { dirname, join } from "node:path"; -import { writeFile, rename, lstat, mkdir, stat } from "node:fs/promises"; +import { rename, lstat, mkdir, stat } from "node:fs/promises"; import type { CommitInfo, PathInfo } from "./paths-info"; import { pathsInfo } from "./paths-info"; import type { CredentialsParams, RepoDesignation } from "../types/public"; import { toRepoId } from "../utils/toRepoId"; import { downloadFile } from "./download-file"; import { createSymlink } from "../utils/symlink"; +import { Readable } from "node:stream"; +import type { ReadableStream } from "node:stream/web"; +import { pipeline } from "node:stream/promises"; +import { createWriteStream } from "node:fs"; export const REGEX_COMMIT_HASH: RegExp = new RegExp("^[0-9a-f]{40}$"); @@ -115,15 +119,16 @@ export async function downloadFileToCacheDir( const incomplete = `${blobPath}.incomplete`; console.debug(`Downloading ${params.path} to ${incomplete}`); - const response: Response | null = await downloadFile({ + const blob: Blob | null = await downloadFile({ ...params, revision: commitHash, }); - if (!response || !response.ok || !response.body) throw new Error(`invalid response for file ${params.path}`); + if (!blob) { + throw new Error(`invalid response for file ${params.path}`); + } - // @ts-expect-error resp.body is a Stream, but Stream in internal to node - await writeFile(incomplete, response.body); + await pipeline(Readable.fromWeb(blob.stream() as ReadableStream), createWriteStream(incomplete)); // rename .incomplete file to expect blob await rename(incomplete, blobPath); diff --git a/packages/hub/src/lib/download-file.spec.ts b/packages/hub/src/lib/download-file.spec.ts index 01fc64c945..c25a8b69b8 100644 --- a/packages/hub/src/lib/download-file.spec.ts +++ b/packages/hub/src/lib/download-file.spec.ts @@ -1,65 +1,82 @@ -import { expect, test, describe, vi } from "vitest"; +import { expect, test, describe, assert } from "vitest"; import { downloadFile } from "./download-file"; -import type { RepoId } from "../types/public"; - -const DUMMY_REPO: RepoId = { - name: "hello-world", - type: "model", -}; +import { deleteRepo } from "./delete-repo"; +import { createRepo } from "./create-repo"; +import { TEST_ACCESS_TOKEN, TEST_HUB_URL, TEST_USER } from "../test/consts"; +import { insecureRandomString } from "../utils/insecureRandomString"; describe("downloadFile", () => { - test("hubUrl params should overwrite HUB_URL", async () => { - const fetchMock: typeof fetch = vi.fn(); - vi.mocked(fetchMock).mockResolvedValue({ - status: 200, - ok: true, - } as Response); + test("should download regular file", async () => { + const blob = await downloadFile({ + repo: { + type: "model", + name: "openai-community/gpt2", + }, + path: "README.md", + }); + + const text = await blob?.slice(0, 1000).text(); + assert( + text?.includes(`--- +language: en +tags: +- exbert + +license: mit +--- + + +# GPT-2 - await downloadFile({ - repo: DUMMY_REPO, - path: "/README.md", - hubUrl: "http://dummy-hub", - fetch: fetchMock, +Test the whole generation capabilities here: https://transformer.huggingface.co/doc/gpt2-large`) + ); + }); + test("should downoad xet file", async () => { + const blob = await downloadFile({ + repo: { + type: "model", + name: "celinah/xet-experiments", + }, + path: "large_text.txt", }); - expect(fetchMock).toHaveBeenCalledWith("http://dummy-hub/hello-world/resolve/main//README.md", expect.anything()); + const text = await blob?.slice(0, 100).text(); + expect(text).toMatch("this is a text file.".repeat(10).slice(0, 100)); }); - test("raw params should use raw url", async () => { - const fetchMock: typeof fetch = vi.fn(); - vi.mocked(fetchMock).mockResolvedValue({ - status: 200, - ok: true, - } as Response); + test("should download private file", async () => { + const repoName = `datasets/${TEST_USER}/TEST-${insecureRandomString()}`; - await downloadFile({ - repo: DUMMY_REPO, - path: "README.md", - raw: true, - fetch: fetchMock, + const result = await createRepo({ + accessToken: TEST_ACCESS_TOKEN, + hubUrl: TEST_HUB_URL, + private: true, + repo: repoName, + files: [{ path: ".gitattributes", content: new Blob(["*.html filter=lfs diff=lfs merge=lfs -text"]) }], }); - expect(fetchMock).toHaveBeenCalledWith("https://huggingface.co/hello-world/raw/main/README.md", expect.anything()); - }); + assert.deepStrictEqual(result, { + repoUrl: `${TEST_HUB_URL}/${repoName}`, + }); + + try { + const blob = await downloadFile({ + repo: repoName, + path: ".gitattributes", + hubUrl: TEST_HUB_URL, + accessToken: TEST_ACCESS_TOKEN, + }); - test("internal server error should propagate the error", async () => { - const fetchMock: typeof fetch = vi.fn(); - vi.mocked(fetchMock).mockResolvedValue({ - status: 500, - ok: false, - headers: new Map([["Content-Type", "application/json"]]), - json: () => ({ - error: "Dummy internal error", - }), - } as unknown as Response); + assert(blob, "File should be found"); - await expect(async () => { - await downloadFile({ - repo: DUMMY_REPO, - path: "README.md", - raw: true, - fetch: fetchMock, + const text = await blob?.text(); + assert.strictEqual(text, "*.html filter=lfs diff=lfs merge=lfs -text"); + } finally { + await deleteRepo({ + repo: repoName, + hubUrl: TEST_HUB_URL, + accessToken: TEST_ACCESS_TOKEN, }); - }).rejects.toThrowError("Dummy internal error"); + } }); }); diff --git a/packages/hub/src/lib/download-file.ts b/packages/hub/src/lib/download-file.ts index 4f6ebde2e3..846fcd5ae5 100644 --- a/packages/hub/src/lib/download-file.ts +++ b/packages/hub/src/lib/download-file.ts @@ -1,8 +1,9 @@ -import { HUB_URL } from "../consts"; -import { createApiError } from "../error"; import type { CredentialsParams, RepoDesignation } from "../types/public"; import { checkCredentials } from "../utils/checkCredentials"; -import { toRepoId } from "../utils/toRepoId"; +import { WebBlob } from "../utils/WebBlob"; +import { XetBlob } from "../utils/XetBlob"; +import type { FileDownloadInfoOutput } from "./file-download-info"; +import { fileDownloadInfo } from "./file-download-info"; /** * @returns null when the file doesn't exist @@ -23,43 +24,54 @@ export async function downloadFile( * @default "main" */ revision?: string; - /** - * Fetch only a specific part of the file - */ - range?: [number, number]; hubUrl?: string; /** * Custom fetch function to use instead of the default one, for example to use a proxy or edit headers. */ fetch?: typeof fetch; + /** + * Whether to use the xet protocol to download the file (if applicable). + * + * Currently there's experimental support for it, so it's not enabled by default. + * + * It will be enabled automatically in a future minor version. + * + * @default false + */ + xet?: boolean; + /** + * Can save an http request if provided + */ + downloadInfo?: FileDownloadInfoOutput; } & Partial -): Promise { +): Promise { const accessToken = checkCredentials(params); - const repoId = toRepoId(params.repo); - const url = `${params.hubUrl ?? HUB_URL}/${repoId.type === "model" ? "" : `${repoId.type}s/`}${repoId.name}/${ - params.raw ? "raw" : "resolve" - }/${encodeURIComponent(params.revision ?? "main")}/${params.path}`; - const resp = await (params.fetch ?? fetch)(url, { - headers: { - ...(accessToken - ? { - Authorization: `Bearer ${accessToken}`, - } - : {}), - ...(params.range - ? { - Range: `bytes=${params.range[0]}-${params.range[1]}`, - } - : {}), - }, - }); + const info = + params.downloadInfo ?? + (await fileDownloadInfo({ + accessToken, + repo: params.repo, + path: params.path, + revision: params.revision, + hubUrl: params.hubUrl, + fetch: params.fetch, + raw: params.raw, + })); - if (resp.status === 404 && resp.headers.get("X-Error-Code") === "EntryNotFound") { + if (!info) { return null; - } else if (!resp.ok) { - throw await createApiError(resp); } - return resp; + if (info.xet && params.xet) { + return new XetBlob({ + hash: info.xet.hash, + refreshUrl: info.xet.refreshUrl.href, + fetch: params.fetch, + accessToken, + size: info.size, + }); + } + + return new WebBlob(new URL(info.url), 0, info.size, "", true, params.fetch ?? fetch, accessToken); } diff --git a/packages/hub/src/lib/file-download-info.spec.ts b/packages/hub/src/lib/file-download-info.spec.ts index 75b22b3d2e..d2be156626 100644 --- a/packages/hub/src/lib/file-download-info.spec.ts +++ b/packages/hub/src/lib/file-download-info.spec.ts @@ -13,8 +13,7 @@ describe("fileDownloadInfo", () => { }); assert.strictEqual(info?.size, 536063208); - assert.strictEqual(info?.etag, '"879c5715c18a0b7f051dd33f70f0a5c8dd1522e0a43f6f75520f16167f29279b"'); - assert(info?.downloadLink); + assert.strictEqual(info?.etag, '"a7a17d6d844b5de815ccab5f42cad6d24496db3850a2a43d8258221018ce87d2"'); }); it("should fetch raw LFS pointer info", async () => { @@ -30,7 +29,6 @@ describe("fileDownloadInfo", () => { assert.strictEqual(info?.size, 134); assert.strictEqual(info?.etag, '"9eb98c817f04b051b3bcca591bcd4e03cec88018"'); - assert(!info?.downloadLink); }); it("should fetch non-LFS file info", async () => { @@ -46,4 +44,16 @@ describe("fileDownloadInfo", () => { assert.strictEqual(info?.size, 28); assert.strictEqual(info?.etag, '"a661b1a138dac6dc5590367402d100765010ffd6"'); }); + + it("should fetch xet file info", async () => { + const info = await fileDownloadInfo({ + repo: { + type: "model", + name: "celinah/xet-experiments", + }, + path: "large_text.txt", + }); + assert.strictEqual(info?.size, 62914580); + assert.strictEqual(info?.etag, '"c27f98578d9363b27db0bc1cbd9c692f8e6e90ae98c38cee7bc0a88829debd17"'); + }); }); diff --git a/packages/hub/src/lib/file-download-info.ts b/packages/hub/src/lib/file-download-info.ts index 3dcc79ee9e..2001c132cc 100644 --- a/packages/hub/src/lib/file-download-info.ts +++ b/packages/hub/src/lib/file-download-info.ts @@ -4,13 +4,20 @@ import type { CredentialsParams, RepoDesignation } from "../types/public"; import { checkCredentials } from "../utils/checkCredentials"; import { toRepoId } from "../utils/toRepoId"; +export interface XetFileInfo { + hash: string; + refreshUrl: URL; + /** + * Later, there will also be a `reconstructionUrl` that can be directly used instead of with the hash. + */ +} + export interface FileDownloadInfoOutput { size: number; etag: string; - /** - * In case of LFS file, link to download directly from cloud provider - */ - downloadLink: string | null; + xet?: XetFileInfo; + // URL to fetch (with the access token if private file) + url: string; } /** * @returns null when the file doesn't exist @@ -54,6 +61,7 @@ export async function fileDownloadInfo( Authorization: `Bearer ${accessToken}`, }), Range: "bytes=0-0", + Accept: "application/vnd.xet-fileinfo+json, */*", }, }); @@ -65,28 +73,52 @@ export async function fileDownloadInfo( throw await createApiError(resp); } - const etag = resp.headers.get("ETag"); + let etag: string | undefined; + let size: number | undefined; + let xetInfo: XetFileInfo | undefined; - if (!etag) { - throw new InvalidApiResponseFormatError("Expected ETag"); + if (resp.headers.get("Content-Type")?.includes("application/vnd.xet-fileinfo+json")) { + const json: { casUrl: string; hash: string; refreshUrl: string; size: string; etag: string } = await resp.json(); + + xetInfo = { + hash: json.hash, + refreshUrl: new URL(json.refreshUrl, hubUrl), + }; + + etag = json.etag; + size = parseInt(json.size); } - const contentRangeHeader = resp.headers.get("content-range"); + if (size === undefined || isNaN(size)) { + const contentRangeHeader = resp.headers.get("content-range"); + + if (!contentRangeHeader) { + throw new InvalidApiResponseFormatError("Expected size information"); + } - if (!contentRangeHeader) { - throw new InvalidApiResponseFormatError("Expected size information"); + const [, parsedSize] = contentRangeHeader.split("/"); + size = parseInt(parsedSize); + + if (isNaN(size)) { + throw new InvalidApiResponseFormatError("Invalid file size received"); + } } - const [, parsedSize] = contentRangeHeader.split("/"); - const size = parseInt(parsedSize); + etag ??= resp.headers.get("ETag") ?? undefined; - if (isNaN(size)) { - throw new InvalidApiResponseFormatError("Invalid file size received"); + if (!etag) { + throw new InvalidApiResponseFormatError("Expected ETag"); } return { etag, size, - downloadLink: new URL(resp.url).hostname !== new URL(hubUrl).hostname ? resp.url : null, + xet: xetInfo, + // Cannot use resp.url in case it's a S3 url and the user adds an Authorization header to it. + url: + resp.url && + (new URL(resp.url).hostname === new URL(hubUrl).hostname || resp.headers.get("X-Cache")?.endsWith(" cloudfront")) + ? resp.url + : url, }; } diff --git a/packages/hub/src/lib/parse-safetensors-metadata.ts b/packages/hub/src/lib/parse-safetensors-metadata.ts index 063a503c9f..ca43a00883 100644 --- a/packages/hub/src/lib/parse-safetensors-metadata.ts +++ b/packages/hub/src/lib/parse-safetensors-metadata.ts @@ -89,17 +89,13 @@ async function parseSingleFile( fetch?: typeof fetch; } & Partial ): Promise { - const firstResp = await downloadFile({ - ...params, - path, - range: [0, 7], - }); + const blob = await downloadFile({ ...params, path }); - if (!firstResp) { + if (!blob) { throw new SafetensorParseError(`Failed to parse file ${path}: failed to fetch safetensors header length.`); } - const bufLengthOfHeaderLE = await firstResp.arrayBuffer(); + const bufLengthOfHeaderLE = await blob.slice(0, 8).arrayBuffer(); const lengthOfHeader = new DataView(bufLengthOfHeaderLE).getBigUint64(0, true); // ^little-endian if (lengthOfHeader <= 0) { @@ -111,15 +107,9 @@ async function parseSingleFile( ); } - const secondResp = await downloadFile({ ...params, path, range: [8, 7 + Number(lengthOfHeader)] }); - - if (!secondResp) { - throw new SafetensorParseError(`Failed to parse file ${path}: failed to fetch safetensors header.`); - } - try { // no validation for now, we assume it's a valid FileHeader. - const header: SafetensorsFileHeader = await secondResp.json(); + const header: SafetensorsFileHeader = JSON.parse(await blob.slice(8, 8 + Number(lengthOfHeader)).text()); return header; } catch (err) { throw new SafetensorParseError(`Failed to parse file ${path}: safetensors header is not valid JSON.`); @@ -138,20 +128,19 @@ async function parseShardedIndex( fetch?: typeof fetch; } & Partial ): Promise<{ index: SafetensorsIndexJson; headers: SafetensorsShardedHeaders }> { - const indexResp = await downloadFile({ + const indexBlob = await downloadFile({ ...params, path, - range: [0, 10_000_000], }); - if (!indexResp) { + if (!indexBlob) { throw new SafetensorParseError(`Failed to parse file ${path}: failed to fetch safetensors index.`); } // no validation for now, we assume it's a valid IndexJson. let index: SafetensorsIndexJson; try { - index = await indexResp.json(); + index = JSON.parse(await indexBlob.slice(0, 10_000_000).text()); } catch (error) { throw new SafetensorParseError(`Failed to parse file ${path}: not a valid JSON.`); } diff --git a/packages/hub/src/utils/WebBlob.spec.ts b/packages/hub/src/utils/WebBlob.spec.ts index 68ad69e0d3..242a51e08e 100644 --- a/packages/hub/src/utils/WebBlob.spec.ts +++ b/packages/hub/src/utils/WebBlob.spec.ts @@ -15,7 +15,7 @@ describe("WebBlob", () => { }); it("should create a WebBlob with a slice on the entire resource", async () => { - const webBlob = await WebBlob.create(resourceUrl, { cacheBelow: 0 }); + const webBlob = await WebBlob.create(resourceUrl, { cacheBelow: 0, accessToken: undefined }); expect(webBlob).toMatchObject({ url: resourceUrl, @@ -35,7 +35,7 @@ describe("WebBlob", () => { }); it("should create a WebBlob with a slice on the entire resource, cached", async () => { - const webBlob = await WebBlob.create(resourceUrl, { cacheBelow: 1_000_000 }); + const webBlob = await WebBlob.create(resourceUrl, { cacheBelow: 1_000_000, accessToken: undefined }); expect(webBlob).not.toBeInstanceOf(WebBlob); expect(webBlob.size).toBe(size); @@ -75,7 +75,7 @@ describe("WebBlob", () => { it("should create a slice on the file", async () => { const expectedText = fullText.slice(10, 20); - const slice = (await WebBlob.create(resourceUrl, { cacheBelow: 0 })).slice(10, 20); + const slice = (await WebBlob.create(resourceUrl, { cacheBelow: 0, accessToken: undefined })).slice(10, 20); expect(slice).toMatchObject({ url: resourceUrl, diff --git a/packages/hub/src/utils/WebBlob.ts b/packages/hub/src/utils/WebBlob.ts index ff9aa1e0d7..364bd95094 100644 --- a/packages/hub/src/utils/WebBlob.ts +++ b/packages/hub/src/utils/WebBlob.ts @@ -2,6 +2,8 @@ * WebBlob is a Blob implementation for web resources that supports range requests. */ +import { createApiError } from "../error"; + interface WebBlobCreateOptions { /** * @default 1_000_000 @@ -14,12 +16,20 @@ interface WebBlobCreateOptions { * Custom fetch function to use instead of the default one, for example to use a proxy or edit headers. */ fetch?: typeof fetch; + accessToken: string | undefined; } export class WebBlob extends Blob { static async create(url: URL, opts?: WebBlobCreateOptions): Promise { const customFetch = opts?.fetch ?? fetch; - const response = await customFetch(url, { method: "HEAD" }); + const response = await customFetch(url, { + method: "HEAD", + ...(opts?.accessToken && { + headers: { + Authorization: `Bearer ${opts.accessToken}`, + }, + }), + }); const size = Number(response.headers.get("content-length")); const contentType = response.headers.get("content-type") || ""; @@ -29,7 +39,7 @@ export class WebBlob extends Blob { return await (await customFetch(url)).blob(); } - return new WebBlob(url, 0, size, contentType, true, customFetch); + return new WebBlob(url, 0, size, contentType, true, customFetch, opts?.accessToken); } private url: URL; @@ -38,8 +48,17 @@ export class WebBlob extends Blob { private contentType: string; private full: boolean; private fetch: typeof fetch; - - constructor(url: URL, start: number, end: number, contentType: string, full: boolean, customFetch: typeof fetch) { + private accessToken: string | undefined; + + constructor( + url: URL, + start: number, + end: number, + contentType: string, + full: boolean, + customFetch: typeof fetch, + accessToken: string | undefined + ) { super([]); this.url = url; @@ -48,6 +67,7 @@ export class WebBlob extends Blob { this.contentType = contentType; this.full = full; this.fetch = customFetch; + this.accessToken = accessToken; } override get size(): number { @@ -69,7 +89,8 @@ export class WebBlob extends Blob { Math.min(this.start + end, this.end), this.contentType, start === 0 && end === this.size ? this.full : false, - this.fetch + this.fetch, + this.accessToken ); return slice; @@ -100,12 +121,19 @@ export class WebBlob extends Blob { private fetchRange(): Promise { const fetch = this.fetch; // to avoid this.fetch() which is bound to the instance instead of globalThis if (this.full) { - return fetch(this.url); + return fetch(this.url, { + ...(this.accessToken && { + headers: { + Authorization: `Bearer ${this.accessToken}`, + }, + }), + }).then((resp) => (resp.ok ? resp : createApiError(resp))); } return fetch(this.url, { headers: { Range: `bytes=${this.start}-${this.end - 1}`, + ...(this.accessToken && { Authorization: `Bearer ${this.accessToken}` }), }, - }); + }).then((resp) => (resp.ok ? resp : createApiError(resp))); } } diff --git a/packages/hub/src/utils/XetBlob.spec.ts b/packages/hub/src/utils/XetBlob.spec.ts index 008e2294d2..e3233fab6d 100644 --- a/packages/hub/src/utils/XetBlob.spec.ts +++ b/packages/hub/src/utils/XetBlob.spec.ts @@ -6,12 +6,9 @@ import { sum } from "./sum"; describe("XetBlob", () => { it("should lazy load the first 22 bytes", async () => { const blob = new XetBlob({ - repo: { - type: "model", - name: "celinah/xet-experiments", - }, hash: "7b3b6d07673a88cf467e67c1f7edef1a8c268cbf66e9dd9b0366322d4ab56d9b", size: 5_234_139_343, + refreshUrl: "https://huggingface.co/api/models/celinah/xet-experiments/xet-read-token/main", }); expect(await blob.slice(10, 22).text()).toBe("__metadata__"); @@ -20,10 +17,7 @@ describe("XetBlob", () => { it("should load the first chunk correctly", async () => { let xorbCount = 0; const blob = new XetBlob({ - repo: { - type: "model", - name: "celinah/xet-experiments", - }, + refreshUrl: "https://huggingface.co/api/models/celinah/xet-experiments/xet-read-token/main", hash: "7b3b6d07673a88cf467e67c1f7edef1a8c268cbf66e9dd9b0366322d4ab56d9b", size: 5_234_139_343, fetch: async (url, opts) => { @@ -51,10 +45,7 @@ describe("XetBlob", () => { it("should load just past the first chunk correctly", async () => { let xorbCount = 0; const blob = new XetBlob({ - repo: { - type: "model", - name: "celinah/xet-experiments", - }, + refreshUrl: "https://huggingface.co/api/models/celinah/xet-experiments/xet-read-token/main", hash: "7b3b6d07673a88cf467e67c1f7edef1a8c268cbf66e9dd9b0366322d4ab56d9b", size: 5_234_139_343, fetch: async (url, opts) => { @@ -80,75 +71,62 @@ describe("XetBlob", () => { expect(xorbCount).toBe(2); }); - // Doesn't work in chrome due to caching issues, it caches the partial output when the - // fetch is interrupted in the previous test and then uses that cached output in this test (that requires more data) - if (typeof window === "undefined") { - it("should load the first 200kB correctly", async () => { - let xorbCount = 0; - const blob = new XetBlob({ - repo: { - type: "model", - name: "celinah/xet-experiments", - }, - hash: "7b3b6d07673a88cf467e67c1f7edef1a8c268cbf66e9dd9b0366322d4ab56d9b", - size: 5_234_139_343, - fetch: async (url, opts) => { - if (typeof url === "string" && url.includes("/xorbs/")) { - xorbCount++; - } - return fetch(url, opts); - }, - // internalLogging: true, - }); - - const xetDownload = await blob.slice(0, 200_000).arrayBuffer(); - const bridgeDownload = await fetch( - "https://huggingface.co/celinah/xet-experiments/resolve/main/model5GB.safetensors", - { - headers: { - Range: "bytes=0-199999", - }, + it("should load the first 200kB correctly", async () => { + let xorbCount = 0; + const blob = new XetBlob({ + refreshUrl: "https://huggingface.co/api/models/celinah/xet-experiments/xet-read-token/main", + hash: "7b3b6d07673a88cf467e67c1f7edef1a8c268cbf66e9dd9b0366322d4ab56d9b", + size: 5_234_139_343, + fetch: async (url, opts) => { + if (typeof url === "string" && url.includes("/xorbs/")) { + xorbCount++; } - ).then((res) => res.arrayBuffer()); - - expect(xetDownload.byteLength).toBe(200_000); - expect(new Uint8Array(xetDownload)).toEqual(new Uint8Array(bridgeDownload)); - expect(xorbCount).toBe(2); - }, 60_000); - - it("should load correctly when loading far into a chunk range", async () => { - const blob = new XetBlob({ - repo: { - type: "model", - name: "celinah/xet-experiments", + return fetch(url, opts); + }, + // internalLogging: true, + }); + + const xetDownload = await blob.slice(0, 200_000).arrayBuffer(); + const bridgeDownload = await fetch( + "https://huggingface.co/celinah/xet-experiments/resolve/main/model5GB.safetensors", + { + headers: { + Range: "bytes=0-199999", }, - hash: "7b3b6d07673a88cf467e67c1f7edef1a8c268cbf66e9dd9b0366322d4ab56d9b", - size: 5_234_139_343, - // internalLogging: true, - }); + } + ).then((res) => res.arrayBuffer()); - const xetDownload = await blob.slice(10_000_000, 10_100_000).arrayBuffer(); - const bridgeDownload = await fetch( - "https://huggingface.co/celinah/xet-experiments/resolve/main/model5GB.safetensors", - { - headers: { - Range: "bytes=10000000-10099999", - }, - } - ).then((res) => res.arrayBuffer()); + expect(xetDownload.byteLength).toBe(200_000); + expect(new Uint8Array(xetDownload)).toEqual(new Uint8Array(bridgeDownload)); + expect(xorbCount).toBe(2); + }, 60_000); - console.log("xet", xetDownload.byteLength, "bridge", bridgeDownload.byteLength); - expect(new Uint8Array(xetDownload).length).toEqual(100_000); - expect(new Uint8Array(xetDownload)).toEqual(new Uint8Array(bridgeDownload)); + it("should load correctly when loading far into a chunk range", async () => { + const blob = new XetBlob({ + refreshUrl: "https://huggingface.co/api/models/celinah/xet-experiments/xet-read-token/main", + hash: "7b3b6d07673a88cf467e67c1f7edef1a8c268cbf66e9dd9b0366322d4ab56d9b", + size: 5_234_139_343, + // internalLogging: true, }); - } + + const xetDownload = await blob.slice(10_000_000, 10_100_000).arrayBuffer(); + const bridgeDownload = await fetch( + "https://huggingface.co/celinah/xet-experiments/resolve/main/model5GB.safetensors", + { + headers: { + Range: "bytes=10000000-10099999", + }, + } + ).then((res) => res.arrayBuffer()); + + console.log("xet", xetDownload.byteLength, "bridge", bridgeDownload.byteLength); + expect(new Uint8Array(xetDownload).length).toEqual(100_000); + expect(new Uint8Array(xetDownload)).toEqual(new Uint8Array(bridgeDownload)); + }); it("should load text correctly when offset_into_range starts in a chunk further than the first", async () => { const blob = new XetBlob({ - repo: { - type: "model", - name: "celinah/xet-experiments", - }, + refreshUrl: "https://huggingface.co/api/models/celinah/xet-experiments/xet-read-token/main", hash: "794efea76d8cb372bbe1385d9e51c3384555f3281e629903ecb6abeff7d54eec", size: 62_914_580, }); @@ -238,12 +216,8 @@ describe("XetBlob", () => { const blob = new XetBlob({ hash: "test", - repo: { - name: "test", - type: "model", - }, size: totalSize, - hubUrl: "https://huggingface.co", + refreshUrl: "https://huggingface.co", listener: (e) => debugged.push(e), fetch: async function (_url, opts) { const url = new URL(_url as string); @@ -345,12 +319,8 @@ describe("XetBlob", () => { const blob = new XetBlob({ hash: "test", - repo: { - name: "test", - type: "model", - }, size: totalSize, - hubUrl: "https://huggingface.co", + refreshUrl: "https://huggingface.co", listener: (e) => debugged.push(e), fetch: async function (_url, opts) { const url = new URL(_url as string); @@ -464,12 +434,8 @@ describe("XetBlob", () => { const blob = new XetBlob({ hash: "test", - repo: { - name: "test", - type: "model", - }, size: totalSize, - hubUrl: "https://huggingface.co", + refreshUrl: "https://huggingface.co", listener: (e) => debugged.push(e), fetch: async function (_url, opts) { const url = new URL(_url as string); @@ -578,12 +544,8 @@ describe("XetBlob", () => { const blob = new XetBlob({ hash: "test", - repo: { - name: "test", - type: "model", - }, size: totalSize, - hubUrl: "https://huggingface.co", + refreshUrl: "https://huggingface.co", listener: (e) => debugged.push(e), fetch: async function (_url, opts) { const url = new URL(_url as string); @@ -690,12 +652,8 @@ describe("XetBlob", () => { const blob = new XetBlob({ hash: "test", - repo: { - name: "test", - type: "model", - }, size: totalSize, - hubUrl: "https://huggingface.co", + refreshUrl: "https://huggingface.co", listener: (e) => debugged.push(e), fetch: async function (_url, opts) { const url = new URL(_url as string); @@ -801,12 +759,8 @@ describe("XetBlob", () => { const blob = new XetBlob({ hash: "test", - repo: { - name: "test", - type: "model", - }, size: totalSize, - hubUrl: "https://huggingface.co", + refreshUrl: "https://huggingface.co", listener: (e) => debugged.push(e), fetch: async function (_url, opts) { const url = new URL(_url as string); diff --git a/packages/hub/src/utils/XetBlob.ts b/packages/hub/src/utils/XetBlob.ts index ca91e5cbab..aed5852e73 100644 --- a/packages/hub/src/utils/XetBlob.ts +++ b/packages/hub/src/utils/XetBlob.ts @@ -1,8 +1,6 @@ -import { HUB_URL } from "../consts"; import { createApiError } from "../error"; -import type { CredentialsParams, RepoDesignation, RepoId } from "../types/public"; +import type { CredentialsParams } from "../types/public"; import { checkCredentials } from "./checkCredentials"; -import { toRepoId } from "./toRepoId"; import { decompress as lz4_decompress } from "../vendor/lz4js"; import { RangeList } from "./RangeList"; @@ -14,9 +12,9 @@ type XetBlobCreateOptions = { * Custom fetch function to use instead of the default one, for example to use a proxy or edit headers. */ fetch?: typeof fetch; - repo: RepoDesignation; hash: string; - hubUrl?: string; + // URL to get the access token from + refreshUrl: string; size: number; listener?: (arg: { event: "read" } | { event: "progress"; progress: { read: number; total: number } }) => void; internalLogging?: boolean; @@ -85,8 +83,7 @@ const CHUNK_HEADER_BYTES = 8; export class XetBlob extends Blob { fetch: typeof fetch; accessToken?: string; - repoId: RepoId; - hubUrl: string; + refreshUrl: string; hash: string; start = 0; end = 0; @@ -99,13 +96,12 @@ export class XetBlob extends Blob { this.fetch = params.fetch ?? fetch.bind(globalThis); this.accessToken = checkCredentials(params); - this.repoId = toRepoId(params.repo); - this.hubUrl = params.hubUrl ?? HUB_URL; + this.refreshUrl = params.refreshUrl; this.end = params.size; this.hash = params.hash; this.listener = params.listener; this.internalLogging = params.internalLogging ?? false; - this.hubUrl; + this.refreshUrl; } override get size(): number { @@ -115,9 +111,8 @@ export class XetBlob extends Blob { #clone() { const blob = new XetBlob({ fetch: this.fetch, - repo: this.repoId, hash: this.hash, - hubUrl: this.hubUrl, + refreshUrl: this.refreshUrl, size: this.size, }); @@ -156,7 +151,7 @@ export class XetBlob extends Blob { } this.#reconstructionInfoPromise = (async () => { - const connParams = await getAccessToken(this.repoId, this.accessToken, this.fetch, this.hubUrl); + const connParams = await getAccessToken(this.accessToken, this.fetch, this.refreshUrl); // debug( // `curl '${connParams.casUrl}/reconstruction/${this.hash}' -H 'Authorization: Bearer ${connParams.accessToken}'` @@ -525,8 +520,8 @@ const jwts: Map< } > = new Map(); -function cacheKey(params: { repoId: RepoId; initialAccessToken: string | undefined }): string { - return `${params.repoId.type}:${params.repoId.name}:${params.initialAccessToken}`; +function cacheKey(params: { refreshUrl: string; initialAccessToken: string | undefined }): string { + return JSON.stringify([params.refreshUrl, params.initialAccessToken]); } // exported for testing purposes @@ -592,12 +587,11 @@ export function bg4_regoup_bytes(bytes: Uint8Array): Uint8Array { } async function getAccessToken( - repoId: RepoId, initialAccessToken: string | undefined, customFetch: typeof fetch, - hubUrl: string + refreshUrl: string ): Promise<{ accessToken: string; casUrl: string }> { - const key = cacheKey({ repoId, initialAccessToken }); + const key = cacheKey({ refreshUrl, initialAccessToken }); const jwt = jwts.get(key); @@ -612,8 +606,7 @@ async function getAccessToken( } const promise = (async () => { - const url = `${hubUrl}/api/${repoId.type}s/${repoId.name}/xet-read-token/main`; - const resp = await customFetch(url, { + const resp = await customFetch(refreshUrl, { headers: { ...(initialAccessToken ? { @@ -629,11 +622,10 @@ async function getAccessToken( const json: { accessToken: string; casUrl: string; exp: number } = await resp.json(); const jwt = { - repoId, accessToken: json.accessToken, expiresAt: new Date(json.exp * 1000), initialAccessToken, - hubUrl, + refreshUrl, casUrl: json.casUrl, }; @@ -660,7 +652,7 @@ async function getAccessToken( }; })(); - jwtPromises.set(repoId.name, promise); + jwtPromises.set(key, promise); return promise; } diff --git a/packages/hub/src/utils/createBlob.ts b/packages/hub/src/utils/createBlob.ts index 0cf54206da..5d5f200a66 100644 --- a/packages/hub/src/utils/createBlob.ts +++ b/packages/hub/src/utils/createBlob.ts @@ -11,9 +11,9 @@ import { isFrontend } from "./isFrontend"; * From the frontend: * - support http resources with absolute or relative URLs */ -export async function createBlob(url: URL, opts?: { fetch?: typeof fetch }): Promise { +export async function createBlob(url: URL, opts?: { fetch?: typeof fetch; accessToken?: string }): Promise { if (url.protocol === "http:" || url.protocol === "https:") { - return WebBlob.create(url, { fetch: opts?.fetch }); + return WebBlob.create(url, { fetch: opts?.fetch, accessToken: opts?.accessToken }); } if (isFrontend) { diff --git a/packages/jinja/test/e2e.test.js b/packages/jinja/test/e2e.test.js index 3ce98e2342..1ec3d0f43f 100644 --- a/packages/jinja/test/e2e.test.js +++ b/packages/jinja/test/e2e.test.js @@ -716,12 +716,11 @@ describe("End-to-end tests", () => { it("should parse a chat template from the Hugging Face Hub", async () => { const repo = "TheBloke/Mistral-7B-Instruct-v0.1-GPTQ"; - const tokenizerConfig = await ( - await downloadFile({ - repo, - path: "tokenizer_config.json", - }) - ).json(); + const blob = await downloadFile({ + repo, + path: "tokenizer_config.json", + }); + const tokenizerConfig = JSON.parse(await blob.text()); const template = new Template(tokenizerConfig.chat_template); const result = template.render(TEST_CUSTOM_TEMPLATES[repo].data);