Skip to content

Commit

Permalink
pre-commit
Browse files Browse the repository at this point in the history
  • Loading branch information
Matteo Omenetti [email protected] committed Jan 10, 2025
1 parent 99f2aa2 commit 66700ab
Show file tree
Hide file tree
Showing 6 changed files with 109 additions and 62 deletions.
23 changes: 14 additions & 9 deletions docling_ibm_models/code_formula_model/code_formula_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,13 @@
from PIL import Image
from transformers import AutoTokenizer

from docling_ibm_models.code_formula_model.models.sam_opt import SamOPTForCausalLM
from docling_ibm_models.code_formula_model.models.sam_opt_image_processor import (
SamOptImageProcessor,
)
from docling_ibm_models.code_formula_model.utils.constants import *

from docling_ibm_models.code_formula_model.utils.conversations import conv_v1

from docling_ibm_models.code_formula_model.models.sam_opt import SamOPTForCausalLM
from docling_ibm_models.code_formula_model.models.sam_opt_image_processor import SamOptImageProcessor


_log = logging.getLogger(__name__)


Expand Down Expand Up @@ -67,10 +66,12 @@ def __init__(
if device == "cpu":
torch.set_num_threads(self._num_threads)

self._tokenizer = AutoTokenizer.from_pretrained(artifacts_path, use_fast=True, padding_side='left')
self._tokenizer = AutoTokenizer.from_pretrained(
artifacts_path, use_fast=True, padding_side="left"
)
self._model = SamOPTForCausalLM.from_pretrained(artifacts_path).to(self._device)
self._model.eval()

self._image_processor = SamOptImageProcessor.from_pretrained(artifacts_path)

_log.debug("CodeFormulaModel settings: {}".format(self.info()))
Expand Down Expand Up @@ -174,7 +175,9 @@ def predict(
temperature = None

if len(labels) != len(images):
raise Exception("The number of images must be the same as the number of labels.")
raise Exception(
"The number of images must be the same as the number of labels."
)

images_tmp = []
for image in images:
Expand All @@ -187,7 +190,9 @@ def predict(
images_tmp.append(image)
images = images_tmp

images_tensor = torch.stack([self._image_processor(img) for img in images]).to(self._device)
images_tensor = torch.stack([self._image_processor(img) for img in images]).to(
self._device
)

prompts = [self._get_prompt(label) for label in labels]

Expand Down
93 changes: 57 additions & 36 deletions docling_ibm_models/code_formula_model/models/sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

from functools import partial
from typing import Optional, Tuple, Type

import torch
import torch.nn as nn
import torch.nn.functional as F

from typing import Optional, Tuple, Type

from functools import partial

class MLPBlock(nn.Module):
def __init__(
Expand Down Expand Up @@ -98,7 +98,9 @@ def __init__(
if use_abs_pos:
# Initialize absolute positional embedding with pretrain image size.
self.pos_embed = nn.Parameter(
torch.zeros(1, img_size // patch_size, img_size // patch_size, embed_dim)
torch.zeros(
1, img_size // patch_size, img_size // patch_size, embed_dim
)
)

self.blocks = nn.ModuleList()
Expand Down Expand Up @@ -135,9 +137,10 @@ def __init__(
LayerNorm2d(out_chans),
)


self.net_2 = nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1, bias=False)
self.net_3 = nn.Conv2d(512, 1024, kernel_size=3, stride=2, padding=1, bias=False)
self.net_3 = nn.Conv2d(
512, 1024, kernel_size=3, stride=2, padding=1, bias=False
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.patch_embed(x)
Expand All @@ -151,7 +154,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.net_2(x)
x = self.net_3(x)


return x


Expand Down Expand Up @@ -198,7 +200,9 @@ def __init__(
)

self.norm2 = norm_layer(dim)
self.mlp = MLPBlock(embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer)
self.mlp = MLPBlock(
embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer
)

self.window_size = window_size

Expand Down Expand Up @@ -263,23 +267,34 @@ def __init__(
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, H, W, _ = x.shape
# qkv with shape (3, B, nHead, H * W, C)
qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
qkv = (
self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
)
# q, k, v with shape (B * nHead, H * W, C)
q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0)

attn = (q * self.scale) @ k.transpose(-2, -1)

if self.use_rel_pos:
attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W))
attn = add_decomposed_rel_pos(
attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W)
)

attn = attn.softmax(dim=-1)
x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1)
x = (
(attn @ v)
.view(B, self.num_heads, H, W, -1)
.permute(0, 2, 3, 1, 4)
.reshape(B, H, W, -1)
)
x = self.proj(x)

return x


def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]:
def window_partition(
x: torch.Tensor, window_size: int
) -> Tuple[torch.Tensor, Tuple[int, int]]:
"""
Partition into non-overlapping windows with padding if needed.
Args:
Expand All @@ -299,12 +314,17 @@ def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, T
Hp, Wp = H + pad_h, W + pad_w

x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
windows = (
x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
)
return windows, (Hp, Wp)


def window_unpartition(
windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int]
windows: torch.Tensor,
window_size: int,
pad_hw: Tuple[int, int],
hw: Tuple[int, int],
) -> torch.Tensor:
"""
Window unpartition into original sequences and removing padding.
Expand All @@ -320,7 +340,9 @@ def window_unpartition(
Hp, Wp = pad_hw
H, W = hw
B = windows.shape[0] // (Hp * Wp // window_size // window_size)
x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1)
x = windows.view(
B, Hp // window_size, Wp // window_size, window_size, window_size, -1
)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)

if Hp > H or Wp > W:
Expand Down Expand Up @@ -394,7 +416,9 @@ def add_decomposed_rel_pos(
rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)

attn = (
attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]
attn.view(B, q_h, q_w, k_h, k_w)
+ rel_h[:, :, :, :, None]
+ rel_w[:, :, :, None, :]
).view(B, q_h * q_w, k_h * k_w)

return attn
Expand Down Expand Up @@ -434,15 +458,14 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return x



def build_sam_vit_b(checkpoint=None, image_size=1024):
return _build_sam(
encoder_embed_dim=768,
encoder_depth=12,
encoder_num_heads=12,
encoder_global_attn_indexes=[2, 5, 8, 11],
checkpoint=checkpoint,
image_size=image_size
image_size=image_size,
)


Expand All @@ -452,30 +475,28 @@ def _build_sam(
encoder_num_heads,
encoder_global_attn_indexes,
checkpoint=None,
image_size=1024
image_size=1024,
):
prompt_embed_dim = 256
vit_patch_size = 16
image_encoder=ImageEncoderViT(
depth=encoder_depth,
embed_dim=encoder_embed_dim,
img_size=image_size,
mlp_ratio=4,
norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
num_heads=encoder_num_heads,
patch_size=vit_patch_size,
qkv_bias=True,
use_rel_pos=True,
global_attn_indexes=encoder_global_attn_indexes,
window_size=14,
out_chans=prompt_embed_dim,
)
image_encoder = ImageEncoderViT(
depth=encoder_depth,
embed_dim=encoder_embed_dim,
img_size=image_size,
mlp_ratio=4,
norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
num_heads=encoder_num_heads,
patch_size=vit_patch_size,
qkv_bias=True,
use_rel_pos=True,
global_attn_indexes=encoder_global_attn_indexes,
window_size=14,
out_chans=prompt_embed_dim,
)

if checkpoint is not None:
# with open(checkpoint, "rb") as f:
state_dict = torch.load(checkpoint)


image_encoder.load_state_dict(state_dict, strict=True)
return image_encoder

37 changes: 27 additions & 10 deletions docling_ibm_models/code_formula_model/models/sam_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,22 +17,31 @@

import torch
import torch.nn as nn

from transformers import AutoConfig, AutoModelForCausalLM
from transformers import (
AutoConfig,
AutoModelForCausalLM,
OPTConfig,
OPTForCausalLM,
OPTModel,
)
from transformers.modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
)

from docling_ibm_models.code_formula_model.models.sam import build_sam_vit_b

from transformers import OPTConfig, OPTModel, OPTForCausalLM


class SamOptConfig(OPTConfig):
model_type = "sam_opt"

def __init__(self, sam_image_size=1024, sam_mm_projector_in=1024, sam_mm_projector_out=768, **kwargs):
def __init__(
self,
sam_image_size=1024,
sam_mm_projector_in=1024,
sam_mm_projector_out=768,
**kwargs,
):
super().__init__(**kwargs)
self.sam_image_size = sam_image_size
self.sam_mm_projector_in = sam_mm_projector_in
Expand Down Expand Up @@ -79,16 +88,24 @@ def forward(
image_features = self.mm_projector(image_features)

new_input_embeds = []
for cur_input_ids, cur_input_embeds, cur_image_features in zip(input_ids, inputs_embeds, image_features):
image_start_token_position = torch.where(cur_input_ids == im_start_token)[0].item()

cur_image_features = cur_image_features.to(device=cur_input_embeds.device)
for cur_input_ids, cur_input_embeds, cur_image_features in zip(
input_ids, inputs_embeds, image_features
):
image_start_token_position = torch.where(
cur_input_ids == im_start_token
)[0].item()

cur_image_features = cur_image_features.to(
device=cur_input_embeds.device
)
num_patches = cur_image_features.shape[0]
cur_input_embeds = torch.cat(
(
cur_input_embeds[: image_start_token_position + 1],
cur_image_features,
cur_input_embeds[image_start_token_position + num_patches + 1 :],
cur_input_embeds[
image_start_token_position + num_patches + 1 :
],
),
dim=0,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
# Copyright IBM Corp. 2024 - 2024
# SPDX-License-Identifier: MIT
#
from transformers.image_processing_utils import ImageProcessingMixin
from torchvision.transforms import functional as F
from PIL import Image
from torchvision.transforms import functional as F
from transformers import AutoImageProcessor
from transformers.image_processing_utils import ImageProcessingMixin


class SamOptImageProcessor(ImageProcessingMixin):
Expand Down
2 changes: 1 addition & 1 deletion docling_ibm_models/code_formula_model/utils/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@
DEFAULT_IM_START_TOKEN = "<img>"
DEFAULT_IM_END_TOKEN = "</img>"

IMAGE_TOKEN_LEN = 256
IMAGE_TOKEN_LEN = 256
12 changes: 8 additions & 4 deletions docling_ibm_models/code_formula_model/utils/conversations.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,20 @@
import dataclasses
from enum import auto, Enum
from enum import Enum, auto
from typing import List


class SeparatorStyle(Enum):
"""Different separator style."""

SINGLE = auto()
TWO = auto()
MPT = auto()


@dataclasses.dataclass
class Conversation:
"""A class that keeps all conversation history."""

system: str
roles: List[str]
messages: List[List[str]]
Expand All @@ -36,7 +39,7 @@ def copy(self):

def get_prompt(self):
if self.sep_style == SeparatorStyle.SINGLE:
ret = self.system + self.sep + '\n'
ret = self.system + self.sep + "\n"
for role, message in self.messages:
if message:
if type(message) is tuple:
Expand All @@ -58,9 +61,9 @@ def get_prompt(self):
return ret
if self.sep_style == SeparatorStyle.MPT:
if self.system:
ret = self.system + self.sep
ret = self.system + self.sep
else:
ret = ''
ret = ""
for role, message in self.messages:
if message:
if type(message) is tuple:
Expand All @@ -75,6 +78,7 @@ def get_prompt(self):
def append_message(self, role, message):
self.messages.append([role, message])


conv_v1 = Conversation(
system="A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions.",
Expand Down

0 comments on commit 66700ab

Please sign in to comment.