Skip to content

Commit

Permalink
working, only incorrect model download
Browse files Browse the repository at this point in the history
  • Loading branch information
Matteo Omenetti [email protected] committed Jan 10, 2025
1 parent caf8b80 commit fa8420b
Show file tree
Hide file tree
Showing 12 changed files with 1,326 additions and 0 deletions.
114 changes: 114 additions & 0 deletions demo/demo_code_formula_predictor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
#
# Copyright IBM Corp. 2024 - 2024
# SPDX-License-Identifier: MIT
#
import argparse
import logging
import os
import sys
import time
from pathlib import Path

from huggingface_hub import snapshot_download
from PIL import Image

from docling_ibm_models.code_formula_model.code_formula_predictor import CodeFormulaPredictor


def demo(
logger: logging.Logger,
artifact_path: str,
device: str,
num_threads: int,
image_dir: str,
viz_dir: str,
):
r"""
Apply LayoutPredictor on the input image directory
If you want to load from PDF:
pdf_image = pyvips.Image.new_from_file("test_data/ADS.2007.page_123.pdf", page=0)
"""
# Create the layout predictor
code_formula_predictor = CodeFormulaPredictor(artifact_path, device=device, num_threads=num_threads)

image_dir = Path(image_dir)
images = []
image_names = os.listdir(image_dir)
image_names.sort()
for image_name in image_names:
image = Image.open(image_dir / image_name)
images.append(image)

t0 = time.perf_counter()
outputs = code_formula_predictor.predict(images, ['code', 'formula'], temperature=0)
total_ms = 1000 * (time.perf_counter() - t0)
avg_ms = (total_ms / len(image_names)) if len(image_names) > 0 else 0
logger.info(
"For {} images(ms): [total|avg] = [{:.1f}|{:.1f}]".format(
len(image_names), total_ms, avg_ms
)
)

for i, output in enumerate(outputs):
logger.info(f"\nOutput {i}:\n{output}\n\n")


def main(args):
num_threads = int(args.num_threads) if args.num_threads is not None else None
device = args.device.lower()
image_dir = args.image_dir
viz_dir = args.viz_dir

# Initialize logger
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger("CodeFormulaPredictor")
logger.setLevel(logging.DEBUG)
if not logger.hasHandlers():
handler = logging.StreamHandler(sys.stdout)
formatter = logging.Formatter(
"%(asctime)s %(name)-12s %(levelname)-8s %(message)s"
)
handler.setFormatter(formatter)
logger.addHandler(handler)

# Ensure the viz dir
Path(viz_dir).mkdir(parents=True, exist_ok=True)

# ! TODO: change this
# Download models from HF
# download_path = snapshot_download(repo_id="ds4sd/docling-models", revision="v2.1.0")
# artifact_path = os.path.join(download_path, "model_artifacts/layout")
artifact_path = "/dccstor/doc_fig_class/DocFM-Vision-Pretrainer/Vary-master/checkpoints_new_code_equation_model/checkpoint-7000/"

# Test the LayoutPredictor
demo(logger, artifact_path, device, num_threads, image_dir, viz_dir)


if __name__ == "__main__":
r"""
python -m demo.demo_code_formula_predictor -i <images_dir>
"""
parser = argparse.ArgumentParser(description="Test the CodeFormulaPredictor")
parser.add_argument(
"-d", "--device", required=False, default="cpu", help="One of [cpu, cuda, mps]"
)
parser.add_argument(
"-n", "--num_threads", required=False, default=4, help="Number of threads"
)
parser.add_argument(
"-i",
"--image_dir",
required=True,
help="PNG images input directory",
)
parser.add_argument(
"-v",
"--viz_dir",
required=False,
default="viz/",
help="Directory to save prediction visualizations",
)

args = parser.parse_args()
main(args)
225 changes: 225 additions & 0 deletions docling_ibm_models/code_formula_model/code_formula_predictor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,225 @@
#
# Copyright IBM Corp. 2024 - 2024
# SPDX-License-Identifier: MIT
#
import logging
from typing import List, Union

import numpy as np
import torch
from PIL import Image
from transformers import AutoTokenizer

from docling_ibm_models.code_formula_model.utils.constants import *

from docling_ibm_models.code_formula_model.utils.conversations import conv_v1

from docling_ibm_models.code_formula_model.models.vary_opt import varyOPTForCausalLM
from docling_ibm_models.code_formula_model.models.vary_opt_image_processor import VaryOptImageProcessor


_log = logging.getLogger(__name__)


class CodeFormulaPredictor:
"""
Code and Formula Predictor using a multi-modal vision-language model.
This class enables the prediction of code or LaTeX representations
from input images of code snippets or mathematical formulas.
Attributes
----------
_device : str
The device on which the model is loaded (e.g., 'cpu' or 'cuda').
_num_threads : int
Number of threads used for inference when running on CPU.
_tokenizer : transformers.PreTrainedTokenizer
Tokenizer for processing textual inputs to the model.
_model : transformers.PreTrainedModel
Pretrained multi-modal vision-language model.
_image_processor : transformers.ImageProcessor
Processor for normalizing and preparing input images.
_temperature : float
Sampling temperature for generation; controls randomness in predictions.
"""

def __init__(
self,
artifacts_path: str,
device: str = "cpu",
num_threads: int = 4,
):
"""
Initializes the CodeFormulaPredictor with the specified model artifacts.
Parameters
----------
artifacts_path : str
Path to the directory containing the pretrained model files.
device : str, optional
Device to run the inference on ('cpu' or 'cuda'), by default "cpu".
num_threads : int, optional
Number of threads for CPU inference, by default 4.
"""
self._device = device
self._num_threads = num_threads
if device == "cpu":
torch.set_num_threads(self._num_threads)

self._tokenizer = AutoTokenizer.from_pretrained(artifacts_path, use_fast=True, padding_side='left')
self._model = varyOPTForCausalLM.from_pretrained(artifacts_path).to(self._device)
self._model.eval()

self._image_processor = VaryOptImageProcessor.from_pretrained(artifacts_path)

_log.debug("CodeFormulaModel settings: {}".format(self.info()))

def info(self) -> dict:
"""
Retrieves configuration details of the CodeFormulaPredictor instance.
Returns
-------
dict
A dictionary containing configuration details such as the device and
the number of threads used.
"""
info = {
"device": self._device,
"num_threads": self._num_threads,
}
return info

def _get_prompt(self, label: str) -> str:
"""
Constructs the prompt for the model based on the input label.
Parameters
----------
label : str
The type of input, either 'code' or 'formula'.
Returns
-------
str
The constructed prompt including necessary tokens and query.
Raises
------
NotImplementedError
If the label is not 'code' or 'formula'.
"""
if label == "code":
query = "<code_image_to_text>"
elif label == "formula":
query = "<equation>"
else:
raise NotImplementedError("Label must be either code or formula")

qs = (
DEFAULT_IM_START_TOKEN
+ DEFAULT_IMAGE_PATCH_TOKEN * IMAGE_TOKEN_LEN
+ DEFAULT_IM_END_TOKEN
+ "\n"
)

conversation = conv_v1.copy()
conversation.append_message(conversation.roles[0], qs)
conversation.append_message(conversation.roles[1], None)

prompt = conversation.get_prompt()
prompt = prompt + "\n" + query

return prompt

@torch.inference_mode()
def predict(
self,
images: List[Union[Image.Image, np.ndarray]],
labels: List[str],
temperature: float = 0.1,
) -> List[str]:
"""
Predicts the textual representation of input images (code or LaTeX).
Parameters
----------
images : List[Union[Image.Image, np.ndarray]]
List of images to be processed, provided as PIL Image objects or numpy arrays.
labels : List[str]
List of labels indicating the type of each image ('code' or 'formula').
temperature : float, optional
Sampling temperature for generation, by default set to 0.1.
Returns
-------
List[str]
List of predicted textual outputs for each input image in the given input
order.
Raises
------
TypeError
If any of the input images is not of a supported type (PIL Image or numpy array).
Excpetion
In case the temperature is an invalid number.
"""
if (type(temperature) != float and type(temperature) != int) or temperature < 0:
raise Exception("Temperature must be a number greater or equal to 0.")

do_sample = True
if temperature == 0:
do_sample = False
temperature = None

if len(labels) != len(images):
raise Exception("The number of images must be the same as the number of labels.")

images_tmp = []
for image in images:
if isinstance(image, Image.Image):
image = image.convert("RGB")
elif isinstance(image, np.ndarray):
image = Image.fromarray(image).convert("RGB")
else:
raise TypeError("Not supported input image format")
images_tmp.append(image)
images = images_tmp

images_tensor = torch.stack([self._image_processor(img) for img in images]).to(self._device)

prompts = [self._get_prompt(label) for label in labels]

tokenized = self._tokenizer(prompts, padding=True, return_tensors="pt")
tokenized = {k: v.to(self._device) for k, v in tokenized.items()}

prompt_ids = tokenized["input_ids"]
attention_mask = tokenized["attention_mask"]

if self._device == "cpu":
output_ids_list = self._model.generate(
input_ids=prompt_ids,
attention_mask=attention_mask,
images=images_tensor,
do_sample=do_sample,
temperature=temperature,
max_new_tokens=4096 - prompt_ids.shape[0],
use_cache=True,
)
else:
with torch.autocast(device_type=self._device, dtype=torch.bfloat16):
output_ids_list = self._model.generate(
prompt_ids,
images=images_tensor,
do_sample=do_sample,
temperature=temperature,
max_new_tokens=4096 - prompt_ids.shape[0],
use_cache=True,
)

outputs = self._tokenizer.batch_decode(
output_ids_list[:, prompt_ids.shape[1] :], skip_special_tokens=True
)

return outputs
Loading

0 comments on commit fa8420b

Please sign in to comment.