Skip to content

Commit

Permalink
Add lm-format-enforcer benchmark
Browse files Browse the repository at this point in the history
  • Loading branch information
rlouf committed Aug 13, 2024
1 parent aaea650 commit 0a1da14
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 1 deletion.
18 changes: 18 additions & 0 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
name: Tests

on:
pull_request:
branches: [main]
push:
branches: [main]

jobs:
style:
name: Check the code style
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- uses: actions/setup-python@v4
with:
python-version: "3.10"
- uses: pre-commit/[email protected]
7 changes: 6 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,9 @@ version = "0.1"
authors = [{"name" = "The Outlines developers"}]
description = "A benchmarking suite for structured generation libraries."
requires-python = ">=3.10"
dependencies = ["outlines==0.0.46", "transformers==4.44.0", "torch==2.4.0"]
dependencies = [
"lm-format-enforcer==0.10.6",
"outlines==0.0.46",
"torch==2.4.0",
"transformers==4.44.0",
]
6 changes: 6 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
[flake8]
max-line-length = 88
select = C,E,F,W
ignore = E203,E231,E501,E741,W503,W504,C901,E731
per-file-ignores =
**/__init__.py:F401,F403
43 changes: 43 additions & 0 deletions src/lfe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
"""Benchmark the lm-format-enforcer library."""
from lmformatenforcer import RegexParser, TokenEnforcer
from lmformatenforcer.integrations.transformers import (
build_token_enforcer_tokenizer_data,
)
from transformers import AutoTokenizer

models = [
"meta-llama/Llama-2-7b-hf", # 32,000 tokens vocabulary
"gpt2", # 50,257 tokens vocabulary
"meta-llama/Meta-Llama-3.1-8B-Instruct", # 128,256 tokens vocabulary
"google/gemma-2-2b-it", # 256,128 tokens vocabulary
]

case = [(r"\d{3}-\d{2}-\d{4}", "203-22-1234")]


class LMFormatEnforcer:
params = [models, case]
param_names = ["model", "regex"]
timeout = 600

def setup(self, model, _):
"""Set up the benchmark.
We convert the tokenizer during set up as this only
needs to be done once for a given model.
"""
self.tokenizer = AutoTokenizer.from_pretrained(
model, clean_up_tokenization_spaces=True
)
self.tokenizer_data = build_token_enforcer_tokenizer_data(self.tokenizer)

def time_lfe(self, _, regex):
regex_string, regex_example = regex
regex_example_tokens = self.tokenizer.encode(regex_example)

parser = RegexParser(regex_string)
token_enforcer = TokenEnforcer(self.tokenizer_data, parser)

for i in range(len(regex_example_tokens)):
_ = token_enforcer.get_allowed_tokens(regex_example_tokens[: i + 1])

0 comments on commit 0a1da14

Please sign in to comment.