Skip to content

Commit 4a22043

Browse files
Add Pruna AI library snippets (no formatting changes) (#1733)
This update enhances support for Pruna AI, providing users with tailored code snippets for model integrations with Transformers and Diffusers. - Introduced a new library entry for Pruna AI in model-libraries. - Added main entry point and specific snippet generation functions for diffusers and transformers models. - Cleaned up whitespace inconsistencies in existing snippets. TLDR: Pruna API normally mimics the Transformers and Diffusers API, so we can use `PrunaModel.from_pretrained` on top of pipelines or specific models. We re-use the underlying snippets for both the library and do some greedy replacements of certain part of the code snippets. example ```python import torch from diffusers import FluxFillPipeline from diffusers.utils import load_image image = load_image("https://huggingface.co/datasets/diffusers/diffusers-images-docs/resolve/main/cup.png") mask = load_image("https://huggingface.co/datasets/diffusers/diffusers-images-docs/resolve/main/cup_mask.png") pipe = FluxFillPipeline.from_pretrained("black-forest-labs/FLUX.1-Fill-dev", torch_dtype=torch.bfloat16).to("cuda") image = pipe( prompt="a white paper cup", image=image, mask_image=mask, height=1632, width=1232, guidance_scale=30, num_inference_steps=50, max_sequence_length=512, generator=torch.Generator("cpu").manual_seed(0) ).images[0] image.save(f"flux-fill-dev.png") ``` becomes ```python import torch from pruna import PrunaModel from diffusers.utils import load_image image = load_image("https://huggingface.co/datasets/diffusers/diffusers-images-docs/resolve/main/cup.png") mask = load_image("https://huggingface.co/datasets/diffusers/diffusers-images-docs/resolve/main/cup_mask.png") pipe = PrunaModel.from_pretrained("black-forest-labs/FLUX.1-Fill-dev", torch_dtype=torch.bfloat16).to("cuda") image = pipe( prompt="a white paper cup", image=image, mask_image=mask, height=1632, width=1232, guidance_scale=30, num_inference_steps=50, max_sequence_length=512, generator=torch.Generator("cpu").manual_seed(0) ).images[0] image.save(f"flux-fill-dev.png") ``` --------- Co-authored-by: Lucain <[email protected]>
1 parent 15d7071 commit 4a22043

File tree

2 files changed

+88
-0
lines changed

2 files changed

+88
-0
lines changed

packages/tasks/src/model-libraries-snippets.ts

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1905,6 +1905,87 @@ export const model2vec = (model: ModelData): string[] => [
19051905
model = StaticModel.from_pretrained("${model.id}")`,
19061906
];
19071907

1908+
export const pruna = (model: ModelData): string[] => {
1909+
let snippets: string[];
1910+
1911+
if (model.tags.includes("diffusers")) {
1912+
snippets = pruna_diffusers(model);
1913+
} else if (model.tags.includes("transformers")) {
1914+
snippets = pruna_transformers(model);
1915+
} else {
1916+
snippets = pruna_default(model);
1917+
}
1918+
1919+
const ensurePrunaModelImport = (snippet: string): string => {
1920+
if (!/^from pruna import PrunaModel/m.test(snippet)) {
1921+
return `from pruna import PrunaModel\n${snippet}`;
1922+
}
1923+
return snippet;
1924+
};
1925+
snippets = snippets.map(ensurePrunaModelImport);
1926+
1927+
if (model.tags.includes("pruna_pro-ai")) {
1928+
return snippets.map((snippet) =>
1929+
snippet.replace(/\bpruna\b/g, "pruna_pro").replace(/\bPrunaModel\b/g, "PrunaProModel")
1930+
);
1931+
}
1932+
1933+
return snippets;
1934+
};
1935+
1936+
const pruna_diffusers = (model: ModelData): string[] => {
1937+
const diffusersSnippets = diffusers(model);
1938+
1939+
return diffusersSnippets.map((snippet) =>
1940+
snippet
1941+
// Replace pipeline classes with PrunaModel
1942+
.replace(/\b\w*Pipeline\w*\b/g, "PrunaModel")
1943+
// Clean up diffusers imports containing PrunaModel
1944+
.replace(/from diffusers import ([^,\n]*PrunaModel[^,\n]*)/g, "")
1945+
.replace(/from diffusers import ([^,\n]+),?\s*([^,\n]*PrunaModel[^,\n]*)/g, "from diffusers import $1")
1946+
.replace(/from diffusers import\s*(\n|$)/g, "")
1947+
// Fix PrunaModel imports
1948+
.replace(/from diffusers import PrunaModel/g, "from pruna import PrunaModel")
1949+
.replace(/from diffusers import ([^,\n]+), PrunaModel/g, "from diffusers import $1")
1950+
.replace(/from diffusers import PrunaModel, ([^,\n]+)/g, "from diffusers import $1")
1951+
// Clean up whitespace
1952+
.replace(/\n\n+/g, "\n")
1953+
.trim()
1954+
);
1955+
};
1956+
1957+
const pruna_transformers = (model: ModelData): string[] => {
1958+
const info = model.transformersInfo;
1959+
const transformersSnippets = transformers(model);
1960+
1961+
// Replace pipeline with PrunaModel
1962+
let processedSnippets = transformersSnippets.map((snippet) =>
1963+
snippet
1964+
.replace(/from transformers import pipeline/g, "from pruna import PrunaModel")
1965+
.replace(/pipeline\([^)]*\)/g, `PrunaModel.from_pretrained("${model.id}")`)
1966+
);
1967+
1968+
// Additional cleanup if auto_model info is available
1969+
if (info?.auto_model) {
1970+
processedSnippets = processedSnippets.map((snippet) =>
1971+
snippet
1972+
.replace(new RegExp(`from transformers import ${info.auto_model}\n?`, "g"), "")
1973+
.replace(new RegExp(`${info.auto_model}.from_pretrained`, "g"), "PrunaModel.from_pretrained")
1974+
.replace(new RegExp(`^.*from.*import.*(, *${info.auto_model})+.*$`, "gm"), (line) =>
1975+
line.replace(new RegExp(`, *${info.auto_model}`, "g"), "")
1976+
)
1977+
);
1978+
}
1979+
1980+
return processedSnippets;
1981+
};
1982+
1983+
const pruna_default = (model: ModelData): string[] => [
1984+
`from pruna import PrunaModel
1985+
model = PrunaModel.from_pretrained("${model.id}")
1986+
`,
1987+
];
1988+
19081989
export const nemo = (model: ModelData): string[] => {
19091990
let command: string[] | undefined = undefined;
19101991
// Resolve the tag to a nemo domain/sub-domain

packages/tasks/src/model-libraries.ts

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -801,6 +801,13 @@ export const MODEL_LIBRARIES_UI_ELEMENTS = {
801801
filter: false,
802802
countDownloads: `path_extension:"pth"`,
803803
},
804+
"pruna-ai": {
805+
prettyLabel: "Pruna AI",
806+
repoName: "Pruna AI",
807+
repoUrl: "https://github.com/PrunaAI/pruna",
808+
snippets: snippets.pruna,
809+
docsUrl: "https://docs.pruna.ai",
810+
},
804811
pxia: {
805812
prettyLabel: "pxia",
806813
repoName: "pxia",

0 commit comments

Comments
 (0)