-
Notifications
You must be signed in to change notification settings - Fork 10
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
working, only incorrect model download
- Loading branch information
1 parent
caf8b80
commit fa8420b
Showing
12 changed files
with
1,326 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
# | ||
# Copyright IBM Corp. 2024 - 2024 | ||
# SPDX-License-Identifier: MIT | ||
# | ||
import argparse | ||
import logging | ||
import os | ||
import sys | ||
import time | ||
from pathlib import Path | ||
|
||
from huggingface_hub import snapshot_download | ||
from PIL import Image | ||
|
||
from docling_ibm_models.code_formula_model.code_formula_predictor import CodeFormulaPredictor | ||
|
||
|
||
def demo( | ||
logger: logging.Logger, | ||
artifact_path: str, | ||
device: str, | ||
num_threads: int, | ||
image_dir: str, | ||
viz_dir: str, | ||
): | ||
r""" | ||
Apply LayoutPredictor on the input image directory | ||
If you want to load from PDF: | ||
pdf_image = pyvips.Image.new_from_file("test_data/ADS.2007.page_123.pdf", page=0) | ||
""" | ||
# Create the layout predictor | ||
code_formula_predictor = CodeFormulaPredictor(artifact_path, device=device, num_threads=num_threads) | ||
|
||
image_dir = Path(image_dir) | ||
images = [] | ||
image_names = os.listdir(image_dir) | ||
image_names.sort() | ||
for image_name in image_names: | ||
image = Image.open(image_dir / image_name) | ||
images.append(image) | ||
|
||
t0 = time.perf_counter() | ||
outputs = code_formula_predictor.predict(images, ['code', 'formula'], temperature=0) | ||
total_ms = 1000 * (time.perf_counter() - t0) | ||
avg_ms = (total_ms / len(image_names)) if len(image_names) > 0 else 0 | ||
logger.info( | ||
"For {} images(ms): [total|avg] = [{:.1f}|{:.1f}]".format( | ||
len(image_names), total_ms, avg_ms | ||
) | ||
) | ||
|
||
for i, output in enumerate(outputs): | ||
logger.info(f"\nOutput {i}:\n{output}\n\n") | ||
|
||
|
||
def main(args): | ||
num_threads = int(args.num_threads) if args.num_threads is not None else None | ||
device = args.device.lower() | ||
image_dir = args.image_dir | ||
viz_dir = args.viz_dir | ||
|
||
# Initialize logger | ||
logging.basicConfig(level=logging.DEBUG) | ||
logger = logging.getLogger("CodeFormulaPredictor") | ||
logger.setLevel(logging.DEBUG) | ||
if not logger.hasHandlers(): | ||
handler = logging.StreamHandler(sys.stdout) | ||
formatter = logging.Formatter( | ||
"%(asctime)s %(name)-12s %(levelname)-8s %(message)s" | ||
) | ||
handler.setFormatter(formatter) | ||
logger.addHandler(handler) | ||
|
||
# Ensure the viz dir | ||
Path(viz_dir).mkdir(parents=True, exist_ok=True) | ||
|
||
# ! TODO: change this | ||
# Download models from HF | ||
# download_path = snapshot_download(repo_id="ds4sd/docling-models", revision="v2.1.0") | ||
# artifact_path = os.path.join(download_path, "model_artifacts/layout") | ||
artifact_path = "/dccstor/doc_fig_class/DocFM-Vision-Pretrainer/Vary-master/checkpoints_new_code_equation_model/checkpoint-7000/" | ||
|
||
# Test the LayoutPredictor | ||
demo(logger, artifact_path, device, num_threads, image_dir, viz_dir) | ||
|
||
|
||
if __name__ == "__main__": | ||
r""" | ||
python -m demo.demo_code_formula_predictor -i <images_dir> | ||
""" | ||
parser = argparse.ArgumentParser(description="Test the CodeFormulaPredictor") | ||
parser.add_argument( | ||
"-d", "--device", required=False, default="cpu", help="One of [cpu, cuda, mps]" | ||
) | ||
parser.add_argument( | ||
"-n", "--num_threads", required=False, default=4, help="Number of threads" | ||
) | ||
parser.add_argument( | ||
"-i", | ||
"--image_dir", | ||
required=True, | ||
help="PNG images input directory", | ||
) | ||
parser.add_argument( | ||
"-v", | ||
"--viz_dir", | ||
required=False, | ||
default="viz/", | ||
help="Directory to save prediction visualizations", | ||
) | ||
|
||
args = parser.parse_args() | ||
main(args) |
225 changes: 225 additions & 0 deletions
225
docling_ibm_models/code_formula_model/code_formula_predictor.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,225 @@ | ||
# | ||
# Copyright IBM Corp. 2024 - 2024 | ||
# SPDX-License-Identifier: MIT | ||
# | ||
import logging | ||
from typing import List, Union | ||
|
||
import numpy as np | ||
import torch | ||
from PIL import Image | ||
from transformers import AutoTokenizer | ||
|
||
from docling_ibm_models.code_formula_model.utils.constants import * | ||
|
||
from docling_ibm_models.code_formula_model.utils.conversations import conv_v1 | ||
|
||
from docling_ibm_models.code_formula_model.models.vary_opt import varyOPTForCausalLM | ||
from docling_ibm_models.code_formula_model.models.vary_opt_image_processor import VaryOptImageProcessor | ||
|
||
|
||
_log = logging.getLogger(__name__) | ||
|
||
|
||
class CodeFormulaPredictor: | ||
""" | ||
Code and Formula Predictor using a multi-modal vision-language model. | ||
This class enables the prediction of code or LaTeX representations | ||
from input images of code snippets or mathematical formulas. | ||
Attributes | ||
---------- | ||
_device : str | ||
The device on which the model is loaded (e.g., 'cpu' or 'cuda'). | ||
_num_threads : int | ||
Number of threads used for inference when running on CPU. | ||
_tokenizer : transformers.PreTrainedTokenizer | ||
Tokenizer for processing textual inputs to the model. | ||
_model : transformers.PreTrainedModel | ||
Pretrained multi-modal vision-language model. | ||
_image_processor : transformers.ImageProcessor | ||
Processor for normalizing and preparing input images. | ||
_temperature : float | ||
Sampling temperature for generation; controls randomness in predictions. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
artifacts_path: str, | ||
device: str = "cpu", | ||
num_threads: int = 4, | ||
): | ||
""" | ||
Initializes the CodeFormulaPredictor with the specified model artifacts. | ||
Parameters | ||
---------- | ||
artifacts_path : str | ||
Path to the directory containing the pretrained model files. | ||
device : str, optional | ||
Device to run the inference on ('cpu' or 'cuda'), by default "cpu". | ||
num_threads : int, optional | ||
Number of threads for CPU inference, by default 4. | ||
""" | ||
self._device = device | ||
self._num_threads = num_threads | ||
if device == "cpu": | ||
torch.set_num_threads(self._num_threads) | ||
|
||
self._tokenizer = AutoTokenizer.from_pretrained(artifacts_path, use_fast=True, padding_side='left') | ||
self._model = varyOPTForCausalLM.from_pretrained(artifacts_path).to(self._device) | ||
self._model.eval() | ||
|
||
self._image_processor = VaryOptImageProcessor.from_pretrained(artifacts_path) | ||
|
||
_log.debug("CodeFormulaModel settings: {}".format(self.info())) | ||
|
||
def info(self) -> dict: | ||
""" | ||
Retrieves configuration details of the CodeFormulaPredictor instance. | ||
Returns | ||
------- | ||
dict | ||
A dictionary containing configuration details such as the device and | ||
the number of threads used. | ||
""" | ||
info = { | ||
"device": self._device, | ||
"num_threads": self._num_threads, | ||
} | ||
return info | ||
|
||
def _get_prompt(self, label: str) -> str: | ||
""" | ||
Constructs the prompt for the model based on the input label. | ||
Parameters | ||
---------- | ||
label : str | ||
The type of input, either 'code' or 'formula'. | ||
Returns | ||
------- | ||
str | ||
The constructed prompt including necessary tokens and query. | ||
Raises | ||
------ | ||
NotImplementedError | ||
If the label is not 'code' or 'formula'. | ||
""" | ||
if label == "code": | ||
query = "<code_image_to_text>" | ||
elif label == "formula": | ||
query = "<equation>" | ||
else: | ||
raise NotImplementedError("Label must be either code or formula") | ||
|
||
qs = ( | ||
DEFAULT_IM_START_TOKEN | ||
+ DEFAULT_IMAGE_PATCH_TOKEN * IMAGE_TOKEN_LEN | ||
+ DEFAULT_IM_END_TOKEN | ||
+ "\n" | ||
) | ||
|
||
conversation = conv_v1.copy() | ||
conversation.append_message(conversation.roles[0], qs) | ||
conversation.append_message(conversation.roles[1], None) | ||
|
||
prompt = conversation.get_prompt() | ||
prompt = prompt + "\n" + query | ||
|
||
return prompt | ||
|
||
@torch.inference_mode() | ||
def predict( | ||
self, | ||
images: List[Union[Image.Image, np.ndarray]], | ||
labels: List[str], | ||
temperature: float = 0.1, | ||
) -> List[str]: | ||
""" | ||
Predicts the textual representation of input images (code or LaTeX). | ||
Parameters | ||
---------- | ||
images : List[Union[Image.Image, np.ndarray]] | ||
List of images to be processed, provided as PIL Image objects or numpy arrays. | ||
labels : List[str] | ||
List of labels indicating the type of each image ('code' or 'formula'). | ||
temperature : float, optional | ||
Sampling temperature for generation, by default set to 0.1. | ||
Returns | ||
------- | ||
List[str] | ||
List of predicted textual outputs for each input image in the given input | ||
order. | ||
Raises | ||
------ | ||
TypeError | ||
If any of the input images is not of a supported type (PIL Image or numpy array). | ||
Excpetion | ||
In case the temperature is an invalid number. | ||
""" | ||
if (type(temperature) != float and type(temperature) != int) or temperature < 0: | ||
raise Exception("Temperature must be a number greater or equal to 0.") | ||
|
||
do_sample = True | ||
if temperature == 0: | ||
do_sample = False | ||
temperature = None | ||
|
||
if len(labels) != len(images): | ||
raise Exception("The number of images must be the same as the number of labels.") | ||
|
||
images_tmp = [] | ||
for image in images: | ||
if isinstance(image, Image.Image): | ||
image = image.convert("RGB") | ||
elif isinstance(image, np.ndarray): | ||
image = Image.fromarray(image).convert("RGB") | ||
else: | ||
raise TypeError("Not supported input image format") | ||
images_tmp.append(image) | ||
images = images_tmp | ||
|
||
images_tensor = torch.stack([self._image_processor(img) for img in images]).to(self._device) | ||
|
||
prompts = [self._get_prompt(label) for label in labels] | ||
|
||
tokenized = self._tokenizer(prompts, padding=True, return_tensors="pt") | ||
tokenized = {k: v.to(self._device) for k, v in tokenized.items()} | ||
|
||
prompt_ids = tokenized["input_ids"] | ||
attention_mask = tokenized["attention_mask"] | ||
|
||
if self._device == "cpu": | ||
output_ids_list = self._model.generate( | ||
input_ids=prompt_ids, | ||
attention_mask=attention_mask, | ||
images=images_tensor, | ||
do_sample=do_sample, | ||
temperature=temperature, | ||
max_new_tokens=4096 - prompt_ids.shape[0], | ||
use_cache=True, | ||
) | ||
else: | ||
with torch.autocast(device_type=self._device, dtype=torch.bfloat16): | ||
output_ids_list = self._model.generate( | ||
prompt_ids, | ||
images=images_tensor, | ||
do_sample=do_sample, | ||
temperature=temperature, | ||
max_new_tokens=4096 - prompt_ids.shape[0], | ||
use_cache=True, | ||
) | ||
|
||
outputs = self._tokenizer.batch_decode( | ||
output_ids_list[:, prompt_ids.shape[1] :], skip_special_tokens=True | ||
) | ||
|
||
return outputs |
Oops, something went wrong.