Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support ModernBERT #53

Open
sam-bercovici opened this issue Jan 26, 2025 · 13 comments
Open

Support ModernBERT #53

sam-bercovici opened this issue Jan 26, 2025 · 13 comments

Comments

@sam-bercovici
Copy link
Contributor

adding support for using ModernBERT

@bclavie
Copy link
Collaborator

bclavie commented Feb 21, 2025

Hey. ModernBERT-based rerankers are already supported as long as your transformers version is recent enough. ModernBERT itself isn't a reranking model so there's no extra changes needed to support it.

@bclavie bclavie closed this as completed Feb 21, 2025
@sam-bercovici
Copy link
Contributor Author

sam-bercovici commented Feb 22, 2025

Hi @bclavie,
Using something like "Alibaba-NLP/gte-reranker-modernbert-base" with transformers crashed my code.
I had to jump via a few loops to get it to work. see the following code that made it work:

class DisableCompileContextManager:
    def __init__(self):
        self._original_compile = torch.compile

    def __enter__(self):
        # Turn torch.compile into a no-op
        torch.compile = lambda *args, **kwargs: lambda x: x  # type: ignore

    def __exit__(self, exc_type, exc_val, exc_tb):
        torch.compile = self._original_compile


class TransformerRanker(BaseRanker):
    def __init__(
        self,
        model_name_or_path: str,
        dtype: Optional[Union[str, torch.dtype]] = None,
        device: Optional[Union[str, torch.device]] = None,
        batch_size: int = 16,
        verbose: int = 1,
        max_length: int = 0,
        **kwargs,
    ):
        self.verbose = verbose
        self.device = get_device(device, verbose=self.verbose)
        self.dtype = get_dtype(dtype, self.device, self.verbose)
        self.max_length = max_length
        model_kwargs = kwargs.get("model_kwargs", {})
        with DisableCompileContextManager():
            self.model = AutoModelForSequenceClassification.from_pretrained(
                model_name_or_path,
                torch_dtype=self.dtype,
                **model_kwargs,
            ).to(self.device)
            vprint(f"Loaded model {model_name_or_path}", self.verbose)
            vprint(f"Using device {self.device}.", self.verbose)
            vprint(f"Using dtype {self.dtype}.", self.verbose)
            self.model.eval()
            tokenizer_kwargs = kwargs.get("tokenizer_kwargs", {})
            self.tokenizer = AutoTokenizer.from_pretrained(
                model_name_or_path,
                **tokenizer_kwargs,
            )
            self.ranking_type = "pointwise"
            self.batch_size = batch_size

    @torch.inference_mode()
    def rank(
        self,
        query: str,
        docs: Union[str, List[str], Document, List[Document]],
        doc_ids: Optional[Union[List[str], List[int]]] = None,
        metadata: Optional[List[dict]] = None,
        batch_size: Optional[int] = None,
    ) -> RankedResults:
        docs = prep_docs(docs, doc_ids, metadata)
        inputs = [(query, doc.text) for doc in docs]

        # Override self.batch_size if explicitely set
        if batch_size is None:
            batch_size = self.batch_size
        batched_inputs = [
            inputs[i : i + batch_size] for i in range(0, len(inputs), batch_size)
        ]
        scores = []
        for batch in batched_inputs:
            # tokenized_inputs = self.tokenize(batch)
            with torch.no_grad():
                if self.max_length:
                    tokenized_inputs = self.tokenizer(
                        batch,
                        max_length=self.max_length,
                        return_tensors="pt",
                        padding=True,
                        truncation=True,
                    ).to(self.device)
                else:
                    tokenized_inputs = self.tokenizer(
                        batch,
                        return_tensors="pt",
                        padding=True,
                        truncation=True,
                    ).to(self.device)

                batch_scores = self.model(**tokenized_inputs).logits.squeeze()
                batch_scores = batch_scores.detach().cpu().numpy().tolist()
                if isinstance(batch_scores, float):  # Handling the case of single score
                    scores.append(batch_scores)
                else:
                    scores.extend(batch_scores)
        if len(scores) == 1:
            return Result(document=docs[0], score=scores[0])
        else:
            ranked_results = [
                Result(document=doc, score=score, rank=idx + 1)
                for idx, (doc, score) in enumerate(
                    sorted(zip(docs, scores), key=lambda x: x[1], reverse=True)
                )
            ]
            return RankedResults(results=ranked_results, query=query, has_scores=True)

    @torch.inference_mode()
    def score(self, query: str, doc: str) -> float:
        inputs = self.tokenize((query, doc))  # type: ignore
        outputs = self.model(**inputs)
        score = outputs.logits.squeeze().detach().cpu().numpy().astype(float)
        return score

@bclavie
Copy link
Collaborator

bclavie commented Feb 22, 2025

Ooh this might be a broader ModernBERT issue rather than transformers itself here! Could you please share the exact error message? I'd be happy to look into what caused it.

@sam-bercovici
Copy link
Contributor Author

sam-bercovici commented Feb 22, 2025

@bclavie, I will try to recreate the issue next week, will need to revert my code and get the error back.
In general, I am running this on a local machine with Nvidia 4060 and running the reranking in few threads in parallel.
Again ,will try to get to recreate the errors and report here.
Would you mind reopening this issue to track this until we conclude?

@bclavie
Copy link
Collaborator

bclavie commented Feb 22, 2025

@bclavie, I will try to recreate the issue next week, will need to revert my code and get the error back. In general, I am running this on a local machine with Nvidia 4060 and running the reranking in few threads in parallel. Again ,will try to get to recreate the errors and report here. Would you mind reopening this issue to track this until we conclude?

Sure! I'm curious what's causing this for you -- I suspect it might be something wrong with how we set up modernbert loading. I've been unable to reproduce the issue on either 4090 or Mac MPS 🤔

@bclavie bclavie reopened this Feb 22, 2025
@sam-bercovici
Copy link
Contributor Author

sam-bercovici commented Feb 22, 2025

@bclavie , it was easy to get the error back:
got exception Detected that you are using FX to symbolically trace a dynamo-optimized function. This is not supported at the moment.
I think this is the relevant log:
File "/home/samb/src/trellizai/feature-generator/.conda/lib/python3.12/site-packages/rerankers/models/transformer_ranker.py", line 78, in rank
batch_scores = self.model(**tokenized_inputs).logits.squeeze()
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/samb/src/trellizai/feature-generator/.conda/lib/python3.12/concurrent/futures/_base.py", line 401, in __get_result
raise self._exception
File "/home/samb/src/trellizai/feature-generator/.conda/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/samb/src/trellizai/feature-generator/.conda/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/samb/src/trellizai/feature-generator/.conda/lib/python3.12/site-packages/transformers/models/modernbert/modeling_modernbert.py", line 1239, in forward
outputs = self.model(
^^^^^^^^^^^
File "/home/samb/src/trellizai/feature-generator/.conda/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/samb/src/trellizai/feature-generator/.conda/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/samb/src/trellizai/feature-generator/.conda/lib/python3.12/site-packages/transformers/models/modernbert/modeling_modernbert.py", line 976, in forward
layer_outputs = encoder_layer(
^^^^^^^^^^^^^^
File "/home/samb/src/trellizai/feature-generator/.conda/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/samb/src/trellizai/feature-generator/.conda/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/samb/src/trellizai/feature-generator/.conda/lib/python3.12/site-packages/transformers/models/modernbert/modeling_modernbert.py", line 577, in forward
self.compiled_mlp(hidden_states)
File "/home/samb/src/trellizai/feature-generator/.conda/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 544, in _fn
raise RuntimeError(

@bclavie
Copy link
Collaborator

bclavie commented Feb 22, 2025

Thank you! Could you try with modifying the model init call to:

AutoModelForSequenceClassification.from_pretrained(
                model_name_or_path,
                torch_dtype=self.dtype,
                reference_compile=False,
                **model_kwargs,
            ).to(self.device)

(the reference_compile flag). It should only turn itself on if your machine supports it, but I'm wondering if this is causing the issue. The other culprit would be that it's trying to use the FA2 implementation without FA2 being present, but that'd be strange...

@sam-bercovici
Copy link
Contributor Author

sam-bercovici commented Feb 22, 2025

@bclavie. looks like it worked. My code did not get the exactions it got before.
After that I modified the reranker library code in the TransformerRanker and added reference_compile=False,
my code stopped getting Exceptions.

@sam-bercovici
Copy link
Contributor Author

Hi @bclavie ,
how do you want to progress with this?

@sam-bercovici
Copy link
Contributor Author

Hi @bclavie ,

two more nits.

  1. in the score function:
inputs = self.tokenize((query, doc))  # type: ignore

I think it should be:

inputs = self.tokenize([(query, doc)])
  1. in the rank function, when if len(scores) == 1: when returning Result instead of RankedResults, langchain breaks.
    Does it make sense to return RankedResults with one Result instead of the Result?

@bclavie
Copy link
Collaborator

bclavie commented Mar 10, 2025

Hey, my bad for letting this thread go quiet, I had to go on a med leave for a bit!

I'll be updating the library once I'm recovered to fix this issue:

in the rank function, when if len(scores) == 1: when returning Result instead of RankedResults, langchain breaks.
Does it make sense to return RankedResults with one Result instead of the Result?

As well as go with uncompiled modernbert by default. Thanks again!

@sam-bercovici
Copy link
Contributor Author

I moved to another computer and trying to run the code as per this discussion.
I am getting:
flash_attn/ops/triton/rotary.py", line 166, in apply_rotary
assert sin.shape == cos.shape

follows the code I am using:

class TransformerRanker(BaseRanker):
    def __init__(
        self,
        model_name_or_path: str,
        dtype: Optional[Union[str, torch.dtype]] = None,
        device: Optional[Union[str, torch.device]] = None,
        batch_size: int = 16,
        verbose: int = 1,
        **kwargs,
    ):
        self.verbose = verbose
        self.device = get_device(device, verbose=self.verbose)
        self.dtype = get_dtype(dtype, self.device, self.verbose)
        self.is_monobert = "monobert" in model_name_or_path.lower()
        model_kwargs = kwargs.get("model_kwargs", {})
        self.model = AutoModelForSequenceClassification.from_pretrained(
            model_name_or_path,
            torch_dtype=self.dtype,
            reference_compile=False,
            **model_kwargs,
        ).to(self.device)
        vprint(f"Loaded model {model_name_or_path}", self.verbose)
        vprint(f"Using device {self.device}.", self.verbose)
        vprint(f"Using dtype {self.dtype}.", self.verbose)
        self.model.eval()
        tokenizer_kwargs = kwargs.get("tokenizer_kwargs", {})
        self.tokenizer = AutoTokenizer.from_pretrained(
            model_name_or_path,
            **tokenizer_kwargs,
        )
        self.ranking_type = "pointwise"
        self.batch_size = batch_size

    # added Tuple[str, str] as this is what you pass in the score function, inputs = self.tokenize((query, doc)) below
    def tokenize(
        self, inputs: Union[str, List[str], Tuple[str, str], List[Tuple[str, str]]]
    ) -> BatchEncoding:
        return self.tokenizer(
            inputs, return_tensors="pt", padding=True, truncation=True
        ).to(self.device)

    @torch.inference_mode()
    def rank(
        self,
        query: str,
        docs: Union[str, List[str], Document, List[Document]],
        doc_ids: Optional[Union[List[str], List[int]]] = None,
        metadata: Optional[List[dict]] = None,
        batch_size: Optional[int] = None,
    ) -> RankedResults:
        docs = prep_docs(docs, doc_ids, metadata)
        inputs = [(query, doc.text) for doc in docs]

        # Override self.batch_size if explicitly set
        if batch_size is None:
            batch_size = self.batch_size
        batched_inputs = [
            inputs[i : i + batch_size] for i in range(0, len(inputs), batch_size)
        ]
        scores: List[Union[float, List[float]]] = []
        for batch in batched_inputs:
            tokenized_inputs = self.tokenize(batch)
            batch_scores = self.model(**tokenized_inputs).logits.squeeze()
            if self.dtype != torch.float32:
                batch_scores = batch_scores.float()
            batch_scores = batch_scores.detach().cpu().numpy().tolist()
            if isinstance(batch_scores, float):  # Handling the case of single score
                scores.append(batch_scores)
            else:
                scores.extend(batch_scores)
        if self.is_monobert:
            scores = [x[1] - x[0] for x in scores]  # type: ignore
        if len(scores) == 1:  # TODO - this is different than the original code
            # return Result(document=docs[0], score=scores[0])
            return RankedResults(
                results=[Result(document=docs[0], score=scores[0])],
                query=query,
                has_scores=True,
            )
        else:
            ranked_results = [
                Result(document=doc, score=score, rank=idx + 1)
                for idx, (doc, score) in enumerate(
                    sorted(zip(docs, scores), key=lambda x: x[1], reverse=True)
                )
            ]
            return RankedResults(results=ranked_results, query=query, has_scores=True)

    @torch.inference_mode()
    def score(self, query: str, doc: str) -> float:
        inputs = self.tokenize((query, doc))
        outputs = self.model(**inputs)
        score = outputs.logits.squeeze().detach().cpu().numpy().astype(float)
        return score

@bclavie
Copy link
Collaborator

bclavie commented Mar 20, 2025

Hey, this seems to be another issue that is specifically related to ModernBERT's implementation in HF Transformers 🤔

Let me look into this further so I can see if I can find a more generalised solution for rerankers.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants